## Notebook to check the sizes of the webdataset and get subsets if necessary
In this notebook we load each dataset and check its size. If the size is larger than 10k samples, we create a subset of the dataset with a fixed number of samples per class. The subsets are stored in the folder `BASE_PATH_PROJECT/datasets/subsets` and are used for the experiments in the paper.

In [1]:
import json
import sys

import numpy as np
import pandas as pd
import torch

from constants import BASE_PATH_PROJECT
from sim_consistency.utils.utils import prepare_ds_name

from helper import parse_datasets

#### Global Variables

In [3]:
datasets = "../scripts/configs/webdatasets_wo_imagenet.txt"

features_base = BASE_PATH_PROJECT / 'features'
# dataset_base = BASE_PATH_PROJECT / 'datasets/subsets'
dataset_base = BASE_PATH_PROJECT / 'datasets/subsets_30k'
assert features_base.exists(), f'{features_base} does not exist. Please run the feature extraction script first.'
assert dataset_base.exists(), f'{dataset_base} does not exist. Please create the folder first.'

# Any model that has been used for feature extraction, as only the targets are needed. The model name is used to find the correct folder.
base_model = 'dinov2-vit-small-p14'

# total_sample_nr = 10000
total_sample_nr = 30000

np.random.seed(42)  ## IMPORTANT: SET SEED

#### Load datasets

In [4]:
datasets = parse_datasets(datasets)
datasets_features = [prepare_ds_name(ds) for ds in datasets]
datasets_features

['wds_fer2013',
 'wds_voc2007',
 'wds_cars',
 'wds_fgvc_aircraft',
 'wds_stl10',
 'wds_gtsrb',
 'wds_country211',
 'wds_vtab_caltech101',
 'wds_vtab_cifar10',
 'wds_vtab_cifar100',
 'wds_vtab_diabetic_retinopathy',
 'wds_vtab_dmlab',
 'wds_vtab_dtd',
 'wds_vtab_eurosat',
 'wds_vtab_flowers',
 'wds_vtab_pets',
 'wds_vtab_pcam',
 'wds_vtab_resisc45',
 'wds_vtab_svhn',
 'entity30',
 'living17',
 'nonliving26']

In [5]:
ds_stats = {}
for ds in datasets_features:
    try:
        fn = features_base / ds / base_model / 'targets_train.pt'
        df = pd.Series(torch.load(fn))
    except FileNotFoundError as e:
        print(f'No training data available for {ds=}')
        continue
    ds_stats[ds] = dict(
        nsamples_train=len(df),
        ncls_train=df.nunique(),
        indices_dict={value: np.where(df == value)[0] for value in sorted(df.unique())},
        indices_dict_les={value: len(np.where(df == value)[0]) for value in sorted(df.unique())}
    )

In [6]:
datasets = pd.DataFrame(ds_stats).T.sort_values('nsamples_train')
datasets

Unnamed: 0,nsamples_train,ncls_train,indices_dict,indices_dict_les
wds_vtab_flowers,1020,102,"{0: [79, 140, 235, 252, 310, 404, 594, 749, 77...","{0: 10, 1: 10, 2: 10, 3: 10, 4: 10, 5: 10, 6: ..."
wds_vtab_dtd,1880,47,"{0: [81, 134, 160, 165, 224, 261, 350, 410, 41...","{0: 40, 1: 40, 2: 40, 3: 40, 4: 40, 5: 40, 6: ..."
wds_vtab_caltech101,2753,102,"{0: [57, 118, 134, 163, 207, 296, 311, 345, 35...","{0: 29, 1: 27, 2: 25, 3: 29, 4: 27, 5: 28, 6: ..."
wds_vtab_pets,2944,37,"{0: [21, 61, 117, 142, 209, 212, 239, 287, 302...","{0: 76, 1: 82, 2: 87, 3: 80, 4: 83, 5: 82, 6: ..."
wds_fgvc_aircraft,3334,100,"{0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,...","{0: 34, 1: 33, 2: 33, 3: 34, 4: 33, 5: 33, 6: ..."
wds_stl10,5000,10,"{0: [10, 12, 30, 31, 32, 47, 53, 71, 73, 81, 9...","{0: 500, 1: 500, 2: 500, 3: 500, 4: 500, 5: 50..."
wds_voc2007,7844,20,"{0: [10, 11, 14, 15, 16, 152, 163, 208, 283, 3...","{0: 156, 1: 202, 2: 294, 3: 208, 4: 338, 5: 13..."
wds_cars,8144,196,"{0: [162, 461, 521, 706, 772, 886, 945, 1276, ...","{0: 45, 1: 32, 2: 43, 3: 42, 4: 41, 5: 45, 6: ..."
wds_vtab_eurosat,16200,10,"{0: [5, 7, 19, 26, 27, 48, 50, 51, 53, 65, 79,...","{0: 1829, 1: 1832, 2: 1768, 3: 1535, 4: 1495, ..."
wds_vtab_resisc45,18900,45,"{0: [57, 83, 230, 321, 406, 413, 431, 435, 453...","{0: 412, 1: 421, 2: 417, 3: 392, 4: 427, 5: 44..."


