## Notebook 4.7: *Can representational similarity predict performance gaps?*

In [None]:
import json
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import spearmanr, pearsonr
from itertools import product
import textwrap
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1 import make_axes_locatable

from constants import exclude_models, exclude_models_w_mae, cat_name_mapping, ds_info_file, model_config_file, fontsizes
from helper import load_model_configs_and_allowed_models, save_or_show, load_ds_info

sys.path.append('..')
from scripts.helper import parse_datasets

from clip_benchmark.analysis.utils import retrieve_performance

In [None]:
base_path_aggregated = Path('/home/space/diverse_priors/results/aggregated')

### Config similarity data
sim_data = pd.read_csv(base_path_aggregated / 'model_sims/all_metric_ds_model_pair_similarity.csv')

### Config performance data
ds_list_perf = parse_datasets('../scripts/webdatasets_w_in1k.txt')
ds_list_perf = list(map(lambda x: x.replace('/', '_'), ds_list_perf))

ds_info = load_ds_info(ds_info_file)

results_root = '/home/space/diverse_priors/results/linear_probe/single_model'

### Config datasets to include
ds_to_include= set(ds_list_perf) - set(['cifar100-coarse', 'entity13']) 
ds_to_include.add('imagenet-subset-10k')
remaining_ds = sorted(list(set(ds_list_perf) - set(ds_to_include)))

## Storing information
suffix = ''
# suffix = '_ wo_mae'

SAVE = True
storing_path = Path(f'/home/space/diverse_priors/results/plots/scatter_sim_vs_performance_v2')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
## Filter similarity data only for desired datasets
print(sim_data.shape)
if ds_to_include:
    sim_data = sim_data[sim_data['DS'].isin(ds_to_include)].reset_index(drop=True)
print(sim_data.shape)

In [None]:
## Rename datasets with info
sim_data['DS category'] = sim_data['DS'].apply(lambda x: ds_info.loc[x, 'domain'])
sim_data['DS'] = sim_data['DS'].apply(lambda x: ds_info.loc[x, 'name'])

In [None]:
## Post-process 'pair' columns
def pp_pair_col(df_col):
    return df_col.apply(eval).apply(lambda x: f"{cat_name_mapping[x[0]]}, {cat_name_mapping[x[1]]}")


pair_columns = [col for col in sim_data.columns if 'pair' in col]
sim_data[pair_columns] = sim_data[pair_columns].apply(pp_pair_col, axis=0)
pair_columns += [None]

In [None]:
curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)

# allowed_models = sorted(list(set(allowed_models) - set(['jigsaw-rn50', 'rotnet-rn50'])))
# len(allowed_models)

In [None]:
## Filter only for allowed models
sim_data = sim_data[sim_data['Model 1'].isin(allowed_models) & sim_data['Model 2'].isin(allowed_models)].reset_index(drop=True)

#### Retrieve the downstream task performances. 

In [None]:
# import warnings

# # Ignore UserWarnings
# warnings.filterwarnings("ignore", category=UserWarning)

# res = []
# for ds, mid in product(ds_list_perf, allowed_models):
#     performance = retrieve_performance(
#         model_id=mid, 
#         dataset_id=ds, 
#         metric_column='test_lp_acc1',
#         results_root='/home/space/diverse_priors/results/linear_probe/single_model',
#         regularization="weight_decay",
#         allow_db_results=False
#     )
#     res.append({
#         'DS': ds,
#         'Model': mid,
#         'TestAcc': performance
#     })
# perf_res = pd.DataFrame(res)

In [None]:
# perf_res.to_csv(base_path_aggregated/ f'single_model_performance/all_ds{suffix}.csv', index=False)

In [None]:
perf_res = pd.read_csv(base_path_aggregated / f'single_model_performance/all_ds.csv')

In [None]:
if ds_to_include:
    perf_res = perf_res[perf_res['DS'].isin(ds_to_include)].reset_index(drop=True)
perf_res['DS category'] = perf_res['DS'].apply(lambda x: ds_info.loc[x, 'domain'])
perf_res['DS'] = perf_res['DS'].apply(lambda x: ds_info.loc[x, 'name'])
perf_res = perf_res[perf_res['Model'].isin(allowed_models)].reset_index(drop=True)

#### Combine model similarities and performance measures

