In [1]:
from fungiclef.utils import get_spark, spark_resource, read_config
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = read_config(path='../fungiclef/config.json')

# First, we read the metadata for the dataset and make a proper new one. This will be the single source of truth we use to build the rest of our stuff on
# This corresponds to the DF20 dataset
TRAIN_METADATA = config["gs_paths"]["train"]["metadata"]

# These two correspond to the DF21 dataset
VALID_METADATA = config["gs_paths"]["val"]["metadata"]
TEST_METADATA = config["gs_paths"]["test"]["metadata"]

PRODUCTION_BUCKET = 'gs://dsgt-clef-fungiclef-2024/production/'

### 1. Making a new train / valid / test set split with metadata file

The motivation to this is to make a proper, bigger training set, where unknown classes are also included in the dataset. 


In [None]:
train_metadata_df = pd.read_csv(TRAIN_METADATA)
valid_metadata_df = pd.read_csv(VALID_METADATA)
test_metadata_df = pd.read_csv(TEST_METADATA)

In [None]:
# 
full_dataset_pq = config["gs_paths"]["train_and_test_300px_corrected"]["raw_parquet"]
spark = get_spark(**{
    "spark.sql.parquet.enableVectorizedReader": False, 
})

full_dataset_df = spark.read.parquet(full_dataset_pq)

24/04/24 12:06:00 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [42]:
observation_ids = full_dataset_df.select("observationID").toPandas()['observationID']

                                                                                

In [52]:
print("available images: ", len(observation_ids))

available images:  356770


In [53]:
print("train: ", len(train_metadata_df.observationID))
print("valid: ", len(valid_metadata_df.observationID))
print("test: ", len(test_metadata_df.observationID))

# It appears that the only images we have are the train and validation datasets. Only metadata is available for the public test set. As such, we need to be a bit wiser in how we split up the data in that case. 
print("train + valid: ", len(train_metadata_df.observationID) + len(valid_metadata_df.observationID))

train:  295938
valid:  60832
test:  60225
train + valid:  356770


Given that we don't have as much "new" data, but we still want to keep the data and stratified, the most straightforward way is to split the valid set (60,832 cases) three-fold.

1/3rd will go to the training set (such that it can learn concepts such as unknown), 

1/3rd will become the validation set (for model tuning)

1/3rd will become the sacred held out test set

In [125]:
from sklearn.model_selection import train_test_split

# Here, we stratify the validation dataset by class_id
class_df = valid_metadata_df[['observationID', 'class_id']]

# I note that the validation set has numerous repeated observation for some reason. 
class_df = class_df.drop_duplicates()

# For rare classes, we want to keep them within the test set only. 
test_only_class_ids = class_df.class_id.value_counts()[class_df.class_id.value_counts() < 5].index.tolist()

class_df = class_df[~class_df.class_id.isin(test_only_class_ids)]
additional_train, new_valid_test = train_test_split(class_df, test_size=0.67, stratify=class_df['class_id'])
new_valid, new_test = train_test_split(new_valid_test, test_size=0.5, stratify=new_valid_test['class_id'])
new_test = pd.concat((new_test, class_df[class_df.class_id.isin(test_only_class_ids)]))


In [130]:
# These are the new / additional sets
print("additional train: ", valid_metadata_df.observationID.isin(additional_train.observationID).sum())
print("new valid: ", valid_metadata_df.observationID.isin(new_valid.observationID).sum())
print("new test: ", valid_metadata_df.observationID.isin(new_test.observationID).sum())

additional train:  19394
new valid:  19788
new test:  19746


In [136]:
train_metadata_df['dataset'] = "train"
# Label the old validation metadata df with new tag
valid_metadata_df.loc[valid_metadata_df.observationID.isin(additional_train.observationID), 'dataset'] = "train"
valid_metadata_df.loc[valid_metadata_df.observationID.isin(new_valid.observationID), 'dataset'] = "valid"
valid_metadata_df.loc[valid_metadata_df.observationID.isin(new_test.observationID), 'dataset'] = "test"

In [141]:
# concat the two and save to bucket
PRODUCTION_BUCKET = 'gs://dsgt-clef-fungiclef-2024/production/'

full_metadata_df = pd.concat((train_metadata_df, valid_metadata_df))
full_metadata_df.to_csv(PRODUCTION_BUCKET + "metadata/DF_combined_metadata_full_raw.csv", index=False)

### 2. Preprocessing categorical variables
The following section we will be preprocessing the dataset with categorical columns
We will also save this mapping to use for the public dataset / for any subsequent inference