#### Create subsets

In [7]:
datasets.loc[datasets['nsamples_train'] > total_sample_nr, 'samples_per_class'] = np.ceil(
    total_sample_nr / datasets['ncls_train'])

In [9]:
def create_subset(indices_dict, samples_per_class):
    if np.isnan(samples_per_class):
        return np.nan
    subset_dict = {k: np.random.choice(v, min(len(v), samples_per_class), replace=False) for k, v in
                   indices_dict.items()}
    subset_dict = sorted(list(np.hstack(list(subset_dict.values()))))
    return subset_dict


datasets['subset_indices'] = datasets.apply(lambda x: create_subset(x['indices_dict'], x['samples_per_class']), axis=1)
datasets

Unnamed: 0,nsamples_train,ncls_train,indices_dict,indices_dict_les,samples_per_class,subset_indices
wds_vtab_flowers,1020,102,"{0: [79, 140, 235, 252, 310, 404, 594, 749, 77...","{0: 10, 1: 10, 2: 10, 3: 10, 4: 10, 5: 10, 6: ...",,
wds_vtab_dtd,1880,47,"{0: [81, 134, 160, 165, 224, 261, 350, 410, 41...","{0: 40, 1: 40, 2: 40, 3: 40, 4: 40, 5: 40, 6: ...",,
wds_vtab_caltech101,2753,102,"{0: [57, 118, 134, 163, 207, 296, 311, 345, 35...","{0: 29, 1: 27, 2: 25, 3: 29, 4: 27, 5: 28, 6: ...",,
wds_vtab_pets,2944,37,"{0: [21, 61, 117, 142, 209, 212, 239, 287, 302...","{0: 76, 1: 82, 2: 87, 3: 80, 4: 83, 5: 82, 6: ...",,
wds_fgvc_aircraft,3334,100,"{0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,...","{0: 34, 1: 33, 2: 33, 3: 34, 4: 33, 5: 33, 6: ...",,
wds_stl10,5000,10,"{0: [10, 12, 30, 31, 32, 47, 53, 71, 73, 81, 9...","{0: 500, 1: 500, 2: 500, 3: 500, 4: 500, 5: 50...",,
wds_voc2007,7844,20,"{0: [10, 11, 14, 15, 16, 152, 163, 208, 283, 3...","{0: 156, 1: 202, 2: 294, 3: 208, 4: 338, 5: 13...",,
wds_cars,8144,196,"{0: [162, 461, 521, 706, 772, 886, 945, 1276, ...","{0: 45, 1: 32, 2: 43, 3: 42, 4: 41, 5: 45, 6: ...",,
wds_vtab_eurosat,16200,10,"{0: [5, 7, 19, 26, 27, 48, 50, 51, 53, 65, 79,...","{0: 1829, 1: 1832, 2: 1768, 3: 1535, 4: 1495, ...",,
wds_vtab_resisc45,18900,45,"{0: [57, 83, 230, 321, 406, 413, 431, 435, 453...","{0: 412, 1: 421, 2: 417, 3: 392, 4: 427, 5: 44...",,


#### Store subsets

In [10]:
for ds, row in datasets.iterrows():
    if row['nsamples_train'] > total_sample_nr:
        print(f"Storing subset for {ds=}")
        storing_path = dataset_base / ds
        storing_path.mkdir(parents=True, exist_ok=True)
        with open(storing_path / 'subset_indices_train.json', 'w') as json_file:
            tmp = [int(val) for val in row['subset_indices']]
            json.dump(tmp, json_file)

Storing subset for ds='wds_country211'
Storing subset for ds='wds_vtab_diabetic_retinopathy'
Storing subset for ds='living17'
Storing subset for ds='wds_vtab_cifar100'
Storing subset for ds='wds_vtab_cifar10'
Storing subset for ds='wds_vtab_dmlab'
Storing subset for ds='nonliving26'
Storing subset for ds='wds_vtab_svhn'
Storing subset for ds='entity30'
Storing subset for ds='wds_vtab_pcam'
