## Notebook to check the sizes of the webdataset and get subsets if necessary
In this notebook we load each dataset and check its size. If the size is larger than 10k samples, we create a subset of the dataset with a fixed number of samples per class. The subsets are stored in the folder `BASE_PATH_PROJECT/datasets/subsets` and are used for the experiments in the paper.

In [None]:
import json
import sys

import numpy as np
import pandas as pd
import torch

from clip_benchmark.utils.utils import prepare_ds_name
from constants import BASE_PATH_PROJECT

sys.path.append('..')
from scripts.helper import parse_datasets

In [None]:
datasets = "../scripts/webdatasets_wo_imagenet.txt"

features_base = BASE_PATH_PROJECT / 'features'
dataset_base = BASE_PATH_PROJECT / 'datasets/subsets'

base_model = 'dinov2-vit-large-p14'

total_sample_nr = 10000

np.random.seed(42)  ## IMPORTANT: SET SEED

In [None]:
datasets = parse_datasets(datasets)
datasets_features = [prepare_ds_name(ds) for ds in datasets]

In [None]:
ds_stats = {}
for ds in datasets_features:
    try:
        df = pd.Series(torch.load(features_base / ds / base_model / 'targets_train.pt'))
    except FileNotFoundError as e:
        print(f'No training data available for {ds=}')
        continue
    ds_stats[ds] = dict(
        nsamples_train=len(df),
        ncls_train=df.nunique(),
        indices_dict={value: np.where(df == value)[0] for value in sorted(df.unique())},
        indices_dict_les={value: len(np.where(df == value)[0]) for value in sorted(df.unique())}
    )

In [None]:
datasets = pd.DataFrame(ds_stats).T.sort_values('nsamples_train')
datasets

In [None]:
datasets.loc[datasets['nsamples_train'] > total_sample_nr, 'samples_per_class'] = np.ceil(
    total_sample_nr / datasets['ncls_train'])

In [None]:
def create_subset(indices_dict, samples_per_class):
    if np.isnan(samples_per_class):
        return np.nan
    subset_dict = {k: np.random.choice(v, min(len(v), samples_per_class), replace=False) for k, v in
                   indices_dict.items()}
    subset_dict = sorted(list(np.hstack(list(subset_dict.values()))))
    return subset_dict


datasets['subset_indices'] = datasets.apply(lambda x: create_subset(x['indices_dict'], x['samples_per_class']), axis=1)

In [None]:
for ds, row in datasets.iterrows():
    if row['nsamples_train'] > total_sample_nr:
        storing_path = dataset_base / ds
        storing_path.mkdir(parents=True, exist_ok=True)
        with open(storing_path / 'subset_indices_train.json', 'w') as json_file:
            tmp = [int(val) for val in row['subset_indices']]
            json.dump(tmp, json_file)