In [None]:
def get_model_perf(row):
    m1_perf = perf_res.loc[(perf_res['Model'] == row['Model 1']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    m2_perf = perf_res.loc[(perf_res['Model'] == row['Model 2']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    return m1_perf, m2_perf, np.abs(m1_perf - m2_perf)


In [None]:
performance_per_pair = pd.DataFrame(sim_data.apply(get_model_perf, axis=1).tolist(),
                                    columns=['Model 1 perf.', 'Model 2 perf.', 'abs. diff. perf.']).reset_index(drop=True)

In [None]:
sim_data_new = pd.concat([sim_data, performance_per_pair], axis=1)

#### Compute the correlations between the performance gaps and the model similarities

In [None]:
def get_correlation(subset_data):
    corr_sp, _ = spearmanr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
    corr_pr, _ = pearsonr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
    return {'spearmanr': corr_sp, 'pearsonr': corr_pr}


r_coeffs = sim_data_new.groupby(['Similarity metric', 'DS'])[['Similarity value', 'abs. diff. perf.']].apply(
    get_correlation)
r_coeffs = pd.DataFrame(r_coeffs.tolist(), index=r_coeffs.index)

#### Plot the performance vs. similarity scatter plots for 3 datasets per dataset category

In [None]:
sim_data_new = sim_data_new.sort_values(['DS category', 'DS']).reset_index(drop=True)
sim_data_new['max_model_perf'] = sim_data_new[['Model 1 perf.', 'Model 2 perf.']].apply(max, axis=1)

In [None]:
ds_order = ['ImageNet-1k', 'Flowers', 'Diabetic Retinopathy', 'DTD',
            'CIFAR-100', 'Pets', 'EuroSAT', 'Dmlab',
            'Entity-30', 'Stanford Cars', 'PCAM', 'FER2013']

In [None]:
metric_to_consider = 'CKA linear'

sim_data_new_subset = sim_data_new[sim_data_new['Similarity metric'] == metric_to_consider].copy()

In [None]:
def get_scatter_grid_v2(hue_col, figsize=(9, 7), corr_type='spearmanr'):
    n, m = 3, 4
    cm = 0.393701
    fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(figsize[0] * cm * m, figsize[1] * cm * n), sharex=True,
                             sharey=False)
    axes = axes.flatten()
    
    for i, (col, ax) in enumerate(zip(ds_order, axes)):
        group_data = sim_data_new_subset[sim_data_new_subset['DS'] == col]
        
        # Create a norm for this subplot
        vmin = group_data[hue_col].min()
        vmax = group_data[hue_col].max()
        norm = plt.Normalize(vmin=vmin, vmax=vmax)
        
        scatter = sns.scatterplot(
            group_data,
            x='Similarity value',
            y='abs. diff. perf.',
            hue=hue_col,
            palette='viridis',
            alpha=0.25,
            ax=ax,
            s=15,
            legend=False,  # Don't show the legend
            hue_norm=norm,  # Use the subplot-specific norm
        )
        
        xlbl = 'Similarity value' if i // m == 2 else ''
        ax.set_xlabel(xlbl, fontsize=fontsizes['label'])
        ylbl = f'Model Performance Gap' if i % m == 0 else ''
        ax.set_ylabel(ylbl, fontsize=fontsizes['label'])
        col_cat = ds_info.loc[ds_info['name'] == col, 'domain'].unique()[0]
        title = f"$\\it{{{col_cat}}}$\n{col}" if i // m == 0 else col
        ax.set_title(title, fontsize=fontsizes['title'])
        
        ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))
        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))
        ax.tick_params(axis='both', which='major', labelsize=fontsizes['ticks'])
        
        if corr_type == 'spearmanr':
            r_coeff = r_coeffs.loc[(metric_to_consider, col), 'spearmanr']
        else:
            r_coeff = r_coeffs.loc[(metric_to_consider, col), 'pearsonr']
        ax.text(0.9, 0.9, f'r coeff.: {r_coeff:.2f}',
                transform=ax.transAxes, fontsize=fontsizes['label'],
                bbox=dict(facecolor='white', alpha=0.5, edgecolor='white'),
                ha='right')
        
        # Add a colorbar to each subplot
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, cax=cax)
        cbar.ax.tick_params(labelsize=fontsizes['ticks'])
    
    fig.subplots_adjust(wspace=0.4, hspace=0.3)  # Increase spacing to accommodate colorbars
    fig.tight_layout()
    return fig

In [None]:
fig = get_scatter_grid_v2('max_model_perf', (9, 7), 'pearsonr')
save_or_show(fig, storing_path / f'scatter_3_4_pearsonr_grid_cka_linear_max_model_perf.pdf', SAVE)

