In [1]:
import os
import json 
import pandas as pd
import sys
from itertools import product
from tqdm.notebook import tqdm
from pathlib import Path
import torch
import seaborn as sns
import matplotlib.pyplot as plt

from clip_benchmark.utils.utils import retrieve_model_dataset_results

#from constants import sim_metric_name_mapping
#from helper import get_model_ids

sys.path.append('..')
from scripts.helper import load_models, get_hyperparams, parse_datasets

#### Global variables

In [37]:
## DATASET AND MODEL CONFIG
datasets = "../scripts/webdatasets_wo_imagenet.txt"
model_config = "../scripts/filtered_models_config.json"
anchor_model = "OpenCLIP_ViT-L-14_openai" # ANCHOR MODEL 1
#anchor_model = "resnet50" # ANCHOR MODEL 2
combiner='concat'

## SIMILARITY METRICS 
similarity_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.2',
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_rbf_unbiased_sigma_0.6',
    'cka_kernel_rbf_unbiased_sigma_0.8',  
    'cka_kernel_linear_unbiased',
    'rsa_method_correlation_corr_method_pearson',
    'rsa_method_correlation_corr_method_spearman',
]
sim_metric  = similarity_metrics[1]

### IMAGENET SUBSET SIMILARITIES
base_subset = 'imagenet-subset-10k'
model_similarities_base_path = Path('/home/space/diverse_priors/model_similarities') / base_subset
model_similarities_path = model_similarities_base_path / sim_metric

### AGGREGATED RESULTS --> GOTTEN WITH gather_anchor_exp_results.ipynb
base_path_aggregated_results = Path('/home/space/diverse_priors/results/aggregated')

### SINGLE MODEL BEST PERFORMANCES --> structure path / [L1, L2, weight_decay] / [DATASET].json
single_model_best_perf_path = Path('/home/space/diverse_priors/results/aggregated/max_performance_per_model_n_ds')

### Imagenet Performance base path
singe_model_imgnet_path = Path("/home/space/diverse_priors/results/linear_probe/single_model/wds_imagenet1k")

#### Storing information

In [38]:
# base_storing_path = Path('/home/lciernik/projects/divers-priors/diverse_priors/benchmark/scripts/test_results/neg_corr_exp')
base_storing_path = Path('/home/space/diverse_priors/results/plots/performance_gap_imagenet_acc')
storing_path = base_storing_path / f"{base_subset.replace('-', '_')}__{anchor_model}"
SAVE = True

if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

#### Load similarity values

In [39]:
model_ids_fn = model_similarities_path / 'model_ids.txt'
sim_mat_fn = model_similarities_path / 'similarity_matrix.pt'

def get_model_ids(fn):
    with open(fn, 'r') as file:
        lines = file.readlines()
    lines = [line.strip() for line in lines]
    return lines

model_ids = get_model_ids(model_ids_fn)
sim_mat = torch.load(sim_mat_fn)
sim_mat = pd.DataFrame(sim_mat, index = model_ids, columns=model_ids)

# filter models 
models, nmodels = load_models(model_config)
allowed_models = sorted(list(models.keys()))
sim_mat = sim_mat.loc[allowed_models, allowed_models]
print(f"{sim_mat.shape=}")

sim_mat.shape=(57, 57)


#### Load experiment results

In [40]:
df = pd.read_pickle(base_path_aggregated_results / f'anchor_{anchor_model}.pkl')

In [41]:
HYPER_PARAM_COLS = ['task', 'mode', 'combiner', 'dataset', 'model_ids', 'fewshot_k', 'fewshot_epochs', 'batch_size', 'regularization']

In [42]:
df['model_ids'] = df['model_ids'].apply(eval).apply(tuple)
df['dataset'] = df['dataset'].apply(lambda x: x.replace('/', '_'))

In [43]:
mean_df = df.groupby(HYPER_PARAM_COLS, dropna=False).test_lp_acc1.mean().reset_index()

