In [None]:
import json
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import spearmanr, pearsonr

from constants import exclude_models, exclude_models_w_mae, cat_name_mapping
from helper import load_model_configs_and_allowed_models, save_or_show

sys.path.append('..')
from scripts.helper import parse_datasets

In [None]:
# base_path_aggregated = '/home/space/diverse_priors/results/aggregated'
base_path_aggregated = Path('/Users/lciernik/Documents/TUB/projects/divers_prios/results/aggregated')

### Config similarity data
sim_data = pd.read_csv(base_path_aggregated / 'model_sims/all_metric_ds_model_pair_similarity.csv')

### Config performance data
ds_list_perf = parse_datasets('../scripts/webdatasets_wo_ood.txt')
ds_list_perf = list(map(lambda x: x.replace('/', '_'), ds_list_perf))
with open('../scripts/dataset_info.json', 'r') as f:
    ds_info = json.load(f)
    ds_info = {k.replace('/', '_'): v for k, v in ds_info.items()}
ds_info = pd.DataFrame(ds_info).T

results_root = '/home/space/diverse_priors/results/linear_probe/single_model'

### Config datasets to include
# ds_to_include = ['wds_imagenet1k', 'imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam']
ds_to_include = ['imagenet-subset-10k', 'wds_imagenet1k', 'wds_vtab_cifar100', 'entity30'] + \
                ['wds_cars', 'wds_vtab_flowers', 'wds_vtab_pets'] + \
                ['wds_vtab_diabetic_retinopathy', 'wds_vtab_eurosat', 'wds_vtab_pcam'] + \
                ['wds_fer2013', 'wds_vtab_dmlab', 'wds_vtab_dtd']

remaining_ds = sorted(list(set(ds_list_perf) - set(ds_to_include)))

## Storing information
suffix = ''
# suffix = '_ wo_mae'

SAVE = True
# f'/home/space/diverse_priors/results/plots/scatter_sim_vs_performance'
storing_path = Path(
    f'/Users/lciernik/Documents/TUB/projects/divers_prios/results/analysis_model_similarities_across_datasets/scatter_sim_vs_performance'
)
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
## Filter similarity data only for desired datasets
if ds_to_include:
    sim_data = sim_data[sim_data['DS'].isin(ds_to_include)].reset_index(drop=True)

In [None]:
## Rename datasets with info
sim_data['DS category'] = sim_data['DS'].apply(lambda x: ds_info.loc[x, 'domain'])
sim_data['DS'] = sim_data['DS'].apply(lambda x: ds_info.loc[x, 'name'])

In [None]:
## Post-process 'pair' columns
def pp_pair_col(df_col):
    return df_col.apply(eval).apply(lambda x: f"{cat_name_mapping[x[0]]}, {cat_name_mapping[x[1]]}")


pair_columns = [col for col in sim_data.columns if 'pair' in col]
sim_data[pair_columns] = sim_data[pair_columns].apply(pp_pair_col, axis=0)

In [None]:
curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models
print(curr_excl_models)

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path='../scripts/models_config_wo_barlowtwins_n_alignment.json',
    exclude_models=exclude_models,
    exclude_alignment=True,
)

In [None]:
sim_data = sim_data[sim_data['Model 1'].isin(allowed_models) & sim_data['Model 2'].isin(allowed_models)]

In [None]:
# res = []
# for ds, mid in product(ds_list_perf, allowed_models):
#     performance = retrieve_performance(
#         model_id=mid, 
#         dataset_id=ds, 
#         metric_column='test_lp_acc1',
#         results_root='/home/space/diverse_priors/results/linear_probe/single_model',
#         regularization="weight_decay",
#         allow_db_results=False
#     )
#     res.append({
#         'DS': ds,
#         'Model': mid,
#         'TestAcc': performance
#     })
# perf_res = pd.DataFrame(res)

In [None]:
# perf_res.to_csv(base_path_aggregated/ f'single_model_performance/all_ds{suffix}.csv', index=False)

In [None]:
perf_res = pd.read_csv(base_path_aggregated / f'single_model_performance/all_ds.csv')

In [None]:
if ds_to_include:
    perf_res = perf_res[perf_res['DS'].isin(ds_to_include)].reset_index(drop=True)