In [6]:
# Here, we are only keeping columns that are relevant either for training or inference. 
# This includes all the columns that were present in the public test metadata dataset
TEST_DF_COLUMNS = ['observationID', 'month', 'day', 'countryCode', 'locality', 'level0Gid',
       'level0Name', 'level1Gid', 'level1Name', 'level2Gid', 'level2Name',
       'Substrate', 'Latitude', 'Longitude', 'CoorUncert', 'Habitat',
       'image_path', 'filename', 'MetaSubstrate']

# As well as the overall classification of the fungi (this could potentially be useful as additional training targets)
COLUMNS_TO_KEEP = TEST_DF_COLUMNS + ['scientificName', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species', 'poisonous', 'class_id', 'dataset']

# These are the categorical columns we will need to factorize and generate labels for
CATEGORICAL_COLUMNS = ['locality', 'level0Gid', 'level1Gid', 'level2Gid', 'Substrate', 'Habitat', 'MetaSubstrate', 'kingdom', 'phylum', 'class',
       'order', 'family', 'genus', 'species']

In [5]:
selected_metadata_df = full_metadata_df[COLUMNS_TO_KEEP]

# This is important to save 
mapping = {}

for col in CATEGORICAL_COLUMNS:
    selected_metadata_df.sort_values(by=col, ascending=True, inplace=True)
    col_numerical, col_mapping = pd.factorize(selected_metadata_df[col], use_na_sentinel=False)
    selected_metadata_df.loc[:, f"{col}_text"] = selected_metadata_df.loc[:, col]
    selected_metadata_df.loc[:, col] = col_numerical
    mapping[col] = {v: k for k, v in enumerate(col_mapping)}

NameError: name 'full_metadata_df' is not defined

In [4]:
# It is vital to save all the mapping for the categorical columns for inference later.
# P.S. This is saved in /production/metadata. 
import pickle

CATEGORICAL_MAPPING_LOCATION = "./categorical_columns_mapping.pkl"

pickle.dump(mapping, open(CATEGORICAL_MAPPING_LOCATION, 'wb'))

NameError: name 'mapping' is not defined

In [58]:
# Lastly, it is vital to remap the unknown class from -1 to another positive integer. Otherwise it will be hard to train and harder to debug.
import numpy as np

UNKNOWN_CLASS = 1604
selected_metadata_df['class_id'] = np.where(selected_metadata_df.class_id==-1, UNKNOWN_CLASS, selected_metadata_df.class_id)

In [59]:
selected_metadata_df.to_csv(PRODUCTION_BUCKET + "metadata/DF_combined_metadata_mapped_columns.csv", index=False)

## 3. Pair up with embeddings
Now that we have a single metadata dataframe as our single source of truth, we will match the embeddings.

In [7]:
selected_metadata_df = pd.read_csv(PRODUCTION_BUCKET + "metadata/DF_combined_metadata_mapped_columns.csv")

In [8]:
selected_metadata_df

Unnamed: 0,observationID,month,day,countryCode,locality,level0Gid,level0Name,level1Gid,level1Name,level2Gid,...,Substrate_text,Habitat_text,MetaSubstrate_text,kingdom_text,phylum_text,class_text,order_text,family_text,genus_text,species_text
0,2238506390,9.0,1.0,DK,5179,9,Denmark,28,Hovedstaden,80,...,soil,park/churchyard,jord,Fungi,Basidiomycota,Agaricomycetes,Polyporales,Meruliaceae,Abortiporus,Abortiporus biennis
1,2812984326,7.0,8.0,DK,6157,9,Denmark,32,Syddanmark,146,...,dead wood (including bark),lawn,wood,Fungi,Basidiomycota,Agaricomycetes,Polyporales,Meruliaceae,Abortiporus,Abortiporus biennis
2,2465026418,11.0,19.0,DK,6829,9,Denmark,31,Sjælland,139,...,dead wood (including bark),Unmanaged deciduous woodland,wood,Fungi,Basidiomycota,Agaricomycetes,Polyporales,Meruliaceae,Abortiporus,Abortiporus biennis
3,2812994333,7.0,9.0,DK,6145,9,Denmark,32,Syddanmark,146,...,dead wood (including bark),lawn,wood,Fungi,Basidiomycota,Agaricomycetes,Polyporales,Meruliaceae,Abortiporus,Abortiporus biennis
4,2812983332,8.0,20.0,DK,6157,9,Denmark,32,Syddanmark,146,...,dead wood (including bark),lawn,wood,Fungi,Basidiomycota,Agaricomycetes,Polyporales,Meruliaceae,Abortiporus,Abortiporus biennis
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
356765,2238482337,11.0,23.0,DK,9621,9,Denmark,29,Midtjylland,98,...,faeces,other habitat,animals,Fungi,Basidiomycota,Agaricomycetes,Agaricales,Tricholomataceae,Clitocybe,
356766,2237940567,10.0,26.0,DK,2012,9,Denmark,29,Midtjylland,111,...,faeces,salt meadow,animals,Fungi,Basidiomycota,Agaricomycetes,Agaricales,Tricholomataceae,Clitocybe,
356767,2238455560,11.0,7.0,DK,1662,9,Denmark,29,Midtjylland,108,...,faeces,salt meadow,animals,Fungi,Basidiomycota,Agaricomycetes,Agaricales,Tricholomataceae,Clitocybe,
356768,2449442140,11.0,3.0,DK,6116,9,Denmark,31,Sjælland,133,...,faeces,Forest bog,animals,Fungi,Basidiomycota,Agaricomycetes,Agaricales,Tricholomataceae,Clitocybe,


In [9]:
# For pairing up with embeddings, we will use numerical data only so there is less data to load etc
numerical_metadata_df = selected_metadata_df.drop([c + "_text" for c in CATEGORICAL_COLUMNS], axis=1)
numerical_metadata_df = numerical_metadata_df.drop(['filename', 'scientificName', 'countryCode', 'level0Name', 'level1Name', 'level2Name'], axis=1)

In [10]:
numerical_metadata_df['image_path'] = numerical_metadata_df.image_path.apply(lambda x: x.replace(".JPG", ".jpg"))

In [12]:
# Match resnet embeddings
resnet_pq = "gs://dsgt-clef-fungiclef-2024/data/parquet/DF20_300px_and_DF21_300px_corrected_FULL_SET_embedding/resnet"
spark = get_spark(**{
    "spark.sql.parquet.enableVectorizedReader": False, 
})

resnet_df = spark.read.parquet(resnet_pq)
resnet_embeddings = resnet_df.select("image_path", "embeddings").toPandas()

Setting default log level to "Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Exception in thread "main" java.nio.file.NoSuchFileException: /tmp/tmpbzot30rp/connection9366208419788263865.info
	at java.base/sun.nio.fs.UnixException.translateToIOException(UnixException.java:92)
	at java.base/sun.nio.fs.UnixException.rethrowAsIOException(UnixException.java:111)
	at java.base/sun.nio.fs.UnixException.rethrowAsIOException(UnixException.java:116)
	at java.base/sun.nio.fs.UnixFileSystemProvider.newByteChannel(UnixFileSystemProvider.java:219)
	at java.base/java.nio.file.Files.newByteChannel(Files.java:371)
	at java.base/java.nio.file.Files.createFile(Files.java:648)
	at java.base/java.nio.file.TempFileHelper.create(TempFileHelper.java:137)
	at java.base/java.nio.file.TempFileHelper.createTempFile(TempFileHelper.jav

In [62]:
# Match dataset by image_path
resnet_embeddings_full_df = numerical_metadata_df.set_index('image_path').join(resnet_embeddings.set_index('image_path')).reset_index()


In [63]:
DATASET_PATH = "gs://dsgt-clef-fungiclef-2024/production/resnet/"
resnet_embeddings_full_df[resnet_embeddings_full_df.dataset=="train"].to_parquet(DATASET_PATH + "DF_300_train.parquet")
resnet_embeddings_full_df[resnet_embeddings_full_df.dataset=="valid"].to_parquet(DATASET_PATH + "DF_300_valid.parquet")
resnet_embeddings_full_df[resnet_embeddings_full_df.dataset=="test"].to_parquet(DATASET_PATH + "DF_300_test.parquet")

Repeating the same for dino embeddings

In [None]:
# Match resnet embeddings
resnet_pq = "gs://dsgt-clef-fungiclef-2024/data/parquet/DF20_300px_and_DF21_300px_corrected_FULL_SET_embedding/resnet"
spark = get_spark(**{
    "spark.sql.parquet.enableVectorizedReader": False, 
})

resnet_df = spark.read.parquet(resnet_pq)
resnet_embeddings = resnet_df.select("image_path", "embeddings").toPandas()

In [32]:
# Dino embeddings not working out yet because of the bug. TODO: Re-do dino embeddings
dino_dct_pq = "gs://dsgt-clef-fungiclef-2024/data/parquet/DF20_300px_and_DF21_300px_corrected_FULL_SET_embedding/dino/data/sample_id=0/part-00000-bfd7deaf-5486-4cf3-b861-f52758aae09c-c000.snappy.parquet"
# spark = get_spark(**{
#     "spark.sql.parquet.enableVectorizedReader": False, 
# })

dino_dct_df = pd.read_parquet(dino_dct_pq)

In [33]:
dino_dct_df.dino_embedding[0]

array([ 3.6364484,  0.0064761,  1.9875255, ..., -2.3837745, -0.7289071,
       -1.0162587], dtype=float32)

In [None]:
dino_dct_df.dino_embedding[0]

In [3]:
from scipy.fftpack import dctn, dct


In [50]:
import torch
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')

Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /home/chris/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth" to /home/chris/.cache/torch/hub/checkpoints/dinov2_vitl14_pretrain.pth
100%|██████████| 1.13G/1.13G [00:04<00:00, 252MB/s] 


In [56]:
from torch.nn import LayerNorm

In [63]:
norm_1d = LayerNorm(dino_dct_df.dino_embedding[0].shape)
norm_1d.forward(torch.Tensor(dino_dct_df.dino_embedding[0])).shape

torch.Size([197376])

In [72]:
DinoVisionTransformer

NameError: name 'DinoVisionTransformer' is not defined

In [79]:
outputs = dinov2_vitl14.forward(torch.rand((1, 3, 224, 224)))

In [129]:
last_hidden_states[0]

tensor([[ 1.6972, -2.8514,  3.2909,  ..., -0.6694,  0.3950, -1.0779],
        [ 0.3097, -1.9966, -0.1062,  ..., -1.3243, -0.5795, -0.4068],
        [ 0.2852,  0.3234, -0.4515,  ..., -0.7450,  0.5257,  2.0890],
        ...,
        [ 3.5287, -2.0530, -1.8518,  ..., -1.3409, -1.0811,  0.1767],
        [ 0.5803,  0.8584, -1.0418,  ..., -0.3129, -0.0696,  2.6281],
        [ 3.2293, -2.8874, -0.4942,  ..., -1.9794, -1.2602, -0.4612]],
       device='cuda:0')

In [8]:
DINO_SHAPE=(257, 768)
    
def dctn_filter(tile, k):
    coeff = dctn(tile.reshape(DINO_SHAPE))
    coeff_subset = coeff[:k, :k]
    return coeff_subset.flatten()


In [228]:
from numpy.linalg import norm

In [2]:
#TODO: Norm the inputs here

def process_hidden_states(df):
    rows = []
    for _, row in df.iterrows():
        hidden_state = row.dino_embedding.reshape(DINO_SHAPE)
        cls_token = hidden_state[0]
        dct_16_1d = dct(hidden_state[1:], axis=-1)[:, :16]
        dct_64_2d = dctn(hidden_state[1:])[:64, :64]
        rows.append(dict(cls_token=cls_token.tolist(), dct_16_1d=dct_16_1d.tolist(), dct_64_2d=dct_64_2d.tolist()))
    return rows

In [1]:
from google.cloud import storage


def list_blobs_with_prefix(bucket_name, prefix, delimiter=None):
    """Lists all the blobs in the bucket that begin with the prefix.

    This can be used to list all blobs in a "folder", e.g. "public/".

    The delimiter argument can be used to restrict the results to only the
    "files" in the given "folder". Without the delimiter, the entire tree under
    the prefix is returned. For example, given these blobs:

        a/1.txt
        a/b/2.txt

    If you specify prefix ='a/', without a delimiter, you'll get back:

        a/1.txt
        a/b/2.txt

    However, if you specify prefix='a/' and delimiter='/', you'll get back
    only the file directly under 'a/':

        a/1.txt

    As part of the response, you'll also get back a blobs.prefixes entity
    that lists the "subfolders" under `a/`:

        a/b/
    """

    storage_client = storage.Client()

    # Note: Client.list_blobs requires at least package version 1.17.0.
    blobs = storage_client.list_blobs(bucket_name, prefix=prefix, delimiter=delimiter)

    # Note: The call returns a response only when the iterator is consumed.

    file_list = []
    print("Blobs:")
    for blob in blobs:
        if ".parquet" in blob.name: 
            file_list.append("gs://"+bucket_name+"/"+blob.name)

    return file_list

    # if delimiter:
    #     print("Prefixes:")
    #     for prefix in blobs.prefixes:
    #         print(prefix)


In [5]:
dino_outputs = list_blobs_with_prefix("dsgt-clef-fungiclef-2024", prefix="data/parquet/DF20_300px_and_DF21_300px_corrected_FULL_SET_embedding/dino/data/")

Blobs:


In [238]:
from tqdm import tqdm

In [239]:
import pandas as pd
from concurrent.futures import ThreadPoolExecutor

def process_file(file_path):
    df = pd.read_parquet(file_path)
    return process_hidden_states(df)

with ThreadPoolExecutor() as executor:
    all_rows = list(tqdm(executor.map(process_file, dino_outputs)))


9it [00:32,  2.96s/it]

: 

In [9]:
import pandas as pd
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm

def process_file(file_path):
    df = pd.read_parquet(file_path)
    return process_hidden_states(df)

with ProcessPoolExecutor() as executor:
    results = list(tqdm(executor.map(process_file, dino_outputs), total=len(dino_outputs)))
    all_rows = [item for sublist in results for item in sublist]  # Flatten list if necessary

  0%|          | 0/1000 [00:00<?, ?it/s]

  1%|▏         | 13/1000 [00:24<31:27,  1.91s/it] 


BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

: 

In [236]:

all_rows = []
for file_path in dino_outputs:
    df = pd.read_parquet(file_path)
    rows = process_hidden_states(df)
    all_rows += rowsb

KeyboardInterrupt: 

In [206]:
df = pd.read_parquet(dino_outputs[0])


In [212]:
rows = process_hidden_states(df)

In [237]:
pd.DataFrame(rows)

Unnamed: 0,cls_token,dct_16_1d,dct_64_2d
0,"[3.636448383331299, 0.006476100534200668, 1.98...","[[-90.1867446899414, 39.47825622558594, 41.866...","[[-17564.009765625, 8145.9775390625, 28178.802..."
1,"[0.7649978399276733, 0.4228263795375824, 0.184...","[[-77.36708068847656, 11.427928924560547, -1.5...","[[-14039.01953125, 15973.794921875, -20988.708..."
2,"[0.9485918879508972, 2.36519718170166, 2.81349...","[[-88.82673645019531, -1.0302658081054688, 34....","[[-32892.84375, -14167.4970703125, 14405.75292..."
3,"[-4.624961853027344, -0.2530408501625061, 1.81...","[[-27.530357360839844, -5.908840179443359, 44....","[[-21471.888671875, 17808.14453125, 11436.0507..."
4,"[2.610158920288086, -0.07085422426462173, -3.0...","[[-59.53972625732422, 52.84880447387695, 35.24...","[[-21902.71875, 9298.951171875, 16262.63085937..."
...,...,...,...
338,"[0.9336028695106506, 0.5011065602302551, 0.471...","[[-86.13015747070312, 47.450103759765625, 37.5...","[[-18539.68359375, 30315.546875, 26999.9804687..."
339,"[2.7459166049957275, 1.9657036066055298, -1.76...","[[-89.84234619140625, 75.6082534790039, 17.731...","[[-20514.4453125, 16542.859375, 6334.349609375..."
340,"[0.5621036887168884, -1.015462040901184, -1.71...","[[-87.16651916503906, 65.7000961303711, 37.032...","[[-26762.4609375, 15582.8291015625, 16527.1660..."
341,"[1.960525631904602, -1.2922641038894653, -1.99...","[[5.040805816650391, 1.289764404296875, 29.879...","[[-22208.4921875, 9247.31640625, 18881.9042968..."


In [215]:
df

Unnamed: 0,image_path,species,dino_embedding
0,2237912865-151073.jpg,Hortiboletus engelii,"[3.6364484, 0.0064761005, 1.9875255, 2.2292707..."
1,2238406072-312936.jpg,Peniophora rufomarginata,"[0.76499784, 0.42282638, 0.18455583, -0.292270..."
2,2238556407-330535.jpg,Hortiboletus engelii,"[0.9485919, 2.3651972, 2.8134954, -0.26696065,..."
3,2238566015-183150.jpg,Calvatia gigantea,"[-4.624962, -0.25304085, 1.8104588, -1.1925237..."
4,2238579738-37528.jpg,Panaeolus subfirmus,"[2.610159, -0.070854224, -3.038675, 0.9941942,..."
...,...,...,...
338,0-3355971447.jpg,Hygrophoropsis aurantiaca,"[0.93360287, 0.50110656, 0.47133487, -0.294665..."
339,0-3126935330.jpg,Cladonia crispata,"[2.7459166, 1.9657036, -1.762023, 0.004444801,..."
340,3-3395916415.jpg,Cuphophyllus colemannianus,"[0.5621037, -1.015462, -1.7117326, 2.0356944, ..."
341,0-3386488332.jpg,Lacrymaria lacrymabunda,"[1.9605256, -1.2922641, -1.9904997, 1.3893987,..."


In [220]:
len(df_full_img_paths.unique())

356770