## Notebook appendix F: *Distribution correlation coefficients downstream task performance vs. model similarity*
This notebook creates the plots for section F in the appendix. It shows the correlations between the downstream task performance differences and the model similarities for each dataset category and dataset. The correlations are calculated using the Pearson correlations. 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import spearmanr, pearsonr
from matplotlib.lines import Line2D

from constants import exclude_models, exclude_models_w_mae, cat_name_mapping, model_config_file, \
    BASE_PATH_RESULTS, ds_list_perf_file, fontsizes, fontsizes_cols
from helper import load_model_configs_and_allowed_models, save_or_show, load_all_datasetnames_n_info, \
    pp_storing_path

In [None]:
### Config datasets
ds_list_perf, ds_info = load_all_datasetnames_n_info(ds_list_perf_file, verbose=False)

### Config similarity data
sim_data_path = BASE_PATH_RESULTS / 'aggregated' / 'model_sims/all_metric_ds_model_pair_similarity.csv'
assert sim_data_path.exists(), f"Path does not exist: {sim_data_path}. Aggregated similarity data not found, please run `aggregate_similarities_across_datasets.ipynb` before."

### Config performance data
perf_data_path = BASE_PATH_RESULTS / f'aggregated/single_model_performance/all_ds.csv'
assert perf_data_path.exists(), f"Path does not exist: {perf_data_path}. Aggregated performance data not found, please run `aggregate_downstream_task_perfs.ipynb` before."

### Config datasets to include
ds_to_include = set(ds_list_perf) - set(['cifar100-coarse', 'entity13'])
ds_to_include.add('imagenet-subset-10k')
remaining_ds = sorted(list(set(ds_list_perf) - set(ds_to_include)))

## Storing information
suffix = ''
# suffix = '_ wo_mae'

## Version and plotting info
version = 'arxiv'
curr_fontsizes = fontsizes if version == 'arxiv' else fontsizes_cols

SAVE = True
storing_path = pp_storing_path(BASE_PATH_RESULTS / 'plots' / 'final' / version / 'app_F_corr_perf_vs_sim', SAVE)

#### Load the model configurations and allowed models

In [None]:
curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)

#### Load similarity data

In [None]:
sim_data = pd.read_csv(sim_data_path)

In [None]:
## Filter similarity data only for desired datasets
print(sim_data.shape)
if ds_to_include:
    sim_data = sim_data[sim_data['DS'].isin(ds_to_include)].reset_index(drop=True)
print(sim_data.shape)

In [None]:
## Rename datasets with info
sim_data['DS category'] = sim_data['DS'].apply(lambda x: ds_info.loc[x, 'domain'])
sim_data['DS'] = sim_data['DS'].apply(lambda x: ds_info.loc[x, 'name'])

In [None]:
## Post-process 'pair' columns
def pp_pair_col(df_col):
    return df_col.apply(eval).apply(lambda x: f"{cat_name_mapping[x[0]]}, {cat_name_mapping[x[1]]}")


pair_columns = [col for col in sim_data.columns if 'pair' in col]
sim_data[pair_columns] = sim_data[pair_columns].apply(pp_pair_col, axis=0)
pair_columns += [None]

In [None]:
## Filter only for allowed models
sim_data = sim_data[sim_data['Model 1'].isin(allowed_models) & sim_data['Model 2'].isin(allowed_models)].reset_index(
    drop=True)

#### Load performance data

In [None]:
perf_res = pd.read_csv(perf_data_path)

In [None]:
if ds_to_include:
    perf_res = perf_res[perf_res['DS'].isin(ds_to_include)].reset_index(drop=True)
perf_res['DS category'] = perf_res['DS'].apply(lambda x: ds_info.loc[x, 'domain'])
perf_res['DS'] = perf_res['DS'].apply(lambda x: ds_info.loc[x, 'name'])
perf_res = perf_res[perf_res['Model'].isin(allowed_models)].reset_index(drop=True)

#### Combine model similarities and performance measures