In [None]:
fig = get_scatter_grid_v2('max_model_perf', (9, 7), 'spearmanr')
save_or_show(fig, storing_path / f'scatter_3_4_spearmanr_grid_cka_linear_max_model_perf.pdf', SAVE)

In [None]:
def get_scatter_grid_v1(hue_col, figsize=(9, 7), corr_type='spearmanr'):
    n, m = 3, 4
    cm = 0.393701

    fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(figsize[0] * cm * m, figsize[1] * cm * n), sharex=True,
                             sharey=False)
    axes = axes.flatten()

    for i, (col, ax) in enumerate(zip(ds_order, axes)):
        group_data = sim_data_new_subset[sim_data_new_subset['DS'] == col]
        assert group_data['Similarity metric'].nunique() == 1

        sns.scatterplot(
            group_data,
            x='Similarity value',
            y='abs. diff. perf.',
            hue=hue_col,
            alpha=0.5,
            ax=ax,
            s=15,
        )
        xlbl = 'Similarity value' if i // m == 2 else ''
        ax.set_xlabel(xlbl, fontsize=fontsizes['label'])

        ylbl = f'Model Performance Gap' if i % m == 0 else ''
        ax.set_ylabel(ylbl, fontsize=fontsizes['label'])

        col_cat = ds_info.loc[ds_info['name'] == col, 'domain'].unique()[0]
        # title = f"{col_cat}\n{col}" if i//m == 0 else col
        title = f"$\\it{{{col_cat}}}$\n{col}" if i // m == 0 else col
        ax.set_title(title, fontsize=fontsizes['title'])
        
        if i == 3 and hue_col:
            sns.move_legend(ax,
                            loc='upper left',
                            title=hue_col,
                            bbox_to_anchor=(1, 1), fontsize=fontsizes['legend'],
                            title_fontsize=fontsizes['legend'], frameon=False)
        elif hue_col:
            ax.get_legend().remove()

        ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))
        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))
        ax.tick_params(axis='both',  # Apply to both x and y axes
                       which='major',  # Apply to major ticks
                       labelsize=fontsizes['ticks'])
        if corr_type == 'spearmanr':
            r_coeff = r_coeffs.loc[('CKA linear', col), 'spearmanr']
        else:
            r_coeff = r_coeffs.loc[('CKA linear', col), 'pearsonr']

        ax.text(0.95, 0.9, f'r coeff.: {r_coeff:.2f}',
                transform=ax.transAxes, fontsize=fontsizes['label'],
                bbox=dict(facecolor='white', alpha=0.5),
                ha='right')  # Align text to the right
    fig.subplots_adjust(wspace=0.2, hspace=0.1)
    fig.tight_layout()
    return fig


In [None]:
sns.set_style('ticks')
for corr_type in ['pearsonr', 'spearmanr']:
    for hue_col, figsize in zip(pair_columns, [(9, 7), (9, 7), (10, 8), (10, 8), (8, 7), (8, 7)]):
        fig = get_scatter_grid_v1(hue_col, figsize, corr_type)
        if hue_col:
            suffix = f'_{hue_col.replace(" ", "_")}'
        else:
            suffix = ""
        save_or_show(fig, storing_path / f'scatter_3_4_{corr_type}_grid_cka_linear{suffix}.pdf', SAVE)

In [None]:
# imnet = sim_data_new_subset[sim_data_new_subset['DS'] == 'Flowers'].copy()
# imnetsubset = imnet[['Model 1 perf.', 'Model 2 perf.', 'Similarity value']].copy()

In [None]:
# def scatter_to_grid_heatmap(ax, df, nbins=5):
#     x_min, x_max = df['Model 1 perf.'].min(), df['Model 1 perf.'].max()
#     y_min, y_max = df['Model 2 perf.'].min(), df['Model 2 perf.'].max()
#     df['t1'] = pd.cut(df['Model 1 perf.'], bins=nbins, labels=range(nbins))
#     df['t2'] = pd.cut(df['Model 2 perf.'], bins=nbins, labels=range(nbins))
#     tbl = pd.pivot_table(
#         df,
#         values='Similarity value',
#         index='t1',
#         columns='t2',
#         aggfunc='mean',
#         observed=True
#     )
#     tbl = tbl.reindex(index=range(nbins), columns=range(nbins), fill_value=np.nan)
    
#     sns.heatmap(tbl, ax=ax, 
#                 vmin=0.2, vmax=0.95, 
#                 mask = tbl.isna())
    
