In [1]:
from fungiclef.utils import get_spark, spark_resource, read_config
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config = read_config(path='../fungiclef/config.json')

# First, we read the metadata for the dataset and make a proper new one. This will be the single source of truth we use to build the rest of our stuff on
# This corresponds to the DF20 dataset
TRAIN_METADATA = config["gs_paths"]["train"]["metadata"]

# These two correspond to the DF21 dataset
VALID_METADATA = config["gs_paths"]["val"]["metadata"]
TEST_METADATA = config["gs_paths"]["test"]["metadata"]

PRODUCTION_BUCKET = 'gs://dsgt-clef-fungiclef-2024/production/'

### 1. Making a new train / valid / test set split with metadata file

The motivation to this is to make a proper, bigger training set, where unknown classes are also included in the dataset. 


In [None]:
train_metadata_df = pd.read_csv(TRAIN_METADATA)
valid_metadata_df = pd.read_csv(VALID_METADATA)
test_metadata_df = pd.read_csv(TEST_METADATA)

In [None]:
# 
full_dataset_pq = config["gs_paths"]["train_and_test_300px_corrected"]["raw_parquet"]
spark = get_spark(**{
    "spark.sql.parquet.enableVectorizedReader": False, 
})

full_dataset_df = spark.read.parquet(full_dataset_pq)

24/04/24 12:06:00 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [42]:
observation_ids = full_dataset_df.select("observationID").toPandas()['observationID']

                                                                                

In [52]:
print("available images: ", len(observation_ids))

available images:  356770


In [53]:
print("train: ", len(train_metadata_df.observationID))
print("valid: ", len(valid_metadata_df.observationID))
print("test: ", len(test_metadata_df.observationID))

# It appears that the only images we have are the train and validation datasets. Only metadata is available for the public test set. As such, we need to be a bit wiser in how we split up the data in that case. 
print("train + valid: ", len(train_metadata_df.observationID) + len(valid_metadata_df.observationID))

train:  295938
valid:  60832
test:  60225
train + valid:  356770


Given that we don't have as much "new" data, but we still want to keep the data and stratified, the most straightforward way is to split the valid set (60,832 cases) three-fold.

1/3rd will go to the training set (such that it can learn concepts such as unknown), 

1/3rd will become the validation set (for model tuning)

1/3rd will become the sacred held out test set

In [125]:
from sklearn.model_selection import train_test_split

# Here, we stratify the validation dataset by class_id
class_df = valid_metadata_df[['observationID', 'class_id']]

# I note that the validation set has numerous repeated observation for some reason. 
class_df = class_df.drop_duplicates()

# For rare classes, we want to keep them within the test set only. 
test_only_class_ids = class_df.class_id.value_counts()[class_df.class_id.value_counts() < 5].index.tolist()

class_df = class_df[~class_df.class_id.isin(test_only_class_ids)]
additional_train, new_valid_test = train_test_split(class_df, test_size=0.67, stratify=class_df['class_id'])
new_valid, new_test = train_test_split(new_valid_test, test_size=0.5, stratify=new_valid_test['class_id'])
new_test = pd.concat((new_test, class_df[class_df.class_id.isin(test_only_class_ids)]))


In [130]:
# These are the new / additional sets
print("additional train: ", valid_metadata_df.observationID.isin(additional_train.observationID).sum())
print("new valid: ", valid_metadata_df.observationID.isin(new_valid.observationID).sum())
print("new test: ", valid_metadata_df.observationID.isin(new_test.observationID).sum())

additional train:  19394
new valid:  19788
new test:  19746


In [136]:
train_metadata_df['dataset'] = "train"
# Label the old validation metadata df with new tag
valid_metadata_df.loc[valid_metadata_df.observationID.isin(additional_train.observationID), 'dataset'] = "train"
valid_metadata_df.loc[valid_metadata_df.observationID.isin(new_valid.observationID), 'dataset'] = "valid"
valid_metadata_df.loc[valid_metadata_df.observationID.isin(new_test.observationID), 'dataset'] = "test"

In [141]:
# concat the two and save to bucket
PRODUCTION_BUCKET = 'gs://dsgt-clef-fungiclef-2024/production/'

