In [1]:
import json
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import torch

from clip_benchmark.utils.utils import prepare_ds_name

sys.path.append('..')
from scripts.helper import parse_datasets

In [2]:
datasets = "../scripts/webdatasets_wo_imagenet.txt"

features_base = Path('/home/space/diverse_priors/features')
dataset_base = Path('/home/space/diverse_priors/datasets/subsets')

base_model = 'dinov2-vit-large-p14'

total_sample_nr = 10000

np.random.seed(42)  ## IMPORTANT: SET SEED

In [3]:
datasets = parse_datasets(datasets)
datasets_features = [prepare_ds_name(ds) for ds in datasets]

In [4]:
ds_stats = {}
for ds in datasets_features:
    try:
        df = pd.Series(torch.load(features_base / ds / base_model / 'targets_train.pt'))
    except FileNotFoundError as e:
        print(f'No training data available for {ds=}')
        continue
    ds_stats[ds] = dict(
        nsamples_train=len(df),
        ncls_train=df.nunique(),
        indices_dict={value: np.where(df == value)[0] for value in sorted(df.unique())},
        indices_dict_les={value: len(np.where(df == value)[0]) for value in sorted(df.unique())}
    )

In [5]:
datasets = pd.DataFrame(ds_stats).T.sort_values('nsamples_train')

In [6]:
datasets.loc[datasets['nsamples_train'] > total_sample_nr, 'samples_per_class'] = np.ceil(
    total_sample_nr / datasets['ncls_train'])

In [7]:
def create_subset(indices_dict, samples_per_class):
    if np.isnan(samples_per_class):
        return np.nan
    subset_dict = {k: np.random.choice(v, min(len(v), samples_per_class), replace=False) for k, v in
                   indices_dict.items()}
    subset_dict = sorted(list(np.hstack(list(subset_dict.values()))))
    return subset_dict


datasets['subset_indices'] = datasets.apply(lambda x: create_subset(x['indices_dict'], x['samples_per_class']), axis=1)

In [8]:
for ds, row in datasets.iterrows():
    if row['nsamples_train'] > total_sample_nr:
        storing_path = dataset_base / ds
        storing_path.mkdir(parents=True, exist_ok=True)
        with open(storing_path / 'subset_indices_train.json', 'w') as json_file:
            tmp = [int(val) for val in row['subset_indices']]
            json.dump(tmp, json_file)