perf_res['DS category'] = perf_res['DS'].apply(lambda x: ds_info.loc[x, 'domain'])
# perf_res['DS'] = perf_res['DS'].apply(lambda x: f"{ds_info.loc[x, 'name']} ({ds_info.loc[x, 'domain']})")
perf_res['DS'] = perf_res['DS'].apply(lambda x: ds_info.loc[x, 'name'])
perf_res.head()

In [None]:
# perf_res.loc[perf_res['DS'] == 'wds_imagenet1k', 'DS'] = 'imagenet-subset-10k'

In [None]:
def get_model_perf(row):
    m1_perf = perf_res.loc[(perf_res['Model'] == row['Model 1']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    m2_perf = perf_res.loc[(perf_res['Model'] == row['Model 2']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    return m1_perf, m2_perf, np.abs(m1_perf - m2_perf)


In [None]:
performance_per_pair = pd.DataFrame(sim_data.apply(get_model_perf, axis=1).tolist(),
                                    columns=['Model 1 perf.', 'Model 2 perf.', 'abs. diff. perf.']).reset_index()

In [None]:
sim_data_new = pd.concat([sim_data, performance_per_pair], axis=1)

In [None]:
def get_correlation(subset_data):
    corr_sp, _ = spearmanr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
    corr_pr, _ = pearsonr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
    return {'spearmanr': corr_sp, 'pearsonr': corr_pr}


r_coeffs = sim_data_new.groupby(['Similarity metric', 'DS'])[['Similarity value', 'abs. diff. perf.']].apply(
    get_correlation)
r_coeffs = pd.DataFrame(r_coeffs.tolist(), index=r_coeffs.index)

In [None]:
sim_data_new = sim_data_new[sim_data_new['Similarity metric'] == 'CKA linear']

In [None]:
sim_data_new = sim_data_new.sort_values(['DS category', 'DS']).reset_index(drop=True)

In [None]:
ds_order = ['ImageNet-1K', 'Flowers', 'Diabetic Retinopathy', 'DTD',
            'CIFAR-100', 'Pets', 'EuroSAT', 'Dmlab',
            'Entity-30', 'Stanford Cars', 'PCAM', 'FER2013']

In [None]:
def get_scatter_grid(hue_col, figsize=(9, 7), corr_type='spearmanr'):
    n, m = 3, 4
    cm = 0.393701
    fontsize_title = 12
    fontsize_label = 12
    fontsize_ticks = 11

    fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(figsize[0] * cm * m, figsize[1] * cm * n), sharex=True,
                             sharey=False)
    axes = axes.flatten()

    for i, (col, ax) in enumerate(zip(ds_order, axes)):
        group_data = sim_data_new[sim_data_new['DS'] == col]

        sns.scatterplot(
            group_data,
            x='Similarity value',
            y='abs. diff. perf.',
            hue=hue_col,
            alpha=0.5,
            ax=ax,
            s=15,
        )
        xlbl = 'Similarity value' if i // m == 2 else ''
        ax.set_xlabel(xlbl, fontsize=fontsize_label)

        ylbl = f'Model Performance Gap' if i % m == 0 else ''
        ax.set_ylabel(ylbl, fontsize=fontsize_label)

        col_cat = ds_info.loc[ds_info['name'] == col, 'domain'].unique()[0]
        # title = f"{col_cat}\n{col}" if i//m == 0 else col
        title = f"$\\it{{{col_cat}}}$\n{col}" if i // m == 0 else col
        ax.set_title(title, fontsize=fontsize_title)

        # 
        if i == 3:
            sns.move_legend(ax,
                            loc='upper left',
                            title=hue_col,
                            bbox_to_anchor=(1, 1), fontsize=fontsize_label,
                            title_fontsize=fontsize_label, frameon=False)
        else:
            ax.get_legend().remove()

        ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))
        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))
        ax.tick_params(axis='both',  # Apply to both x and y axes
                       which='major',  # Apply to major ticks
                       labelsize=fontsize_ticks)
        if corr_type == 'spearmanr':
            r_coeff = r_coeffs.loc[('CKA linear', col), 'spearmanr']
        else:
            r_coeff = r_coeffs.loc[('CKA linear', col), 'pearsonr']

        ax.text(0.95, 0.9, f'r coeff.: {r_coeff:.2f}',
                transform=ax.transAxes, fontsize=fontsize_label,
                bbox=dict(facecolor='white', alpha=0.5),
                ha='right')  # Align text to the right
    fig.subplots_adjust(wspace=0.2, hspace=0.1)
    fig.tight_layout()
    return fig