full_metadata_df = pd.concat((train_metadata_df, valid_metadata_df))
full_metadata_df.to_csv(PRODUCTION_BUCKET + "metadata/DF_combined_metadata_full_raw.csv", index=False)

### 2. Preprocessing categorical variables
The following section we will be preprocessing the dataset with categorical columns
We will also save this mapping to use for the public dataset / for any subsequent inference

In [8]:
# Here, we are only keeping columns that are relevant either for training or inference. 
# This includes all the columns that were present in the public test metadata dataset
TEST_DF_COLUMNS = ['observationID', 'month', 'day', 'countryCode', 'locality', 'level0Gid',
       'level0Name', 'level1Gid', 'level1Name', 'level2Gid', 'level2Name',
       'Substrate', 'Latitude', 'Longitude', 'CoorUncert', 'Habitat',
       'image_path', 'filename', 'MetaSubstrate']

# As well as the overall classification of the fungi (this could potentially be useful as additional training targets)
COLUMNS_TO_KEEP = TEST_DF_COLUMNS + ['scientificName', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species', 'poisonous', 'class_id', 'dataset']

# These are the categorical columns we will need to factorize and generate labels for
CATEGORICAL_COLUMNS = ['locality', 'level0Gid', 'level1Gid', 'level2Gid', 'Substrate', 'Habitat', 'MetaSubstrate', 'kingdom', 'phylum', 'class',
       'order', 'family', 'genus', 'species']

In [29]:
selected_metadata_df = full_metadata_df[COLUMNS_TO_KEEP]

# This is important to save 
mapping = {}

for col in CATEGORICAL_COLUMNS:
    selected_metadata_df.sort_values(by=col, ascending=True, inplace=True)
    col_numerical, col_mapping = pd.factorize(selected_metadata_df[col], use_na_sentinel=False)
    selected_metadata_df.loc[:, f"{col}_text"] = selected_metadata_df.loc[:, col]
    selected_metadata_df.loc[:, col] = col_numerical
    mapping[col] = {v: k for k, v in enumerate(col_mapping)}

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  selected_metadata_df.sort_values(by=col, ascending=True, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  selected_metadata_df.loc[:, f"{col}_text"] = selected_metadata_df.loc[:, col]
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  selected_metadata_df.sort_values(by=col, ascending=True, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_i

In [22]:
# It is vital to save all the mapping for the categorical columns for inference later.
# P.S. This is saved in /production/metadata. 
import pickle

CATEGORICAL_MAPPING_LOCATION = "./categorical_columns_mapping.pkl"

pickle.dump(mapping, open(CATEGORICAL_MAPPING_LOCATION, 'wb'))

In [58]:
# Lastly, it is vital to remap the unknown class from -1 to another positive integer. Otherwise it will be hard to train and harder to debug.
import numpy as np

UNKNOWN_CLASS = 1604
selected_metadata_df['class_id'] = np.where(selected_metadata_df.class_id==-1, UNKNOWN_CLASS, selected_metadata_df.class_id)

In [59]:
selected_metadata_df.to_csv(PRODUCTION_BUCKET + "metadata/DF_combined_metadata_mapped_columns.csv", index=False)

## 3. Pair up with embeddings
Now that we have a single metadata dataframe as our single source of truth, we will match the embeddings.

In [25]:
selected_metadata_df = pd.read_csv(PRODUCTION_BUCKET + "metadata/DF_combined_metadata_mapped_columns.csv")

In [28]:
selected_metadata_df

Unnamed: 0,observationID,month,day,countryCode,locality,level0Gid,level0Name,level1Gid,level1Name,level2Gid,...,Substrate_text,Habitat_text,MetaSubstrate_text,kingdom_text,phylum_text,class_text,order_text,family_text,genus_text,species_text
0,2238506390,9.0,1.0,DK,5179,9,Denmark,28,Hovedstaden,80,...,soil,park/churchyard,jord,Fungi,Basidiomycota,Agaricomycetes,Polyporales,Meruliaceae,Abortiporus,Abortiporus biennis
1,2812984326,7.0,8.0,DK,6157,9,Denmark,32,Syddanmark,146,...,dead wood (including bark),lawn,wood,Fungi,Basidiomycota,Agaricomycetes,Polyporales,Meruliaceae,Abortiporus,Abortiporus biennis
2,2465026418,11.0,19.0,DK,6829,9,Denmark,31,Sjælland,139,...,dead wood (including bark),Unmanaged deciduous woodland,wood,Fungi,Basidiomycota,Agaricomycetes,Polyporales,Meruliaceae,Abortiporus,Abortiporus biennis
3,2812994333,7.0,9.0,DK,6145,9,Denmark,32,Syddanmark,146,...,dead wood (including bark),lawn,wood,Fungi,Basidiomycota,Agaricomycetes,Polyporales,Meruliaceae,Abortiporus,Abortiporus biennis
4,2812983332,8.0,20.0,DK,6157,9,Denmark,32,Syddanmark,146,...,dead wood (including bark),lawn,wood,Fungi,Basidiomycota,Agaricomycetes,Polyporales,Meruliaceae,Abortiporus,Abortiporus biennis
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
356765,2238482337,11.0,23.0,DK,9621,9,Denmark,29,Midtjylland,98,...,faeces,other habitat,animals,Fungi,Basidiomycota,Agaricomycetes,Agaricales,Tricholomataceae,Clitocybe,
356766,2237940567,10.0,26.0,DK,2012,9,Denmark,29,Midtjylland,111,...,faeces,salt meadow,animals,Fungi,Basidiomycota,Agaricomycetes,Agaricales,Tricholomataceae,Clitocybe,
356767,2238455560,11.0,7.0,DK,1662,9,Denmark,29,Midtjylland,108,...,faeces,salt meadow,animals,Fungi,Basidiomycota,Agaricomycetes,Agaricales,Tricholomataceae,Clitocybe,
356768,2449442140,11.0,3.0,DK,6116,9,Denmark,31,Sjælland,133,...,faeces,Forest bog,animals,Fungi,Basidiomycota,Agaricomycetes,Agaricales,Tricholomataceae,Clitocybe,


In [60]:
# For pairing up with embeddings, we will use numerical data only so there is less data to load etc
numerical_metadata_df = selected_metadata_df.drop([c + "_text" for c in CATEGORICAL_COLUMNS], axis=1)
numerical_metadata_df = numerical_metadata_df.drop(['filename', 'scientificName', 'countryCode', 'level0Name', 'level1Name', 'level2Name'], axis=1)

In [61]:
numerical_metadata_df['image_path'] = numerical_metadata_df.image_path.apply(lambda x: x.replace(".JPG", ".jpg"))

In [12]:
# Match resnet embeddings
resnet_pq = "gs://dsgt-clef-fungiclef-2024/data/parquet/DF20_300px_and_DF21_300px_corrected_FULL_SET_embedding/resnet"
spark = get_spark(**{
    "spark.sql.parquet.enableVectorizedReader": False, 
})

resnet_df = spark.read.parquet(resnet_pq)
resnet_embeddings = resnet_df.select("image_path", "embeddings").toPandas()

Setting default log level to "Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Exception in thread "main" java.nio.file.NoSuchFileException: /tmp/tmpbzot30rp/connection9366208419788263865.info
	at java.base/sun.nio.fs.UnixException.translateToIOException(UnixException.java:92)
	at java.base/sun.nio.fs.UnixException.rethrowAsIOException(UnixException.java:111)
	at java.base/sun.nio.fs.UnixException.rethrowAsIOException(UnixException.java:116)
	at java.base/sun.nio.fs.UnixFileSystemProvider.newByteChannel(UnixFileSystemProvider.java:219)
	at java.base/java.nio.file.Files.newByteChannel(Files.java:371)
	at java.base/java.nio.file.Files.createFile(Files.java:648)
	at java.base/java.nio.file.TempFileHelper.create(TempFileHelper.java:137)
	at java.base/java.nio.file.TempFileHelper.createTempFile(TempFileHelper.jav

In [62]:
# Match dataset by image_path
resnet_embeddings_full_df = numerical_metadata_df.set_index('image_path').join(resnet_embeddings.set_index('image_path')).reset_index()


In [63]:
DATASET_PATH = "gs://dsgt-clef-fungiclef-2024/production/resnet/"
resnet_embeddings_full_df[resnet_embeddings_full_df.dataset=="train"].to_parquet(DATASET_PATH + "DF_300_train.parquet")
resnet_embeddings_full_df[resnet_embeddings_full_df.dataset=="valid"].to_parquet(DATASET_PATH + "DF_300_valid.parquet")
resnet_embeddings_full_df[resnet_embeddings_full_df.dataset=="test"].to_parquet(DATASET_PATH + "DF_300_test.parquet")

In [157]:
# Dino embeddings not working out yet because of the bug. TODO: Re-do dino embeddings
dino_dct_pq = config["gs_paths"]["train_and_test_300px_corrected"]["train_dct_parquet"]
spark = get_spark(**{
    "spark.sql.parquet.enableVectorizedReader": False, 
})

dino_dct_df = spark.read.parquet(dino_dct_pq)

24/04/24 13:24:29 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [160]:
dino_embeddings = dino_dct_df.select("ImageUniqueID", "dct_embedding").toPandas()

[Stage 7:>                                                          (0 + 4) / 6]

                                                                                

In [165]:
train_metadata_df.keys()

Index(['observationID', 'year', 'month', 'day', 'countryCode', 'locality',
       'taxonID', 'scientificName', 'kingdom', 'phylum', 'class', 'order',
       'family', 'genus', 'specificEpithet', 'taxonRank', 'species',
       'level0Gid', 'level0Name', 'level1Gid', 'level1Name', 'level2Gid',
       'level2Name', 'ImageUniqueID', 'Substrate', 'rightsHolder', 'Latitude',
       'Longitude', 'CoorUncert', 'Habitat', 'image_path', 'class_id',
       'MetaSubstrate', 'poisonous', 'dataset'],
      dtype='object')

In [166]:
valid_metadata_df.keys()

Index(['observationID', 'month', 'day', 'countryCode', 'locality', 'taxonID',
       'scientificName', 'kingdom', 'phylum', 'class', 'order', 'family',
       'genus', 'specificEpithet', 'taxonRank', 'species', 'level0Gid',
       'level0Name', 'level1Gid', 'level1Name', 'level2Gid', 'level2Name',
       'Substrate', 'Latitude', 'Longitude', 'CoorUncert', 'Habitat',
       'image_path', 'filename', 'MetaSubstrate', 'class_id', 'poisonous',
       'dataset'],
      dtype='object')

In [161]:
dino_embeddings

Unnamed: 0,ImageUniqueID,dct_embedding
0,,"[-27465.547, 14178.638, -20135.004, -6432.8823..."
1,2864913416-287345,"[-23667.527, 2323.5469, 37456.824, 45126.812, ..."
2,2383043463-43428,"[-30186.223, 2974.5952, 15293.627, 45778.152, ..."
3,2446759895-197886,"[-23543.812, 12889.68, 10386.561, 12623.754, 5..."
4,,"[-740.4336, 28029.623, -26031.705, -16861.986,..."
...,...,...
320876,,"[-26018.09, -9693.551, 10605.984, 25729.992, 1..."
320877,,"[-32310.83, 9728.084, 4996.2915, 2508.0183, 16..."
320878,,"[-20754.445, 14093.246, 24992.883, 22008.844, ..."
320879,2238538042-327236,"[-10542.917, 30165.953, 11423.545, 2223.7212, ..."


In [218]:
# Checking we have the correct set
df_full_300_pq = "gs://dsgt-clef-fungiclef-2024/data/parquet/DF20_300px_and_DF21_300px_corrected"
spark = get_spark(**{
    "spark.sql.parquet.enableVectorizedReader": False, 
})

df_full_300_df = spark.read.parquet(df_full_300_pq)
df_full_img_paths = df_full_300_df.select("image_path").toPandas()["image_path"]

24/04/24 13:55:11 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
                                                                                

In [220]:
len(df_full_img_paths.unique())

356770