In [None]:
import os
import json 
import pandas as pd
import sys
from itertools import product
from tqdm.notebook import tqdm

from clip_benchmark.utils.utils import retrieve_model_dataset_results

sys.path.append('..')
from scripts.helper import load_models, get_hyperparams, parse_datasets

In [None]:
datasets = "../scripts/webdatasets_wo_imagenet.txt"
model_config = "../scripts/filtered_models_config.json"
anchor_model = "OpenCLIP_ViT-L-14_openai"
# anchor_model = "resnet50"
combiner='concat'

single_path = '/home/space/diverse_priors/results/linear_probe/single_model'
ensemble_path = '/home/space/diverse_priors/results/linear_probe/ensemble'
combined_path = '/home/space/diverse_priors/results/linear_probe/combined_models'

In [None]:
datasets = parse_datasets(datasets)
datasets = [ds.replace('/', '_') for ds in datasets]

In [None]:
models, n_models = load_models(model_config)
    
assert anchor_model in models.keys(), f"Model in {anchor_model} not available in {MODELS_CONFIG=}."
models.pop(anchor_model)

if 'SegmentAnything_vit_b' in models.keys():
    models.pop('SegmentAnything_vit_b')

model_keys = [sorted([anchor_model, val]) for val in models.keys()]

In [None]:
dfs = []
for ds, (m1, m2) in tqdm(product(datasets, model_keys), total=len(datasets)*len(model_keys), desc=f"Loading combined ({combiner}) results"):
    model_id = f"{m1}__{m2}_{combiner}" 
    curr_path = os.path.join(combined_path, ds, model_id)
    try:
        df = retrieve_model_dataset_results(curr_path, allow_db_results=False)
    except FileNotFoundError as e:
        print(e)
        print(f"No results found for dataset={ds} and {model_id=}!")
        continue
    dfs.append(df)

for ds, (m1, m2) in tqdm(product(datasets, model_keys), total=len(datasets)*len(model_keys), desc="Loading ensemble results"):
    model_id = f"{m1}__{m2}" 
    curr_path = os.path.join(ensemble_path, ds, model_id)
    try:
        df = retrieve_model_dataset_results(curr_path, allow_db_results=False)
    except FileNotFoundError as e:
        print(e)
        print(f"No results found for dataset={ds} and {model_id=}!")
        continue
    dfs.append(df)

for ds in tqdm(datasets, desc=f"Loading {anchor_model} results"):
    curr_path = os.path.join(single_path, ds, anchor_model)
    try:
        df = retrieve_model_dataset_results(curr_path, allow_db_results=False)
    except FileNotFoundError as e:
        print(e)
        print(f"No results found for dataset={ds} and {model_id=}!")
        continue
    dfs.append(df)
    
df = pd.concat(dfs, axis=0)

In [None]:
hyper_params, _ = get_hyperparams(num_seeds=3, size='imagenet1k')
_ = hyper_params.pop('fewshot_lrs')
_ = hyper_params.pop('reg_lambda')
hyper_params["fewshot_k"] = hyper_params.pop("fewshot_ks")
hyper_params["seed"] = hyper_params.pop("seeds")
for k, v in hyper_params.items():
    try:
        hyper_params[k] = [float(x) for x in v]
    except ValueError:
        pass

In [None]:
for k, v in hyper_params.items():
    df = df[df[k].isin(v)]

In [None]:
df.to_pickle(f'/home/lciernik/projects/divers-priors/diverse_priors/benchmark/scripts/test_results/aggregated/anchor_{anchor_model}.pkl')

In [None]:
HYPER_PARAM_COLS = ['task', 'mode', 'combiner', 'dataset', 'model_ids', 'fewshot_k', 'fewshot_epochs', 'batch_size', 'regularization']

In [None]:
df['model_ids'] = df['model_ids'].apply(eval).apply(tuple)

In [None]:
df[['dataset', 'mode', 'regularization']].value_counts()

In [None]:
df.groupby(HYPER_PARAM_COLS, dropna=False).test_lp_acc1.count().value_counts()