In [None]:
sns.set_style('ticks')
corr_type = 'pearsonr'
for hue_col, figsize in zip(pair_columns, [(9, 7), (9, 7), (10, 8), (10, 8)]):
    fig = get_scatter_grid(hue_col, figsize, corr_type)
    save_or_show(fig, storing_path / f'scatter_3_4_{corr_type}_grid_cka_linear_{hue_col.replace(" ", "_")}.pdf', SAVE)

#### Old code

In [None]:
# def get_scatter_n_compute_correlation(hue_col):
#     n = sim_data_new['Similarity metric'].nunique()
#     m = sim_data_new['DS'].nunique()
#     box_size = 4
# 
#     fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(box_size * m, box_size * n), sharex=True, sharey=False)
#     axes = axes.flatten()
# 
#     r_vals_sr, r_vals_pr = [], []
#     for i, (keys, group_data) in enumerate(sim_data_new.groupby(['Similarity metric', 'DS'])):
#         ax = axes[i]
#         sns.scatterplot(
#             group_data,
#             x='Similarity value',
#             y='abs. diff. perf.',
#             hue=hue_col,
#             alpha=0.5,
#             ax=ax,
#             s=15,
#         )
# 
#         xlbl = 'Similarity value' if i // m == 1 else ''
#         ax.set_xlabel(xlbl, fontsize=11)
# 
#         ylbl = f'{keys[0]}\n Model Performance Gap' if i % m == 0 else ''
#         ax.set_ylabel(ylbl, fontsize=12)
# 
#         title = keys[1] if i // m == 0 else ''
#         ax.set_title(title, fontsize=12)
# 
#         if i % m == (m - 1) and i // m == 0:
#             sns.move_legend(ax,
#                             loc='upper left',
#                             title=hue_col,
#                             bbox_to_anchor=(1, 1), fontsize=12,
#                             title_fontsize=12, frameon=False)
#         else:
#             ax.get_legend().remove()
# 
#         sp_r_coeff = r_coeffs.loc[keys, 'spearmanr']
#         pr_r_coeff = r_coeffs.loc[keys, 'pearsonr']
# 
#         ax.text(0.95, 0.9, f'r_spearman: {sp_r_coeff:.2f}',
#                 transform=ax.transAxes, fontsize=11,
#                 bbox=dict(facecolor='white', alpha=0.5),
#                 ha='right')  # Align text to the right
# 
#         ax.text(0.95, 0.8, f'r_pearson: {pr_r_coeff:.2f}',
#                 transform=ax.transAxes, fontsize=11,
#                 bbox=dict(facecolor='white', alpha=0.5),
#                 ha='right')  # Align text to the right
# 
#         # for cat, subset_data in group_data.groupby(hue_col):
#         #     corr_sp, _ = spearmanr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
#         #     corr_pr, _ = pearsonr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
#         #     r_vals_sr.append(tuple(list(keys) + [cat, corr_sp]))
#         #     r_vals_pr.append(tuple(list(keys) + [cat, corr_pr]))
# 
#     fig.subplots_adjust(wspace=0.2, hspace=0.1)
#     # r_vals_sr = pd.DataFrame(r_vals_sr, columns=['DS', 'Similarity metric', hue_col, 'corr'])
#     # r_vals_pr = pd.DataFrame(r_vals_pr, columns=['DS', 'Similarity metric', hue_col, 'corr'])
# 
#     return fig, r_vals_sr, r_vals_pr

In [None]:
# sns.set_style('ticks')
# for hue_col in pair_columns:
#     fig, _, _ = get_scatter_n_compute_correlation(hue_col)
#     curr_suffix = suffix + f'_{"_".join(ds_to_include)}' if ds_to_include else suffix
#     save_or_show(fig, storing_path / f'scatter_{hue_col.replace(" ", "_")}{curr_suffix}.pdf', SAVE)