#     ax.set_xticks(np.linspace(0, nbins, nbins+1))
#     ax.set_yticks(np.linspace(0, nbins, nbins+1))
#     ax.set_xticklabels([f'{x:.1f}' for x in np.linspace(x_min, x_max, nbins+1)])
#     ax.set_yticklabels([f'{y:.1f}' for y in np.linspace(y_min, y_max, nbins+1)], rotation=90)
    
#     ax.invert_yaxis()


In [None]:
# n, m = 3, 4
# cm = 0.393701
# fig, axes = plt.subplots(nrows=n, ncols=m, 
#                          figsize=(10 * cm * m, 8 * cm * n), 
#                          sharex=False,
#                          sharey=False)
# axes = axes.flatten()

# for i, (col, ax) in enumerate(zip(ds_order, axes)):

#     group_data = sim_data_new_subset[sim_data_new_subset['DS'] == col]
    
#     scatter_to_grid_heatmap(ax, group_data.copy())
#     ax.set_title(col)
    
# fig.tight_layout()

#### Old code

In [None]:
# def get_scatter_n_compute_correlation(hue_col):
#     n = sim_data_new['Similarity metric'].nunique()
#     m = sim_data_new['DS'].nunique()
#     box_size = 4
# 
#     fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(box_size * m, box_size * n), sharex=True, sharey=False)
#     axes = axes.flatten()
# 
#     r_vals_sr, r_vals_pr = [], []
#     for i, (keys, group_data) in enumerate(sim_data_new.groupby(['Similarity metric', 'DS'])):
#         ax = axes[i]
#         sns.scatterplot(
#             group_data,
#             x='Similarity value',
#             y='abs. diff. perf.',
#             hue=hue_col,
#             alpha=0.5,
#             ax=ax,
#             s=15,
#         )
# 
#         xlbl = 'Similarity value' if i // m == 1 else ''
#         ax.set_xlabel(xlbl, fontsize=11)
# 
#         ylbl = f'{keys[0]}\n Model Performance Gap' if i % m == 0 else ''
#         ax.set_ylabel(ylbl, fontsize=12)
# 
#         title = keys[1] if i // m == 0 else ''
#         ax.set_title(title, fontsize=12)
# 
#         if i % m == (m - 1) and i // m == 0:
#             sns.move_legend(ax,
#                             loc='upper left',
#                             title=hue_col,
#                             bbox_to_anchor=(1, 1), fontsize=12,
#                             title_fontsize=12, frameon=False)
#         else:
#             ax.get_legend().remove()
# 
#         sp_r_coeff = r_coeffs.loc[keys, 'spearmanr']
#         pr_r_coeff = r_coeffs.loc[keys, 'pearsonr']
# 
#         ax.text(0.95, 0.9, f'r_spearman: {sp_r_coeff:.2f}',
#                 transform=ax.transAxes, fontsize=11,
#                 bbox=dict(facecolor='white', alpha=0.5),
#                 ha='right')  # Align text to the right
# 
#         ax.text(0.95, 0.8, f'r_pearson: {pr_r_coeff:.2f}',
#                 transform=ax.transAxes, fontsize=11,
#                 bbox=dict(facecolor='white', alpha=0.5),
#                 ha='right')  # Align text to the right
# 
#         # for cat, subset_data in group_data.groupby(hue_col):
#         #     corr_sp, _ = spearmanr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
#         #     corr_pr, _ = pearsonr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
#         #     r_vals_sr.append(tuple(list(keys) + [cat, corr_sp]))
#         #     r_vals_pr.append(tuple(list(keys) + [cat, corr_pr]))
# 
#     fig.subplots_adjust(wspace=0.2, hspace=0.1)
#     # r_vals_sr = pd.DataFrame(r_vals_sr, columns=['DS', 'Similarity metric', hue_col, 'corr'])
#     # r_vals_pr = pd.DataFrame(r_vals_pr, columns=['DS', 'Similarity metric', hue_col, 'corr'])
# 
#     return fig, r_vals_sr, r_vals_pr

In [None]:
# sns.set_style('ticks')
# for hue_col in pair_columns:
#     fig, _, _ = get_scatter_n_compute_correlation(hue_col)
#     curr_suffix = suffix + f'_{"_".join(ds_to_include)}' if ds_to_include else suffix
#     save_or_show(fig, storing_path / f'scatter_{hue_col.replace(" ", "_")}{curr_suffix}.pdf', SAVE)