In [1]:
import os
import sys
from itertools import combinations, product
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.stats import spearmanr, kendalltau, pearsonr
from tqdm.notebook import tqdm

from helper import load_model_configs_and_allowed_models, load_similarity_matrices, save_or_show
from clip_benchmark.analysis.utils import retrieve_performance 
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, pearsonr

from constants import exclude_models, exclude_models_w_mae, ds_name_mapping, cat_name_mapping

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping, anchors

In [2]:
### Config similarity data
sim_data = pd.read_csv('/home/space/diverse_priors/results/aggregated/model_sims/all_metric_ds_model_pair_similarity.csv')

### Config performance data
ds_list_perf = parse_datasets('../scripts/webdatasets_wo_ood.txt')
ds_list_perf = list(map(lambda x: x.replace('/', '_'), ds_list_perf))

results_root = '/home/space/diverse_priors/results/linear_probe/single_model'

### Config datasets to include
# ds_to_include = []
ds_to_include = ['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam']

## Storing information
# suffix = '_wo_swav_pirl_timm_clip'
suffix = '_wo_mae_swav_pirl_timm_clip'

SAVE = True
storing_path = Path(
    f'/home/space/diverse_priors/results/plots/scatter_sim_vs_performance'
)
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [3]:
## Filter similarity data only for desired datasets
if ds_to_include:
    sim_data = sim_data[sim_data['DS'].isin(ds_to_include)].reset_index(drop=True)

In [4]:
## Remove MAE models if requested
if 'mae' in suffix:
    sim_data = sim_data[~(sim_data['Model 1'].str.contains('mae') | sim_data['Model 2'].str.contains('mae'))].reset_index(drop=True)

In [5]:
## Post-process 'pair' columns
def pp_pair_col(df_col):
    return df_col.apply(eval).apply(lambda x: f"{cat_name_mapping[x[0]]}, {cat_name_mapping[x[1]]}")

pair_columns = [col for col in sim_data.columns if 'pair' in col]
sim_data[pair_columns] = sim_data[pair_columns].apply(pp_pair_col, axis=0)

In [None]:
curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path='../scripts/models_config_wo_barlowtwins_n_alignment.json',
    exclude_models=exclude_models,
    exclude_alignment=True,
)

In [7]:
# res = []
# for ds, mid in product(ds_list_perf, allowed_models):
#     performance = retrieve_performance(
#         model_id=mid, 
#         dataset_id=ds, 
#         metric_column='test_lp_acc1',
#         results_root='/home/space/diverse_priors/results/linear_probe/single_model',
#         regularization="weight_decay",
#         allow_db_results=False
#     )
#     res.append({
#         'DS': ds,
#         'Model': mid,
#         'TestAcc': performance
#     })
# perf_res = pd.DataFrame(res)

In [8]:
# perf_res.to_csv(f'/home/space/diverse_priors/results/aggregated/single_model_performance/all_ds{suffix}.csv', index=False)

In [9]:
perf_res = pd.read_csv(f'/home/space/diverse_priors/results/aggregated/single_model_performance/all_ds.csv')

In [10]:
perf_res.loc[perf_res['DS']=='wds_imagenet1k', 'DS'] = 'imagenet-subset-10k'