In [None]:
def get_model_perf(row):
    m1_perf = perf_res.loc[(perf_res['Model'] == row['Model 1']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    m2_perf = perf_res.loc[(perf_res['Model'] == row['Model 2']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    return m1_perf, m2_perf, np.abs(m1_perf - m2_perf)


In [None]:
performance_per_pair = pd.DataFrame(sim_data.apply(get_model_perf, axis=1).tolist(),
                                    columns=['Model 1 perf.', 'Model 2 perf.', 'abs. diff. perf.']).reset_index(
    drop=True)

In [None]:
sim_data_new = pd.concat([sim_data, performance_per_pair], axis=1)

#### Compute the correlations between the performance gaps and the model similarities

In [None]:
def get_correlation(subset_data):
    corr_sp, _ = spearmanr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
    corr_pr, _ = pearsonr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
    return {'spearmanr': corr_sp, 'pearsonr': corr_pr}


r_coeffs = sim_data_new.groupby(['Similarity metric', 'DS'])[['Similarity value', 'abs. diff. perf.']].apply(
    get_correlation)
r_coeffs = pd.DataFrame(r_coeffs.tolist(), index=r_coeffs.index)

In [None]:
r_coeffs_tmp = r_coeffs.reset_index()
r_coeffs_tmp['name'] = r_coeffs_tmp['DS']
tmp = pd.merge(r_coeffs_tmp, ds_info.reset_index(names=['DS']), how='left', on='name')
tmp = tmp.drop(columns=['DS_y'])
tmp = tmp[~tmp.duplicated()].reset_index(drop=True)
tmp = tmp.sort_values(['Similarity metric', 'domain', 'spearmanr']).reset_index(drop=True)
if SAVE:
    fn = storing_path / 'corr_perf_vs_sim_per_ds.csv'
    tmp.to_csv(fn, index=False)
# tmp

In [None]:
melted_ds_perf_sim_corr = pd.melt(
    tmp,
    id_vars=['Similarity metric', 'DS_x', 'name', 'domain'],
    var_name='Correlation metric',
    value_name='Correlation coefficient'
)

### Plot the barplots (i.e., correlation distributions) for each dataset category

In [None]:
domain_colors = {
    'Natural (multi-domain)': '#8da0cb',
    'Natural (single-domain)': '#e78ac3',
    'Specialized': '#a6d854',
    'Structured': '#b3b3b3'
}

df = melted_ds_perf_sim_corr[
    (melted_ds_perf_sim_corr['Similarity metric'] == 'CKA linear') &
    (melted_ds_perf_sim_corr['Correlation metric'] == 'pearsonr')
]

if version == 'arxiv':
    bbox_to_anchor=(0.19, 1.02)
    fontsize_legend = curr_fontsizes['label']
    figsize=(8, 5)
else:
    bbox_to_anchor = (0.21, 1.02)
    fontsize_legend = curr_fontsizes['ticks']
    figsize=(9, 6)

plt.figure(figsize=figsize)

unique_names = df['name'].unique()
x = np.arange(len(unique_names))

colors = [domain_colors[domain] for domain in df['domain']]
plt.scatter(x, df['Correlation coefficient'], 
           c=colors, 
           s=100,  
           alpha=1) 


plt.ylabel('Correlation Coefficient', fontsize=curr_fontsizes['label'])
plt.xticks(x, unique_names, rotation=45, ha='right')
plt.tick_params('both', labelsize=curr_fontsizes['ticks'])

plt.axhline(-.3, alpha=0.5, ls=':', c='grey', zorder=-1)
plt.axhline(-.5, alpha=0.5, ls=':', c='grey', zorder=-1)
plt.axhline(-.7, alpha=0.5, ls=':', c='grey', zorder=-1)

domain_patches = [plt.scatter([], [], c=color, label=domain, s=100)
                 for domain, color in domain_colors.items()]

plt.legend(handles=domain_patches, 
          loc='upper center', 
          bbox_to_anchor=bbox_to_anchor, 
          title='',
          frameon=False,
          fontsize=fontsize_legend,
          ncol=1,
          )

plt.tight_layout()
save_or_show(plt.gcf(), storing_path / f'scatter_corr_perf_vs_sim_per_ds_cat_cka_linear.pdf', SAVE)