In [44]:
os.listdir(single_model_best_perf_path/"L2")
def read_json(ds, reg):
    with open(single_model_best_perf_path/reg/(ds+".json")) as f:
        d = json.load(f)
    return d

single_results = {ds: {reg: read_json(ds, reg) for  reg in df["regularization"].unique()} for ds in df["dataset"].unique()}


#### Prepare data for plotting
Steps:
1. Compute performance gap between combined model (concat or ensemble) and single model for each dataset.
2. Add similarity value for each pair of model

In [45]:
single_performance = mean_df[mean_df['mode'] == 'single_model'].copy().reset_index(drop=True)
concat_performance = mean_df[mean_df['mode'] == 'combined_models'].copy().reset_index(drop=True)
ensemble_performance = mean_df[mean_df['mode'] == 'ensemble'].copy().reset_index(drop=True)
print(f"{single_performance.shape=}, {concat_performance.shape=}, {ensemble_performance.shape=}")

single_performance.shape=(72, 10), concat_performance.shape=(4029, 10), ensemble_performance.shape=(4029, 10)


In [46]:
concat_performance['other_model'] = concat_performance['model_ids'].apply(lambda x: x[0] if x[1] == anchor_model else x[1])
ensemble_performance['other_model'] = ensemble_performance['model_ids'].apply(lambda x: x[0] if x[1] == anchor_model else x[1])

In [47]:
def read_img_json(model, reg):
    with open(singe_model_imgnet_path/model/"no_fewshot"/"fewshot_epochs_20"/f"regularization_{reg}"/'batch_size_1024'/"seed_0"/"results.json") as f:
        d = json.load(f)
    return d["test_lp_acc1"]["0"]

imgnet_results = {reg: {model: read_img_json(model, reg) for model in concat_performance["other_model"].unique()} for  reg in df["regularization"].unique()}


In [48]:
ensemble_performance["other_model_dsacc"] = ensemble_performance.apply(lambda x: single_results[x["dataset"]][x["regularization"]][x["other_model"]],axis=1)
concat_performance["other_model_dsacc"] = concat_performance.apply(lambda x: single_results[x["dataset"]][x["regularization"]][x["other_model"]],axis=1)
ensemble_performance["other_model_imgacc"] = ensemble_performance.apply(lambda x: imgnet_results[x["regularization"]][x["other_model"]],axis=1)
concat_performance["other_model_imgacc"] = concat_performance.apply(lambda x: imgnet_results[x["regularization"]][x["other_model"]],axis=1)

In [49]:
## THESE ARE THE ANCHOR MODEL PERFORMANCES FOR DIFFERENT REGULARIZATIONS
single_performance_pivot = pd.pivot_table(
    single_performance,
    index='dataset',
    columns='regularization',
    values='test_lp_acc1'
)
single_performance_pivot

regularization,L1,L2,weight_decay
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
cifar100-coarse,0.9164,0.915667,0.916333
entity13,0.783385,0.796564,0.78359
entity30,0.729444,0.754611,0.741167
living17,0.92451,0.931765,0.917843
nonliving26,0.656026,0.66641,0.656538
wds_cars,0.90735,0.906521,0.908055
wds_country211,0.399289,0.406493,0.407457
wds_fer2013,0.711387,0.709297,0.714637
wds_fgvc_aircraft,0.613161,0.612061,0.615762
wds_gtsrb,0.926128,0.925468,0.926999


In [50]:
def get_performance_gap_n_sim_metric(row):
    other_model = row['other_model']
    comb_perf = row['test_lp_acc1']
    sing_perf = single_performance_pivot.loc[row['dataset'], row['regularization']]
    gap = comb_perf - sing_perf
    single_gap = row["other_model_dsacc"]-sing_perf    
    sim_val = sim_mat.loc[other_model, anchor_model]
    return gap, sim_val, single_gap 