In [11]:
def get_model_perf(row):
    m1_perf = perf_res.loc[(perf_res['Model'] == row['Model 1']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    m2_perf = perf_res.loc[(perf_res['Model'] == row['Model 2']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    return m1_perf, m2_perf, np.abs(m1_perf - m2_perf)


In [12]:
performance_per_pair = pd.DataFrame(sim_data.apply(get_model_perf, axis=1).tolist(), columns=['Model 1 perf.', 'Model 2 perf.', 'abs. diff. perf.']).reset_index()

In [13]:
sim_data_new = pd.concat([sim_data, performance_per_pair], axis=1)

In [14]:
def get_scatter_n_compute_correlation(hue_col):
    n = sim_data_new['Similarity metric'].nunique()
    m = sim_data_new['DS'].nunique()
    box_size = 4
    
    fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(box_size*m , box_size*n), sharex=True, sharey=False)
    axes = axes.flatten()
    
    r_vals_sr, r_vals_pr = [], []
    for i, (keys, group_data) in enumerate(sim_data_new.groupby(['Similarity metric', 'DS'])):
        ax = axes[i]
        sns.scatterplot(
            group_data,
            x='Similarity value',
            y='abs. diff. perf.',
            hue=hue_col,
            alpha=0.5,
            ax=ax,
            s=15,
        )

        xlbl = 'Similarity value' if i//m == 1 else ''
        ax.set_xlabel(xlbl, fontsize=11)
        
        ylbl = f'{keys[0]}\n Model Performance Gap' if i%m==0 else ''
        ax.set_ylabel(ylbl, fontsize=12)

        title = ds_name_mapping[keys[1]] if i//m ==0 else ''
        ax.set_title(title, fontsize=12)
            
        if i%m==(m-1) and i//m ==0:
            sns.move_legend(ax, 
                            loc='upper left', 
                            title=hue_col,
                            bbox_to_anchor=(1, 1), fontsize=12,
                            title_fontsize=12, frameon=False)
        else:
            ax.get_legend().remove()
    
        for cat, subset_data in group_data.groupby(hue_col):
            corr_sp,_ = spearmanr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
            corr_pr,_ = pearsonr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
            r_vals_sr.append(tuple(list(keys)+[cat, corr_sp]))
            r_vals_pr.append(tuple(list(keys)+[cat, corr_pr]))
    
    fig.subplots_adjust(wspace=0.2, hspace=0.1)
    r_vals_sr = pd.DataFrame(r_vals_sr, columns=['DS','Similarity metric', hue_col, 'corr'])
    r_vals_pr = pd.DataFrame(r_vals_pr, columns=['DS','Similarity metric', hue_col, 'corr'])
    
    return fig, r_vals_sr, r_vals_pr

In [15]:
# for hue_col in pair_columns:
#     fig, r_vals_sr, r_vals_pr = get_scatter_n_compute_correlation(hue_col)
#     save_or_show(fig, storing_path / f'scatter_{hue_col.replace(" ", "_")}{suffix}.pdf', SAVE)
    
#     m = sim_data_new['Similarity metric'].nunique()
    
#     fig, axs = plt.subplots(nrows=m, ncols=2, figsize=(5*m, 5*2), sharex=True, sharey=True)
#     for j, (corr_metric, corr_data) in enumerate([('spearmanr', r_vals_sr), ('pearsonr', r_vals_pr)]):
#         for i, (metric, tmp_subset) in enumerate(corr_data.groupby('Similarity metric')):
#             ax = axs[i,j]
#             sns.kdeplot(
#                 tmp_subset,
#                 x='corr',
#                 hue=hue_col,
#                 ax=ax
#                 )
#             if i==0:
#                 ax.set_title(f"{corr_metric.capitalize()}{' (w/o MAE)' if 'mae' in suffix else ''}", fontsize=11)
#             if j==0:
#                 ax.set_ylabel(f"{metric}\nDensity", fontsize=11)

#             if i==0 and j==1:
#                 sns.move_legend(ax, 
#                                 loc='upper left', 
#                                 title=hue_col,
#                                 bbox_to_anchor=(1, 1), fontsize=11,
#                                 title_fontsize=11, frameon=False)
#             else:
#                 ax.get_legend().remove()
            
#     fig.subplots_adjust(hspace=0.1, wspace=0.1)
#     save_or_show(fig, storing_path / f'corr_dist_perf_vs_cka_{hue_col.replace(" ", "_")}{suffix}.pdf', SAVE)

In [None]:
for hue_col in pair_columns:
    fig, r_vals_sr, r_vals_pr = get_scatter_n_compute_correlation(hue_col)
    curr_suffix = suffix + f'_{"_".join(ds_to_include)}' if ds_to_include else suffix
    save_or_show(fig, storing_path / f'scatter_{hue_col.replace(" ", "_")}{curr_suffix}.pdf', SAVE)