In [51]:
concat_performance = pd.concat([concat_performance, 
                                pd.DataFrame(concat_performance.apply(get_performance_gap_n_sim_metric, axis=1).tolist(), 
                                             columns=['gap', 'sim_value', "ds_acc_gap"])], 
                               axis=1)

ensemble_performance = pd.concat([ensemble_performance,
                                  pd.DataFrame(ensemble_performance.apply(get_performance_gap_n_sim_metric, axis=1).tolist(),
                                               columns=['gap', 'sim_value', "ds_acc_gap"])],
                                 axis=1)

In [52]:
ensemble_performance["sim_value"]

0       0.414018
1       0.414018
2       0.414018
3       0.748758
4       0.748758
          ...   
4024    0.482365
4025    0.482365
4026    0.569174
4027    0.569174
4028    0.569174
Name: sim_value, Length: 4029, dtype: float64

#### Plot scatter plot and add correlation coefficient 

In [53]:
def plot_scatter(df, title,x='ds_acc_gap'):
    g = sns.relplot(
        df,
        x=x,
        y='gap',
        col='regularization',
        row='dataset',
        height=3, 
        aspect=1.25,
        facet_kws={'sharey': False, 'sharex': False}
    )
    g.set_titles("{row_name} – {col_name}")
    
    def annotate_correlation(data, **kwargs):
        r = data[x].corr(data['gap'],method="spearman")
        ax = plt.gca()
        ax.text(0.05, 0.95, f'r = {r:.2f}', transform=ax.transAxes, 
                fontsize=12, verticalalignment='top')
        if max(data['gap'])>0:
            ax.axhspan(0, max(data['gap']), facecolor='lightgreen', alpha=0.2, zorder=-1)
        if max(data['ds_acc_gap'])>0 and x== "ds_acc_gap":
            ax.axvline(x=0, color='black', linestyle='-', linewidth=1, zorder=-1)
            #ax.axvspan(0, max(data['ds_acc_gap']), facecolor='lightgreen', alpha=0.2, zorder=-1)
        if min(data['gap'])<0:
            ax.axhspan(min(data['gap']), 0, facecolor='lightcoral', alpha=0.2, zorder=-1)
    
    g.map_dataframe(annotate_correlation)

    g.fig.suptitle(title, y=1)
    g.fig.tight_layout()
    return g.fig

In [54]:
SAVE = True
fig = plot_scatter(concat_performance, 
                   f"Combined models (Concat) with anchor {anchor_model} and Single Downstream Accuracy Gap.")
if SAVE:
    fig.savefig(storing_path / 'combined_concat_ds.pdf', bbox_inches='tight')
    plt.close(fig)
    print('stored concat img')
else:
    plt.show(fig)

stored concat img


In [55]:
fig = plot_scatter(ensemble_performance, 
                   f"Ensemble with anchor {anchor_model}and Single Downstream Accuracy Gap.")
if SAVE:
    fig.savefig(storing_path / 'ensemble_ds.pdf', bbox_inches='tight')
    plt.close(fig)
    print('stored ensemble img')
else:
    plt.show(fig)

stored ensemble img


In [56]:

fig = plot_scatter(concat_performance, 
                   f"Combined models (Concat) with anchor {anchor_model} and Imagenet Accuracy", "other_model_imgacc" )
if SAVE:
    fig.savefig(storing_path / 'combined_concat_imgacc.pdf', bbox_inches='tight')
    plt.close(fig)
    print('stored concat img')
else:
    plt.show(fig)

stored concat img


In [57]:
fig = plot_scatter(ensemble_performance, 
                   f"Ensemble with anchor {anchor_model} and Imagenet Accuracy", "other_model_imgacc" )
if SAVE:
    fig.savefig(storing_path / 'ensemble_imgacc.pdf', bbox_inches='tight')
    plt.close(fig)
    print('stored ensemble img')
else:
    plt.show(fig)

stored ensemble img
