## Notebook 4.7: *Can representational similarity predict performance gaps?*
This notebook creates scatter plots of the performance gaps between models on a downstream task against their representational similarity. The representational similarity is computed using the CKA linear metric. The performance gaps are computed as the absolute difference in the performance of the two models on the downstream task. The notebook also computes the Spearman and Pearson r correlation between the performance gaps and the representational similarity. The notebook creates scatter plots for three datasets per dataset category.

In [1]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.stats import pearsonr, spearmanr

from constants import (
    BASE_PATH_RESULTS,
    cat_name_mapping,
    ds_list_perf_file,
    exclude_models,
    exclude_models_w_mae,
    fontsizes,
    model_config_file
)
from helper import (
    load_all_datasetnames_n_info,
    load_model_configs_and_allowed_models,
    pp_storing_path,
    save_or_show
)

#### Global variables

In [2]:
### Config similarity data
sim_data_path = BASE_PATH_RESULTS / 'aggregated' / 'model_sims/all_metric_ds_model_pair_similarity.csv'
assert sim_data_path.exists(), f"Path does not exist: {sim_data_path}. Aggregated similarity data not found, please run `aggregate_similarities_across_datasets.ipynb` before."

### Config performance data
perf_data_path = BASE_PATH_RESULTS / f'aggregated/single_model_performance/all_ds.csv'
assert perf_data_path.exists(), f"Path does not exist: {perf_data_path}. Aggregated performance data not found, please run `aggregate_downstream_task_perfs.ipynb` before."

### Config datasets
ds_list_perf, ds_info = load_all_datasetnames_n_info(ds_list_perf_file, verbose=False)

### Config datasets to include
ds_to_include = set(ds_list_perf) - set(['cifar100-coarse', 'entity13'])
ds_to_include.add('imagenet-subset-10k')
remaining_ds = sorted(list(set(ds_list_perf) - set(ds_to_include)))

## Storing information
suffix = ''
# suffix = '_ wo_mae'

# Version and plotting info
version = '3_opt'
curr_fontsizes = {k: v + 1 for k, v in fontsizes.items()}

SAVE = True
#storing_path = pp_storing_path(BASE_PATH_RESULTS / 'plots' / 'final' / version / 'sec_4_7_perf_vs_sim', SAVE)
storing_path = pp_storing_path('/home/lciernik/projects/divers-priors/results_local/corr_mats_ds/sec_4_7_perf_vs_sim', SAVE)




#### Load model configurations and allowed models

In [3]:
curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)

Nr. models original=64


#### Load similarity data

In [4]:
sim_data = pd.read_csv(sim_data_path)

In [5]:
## Filter similarity data only for desired datasets
print(sim_data.shape)
if ds_to_include:
    sim_data = sim_data[sim_data['DS'].isin(ds_to_include)].reset_index(drop=True)
print(sim_data.shape)

(100800, 9)
(92736, 9)


In [6]:
## Rename datasets with info
sim_data['DS category'] = sim_data['DS'].apply(lambda x: ds_info.loc[x, 'domain'])
sim_data['DS'] = sim_data['DS'].apply(lambda x: ds_info.loc[x, 'name'])

In [7]:
## Post-process 'pair' columns
def pp_pair_col(df_col):
    return df_col.apply(eval).apply(lambda x: f"{cat_name_mapping[x[0]]}, {cat_name_mapping[x[1]]}")


pair_columns = [col for col in sim_data.columns if 'pair' in col]
sim_data[pair_columns] = sim_data[pair_columns].apply(pp_pair_col, axis=0)
pair_columns += [None]

In [8]:
## Filter only for allowed models
sim_data = sim_data[sim_data['Model 1'].isin(allowed_models) & sim_data['Model 2'].isin(allowed_models)].reset_index(
    drop=True)

#### Load performance data

In [9]:
perf_res = pd.read_csv(perf_data_path)

In [10]:
if ds_to_include:
    perf_res = perf_res[perf_res['DS'].isin(ds_to_include)].reset_index(drop=True)
perf_res['DS category'] = perf_res['DS'].apply(lambda x: ds_info.loc[x, 'domain'])
perf_res['DS'] = perf_res['DS'].apply(lambda x: ds_info.loc[x, 'name'])
perf_res = perf_res[perf_res['Model'].isin(allowed_models)].reset_index(drop=True)

#### Combine model similarities and performance measures

In [11]:
def get_model_perf(row):
    m1_perf = perf_res.loc[(perf_res['Model'] == row['Model 1']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    m2_perf = perf_res.loc[(perf_res['Model'] == row['Model 2']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    return m1_perf, m2_perf, np.abs(m1_perf - m2_perf)


In [12]:
performance_per_pair = pd.DataFrame(sim_data.apply(get_model_perf, axis=1).tolist(),
                                    columns=['Model 1 perf.', 'Model 2 perf.', 'abs. diff. perf.']).reset_index(
    drop=True)

In [13]:
sim_data_new = pd.concat([sim_data, performance_per_pair], axis=1)

#### Compute the correlations between the performance gaps and the model similarities

In [14]:
def get_correlation(subset_data):
    corr_sp, _ = spearmanr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
    corr_pr, _ = pearsonr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
    return {'spearmanr': corr_sp, 'pearsonr': corr_pr}


r_coeffs = sim_data_new.groupby(['Similarity metric', 'DS'])[['Similarity value', 'abs. diff. perf.']].apply(
    get_correlation)
r_coeffs = pd.DataFrame(r_coeffs.tolist(), index=r_coeffs.index)

### Plot the performance vs. similarity scatter plots for X datasets per dataset category

In [15]:
sim_data_new = sim_data_new.sort_values(['DS category', 'DS']).reset_index(drop=True)
sim_data_new['max_model_perf'] = sim_data_new[['Model 1 perf.', 'Model 2 perf.']].apply(max, axis=1)

In [16]:
ds_order = ['ImageNet-1k', 'Flowers', 'Diabetic Retinopathy', 'DTD',
            'CIFAR-100', 'Pets', 'EuroSAT', 'Dmlab',
            'Entity-30', 'Stanford Cars', 'PCAM', 'FER2013']

if version == 'arxiv':
    ds_order = ['ImageNet-1k', 'Flowers', 'Diabetic Retinopathy', 'DTD',
                'CIFAR-100', 'Pets', 'EuroSAT', 'Dmlab',
                'Entity-30', 'Stanford Cars', 'PCAM', 'FER2013']
    wspace = 0.4
    hspace = 0.3
    size_one_box = (9, 7)
    text_y_pos = 0.9
elif version == 'opt':
    ds_order = ['ImageNet-1k', 'Flowers', 'Diabetic Retinopathy', 'DTD',
                'CIFAR-100', 'Pets', 'PCAM', 'FER2013']
    wspace = 0.4
    hspace = 0.25
    size_one_box = (9, 6.5)
    text_y_pos = 0.85
else:
    ds_order = ['ImageNet-1k', 'Flowers', 'Diabetic Retinopathy', 'DTD']
    wspace = 0.4
    hspace = 0.25
    size_one_box = (9, 7)
    text_y_pos = 0.85


In [17]:
metric_to_consider = 'CKA linear'

sim_data_new_subset = sim_data_new[sim_data_new['Similarity metric'] == metric_to_consider].copy()

In [18]:
def get_scatter_grid_v2(hue_col, figsize=(9, 7), corr_type='spearmanr'):
    n, m = len(ds_order) // 4, 4
    cm = 0.393701
    fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(figsize[0] * cm * m, figsize[1] * cm * n), sharex=True,
                             sharey=False)
    axes = axes.flatten()

    for i, (col, ax) in enumerate(zip(ds_order, axes)):
        group_data = sim_data_new_subset[sim_data_new_subset['DS'] == col]

        # Create a norm for this subplot
        vmin = group_data[hue_col].min()
        vmax = group_data[hue_col].max()
        norm = plt.Normalize(vmin=vmin, vmax=vmax)

        scatter = sns.scatterplot(
            group_data.sort_values(by=hue_col, ascending=True),
            x='Similarity value',
            y='abs. diff. perf.',
            hue=hue_col,
            palette='viridis',
            alpha=0.5,
            ax=ax,
            s=15,
            legend=False,  # Don't show the legend
            hue_norm=norm,  # Use the subplot-specific norm
        )

        xlbl = 'Similarity value' if i // m == (n - 1) else ''
        ax.set_xlabel(xlbl, fontsize=curr_fontsizes['label'])
        ylbl = f'Performance Gap' if i % m == 0 else ''
        ax.set_ylabel(ylbl, fontsize=curr_fontsizes['label'])
        col_cat = ds_info.loc[ds_info['name'] == col, 'domain'].unique()[0]
        title = f"$\\it{{{col_cat}}}$\n{col}" if i // m == 0 else col
        ax.set_title(title, fontsize=curr_fontsizes['title'])

        ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))
        y_min, y_max = ax.get_ylim()
        y_range = y_max - y_min
        fmt_str = '%.1f' if y_range > 0.2 else '%.2f'
        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter(fmt_str))
        ax.tick_params(axis='both', which='major', labelsize=curr_fontsizes['ticks'])

        if corr_type == 'spearmanr':
            r_coeff = r_coeffs.loc[(metric_to_consider, col), 'spearmanr']
        else:
            r_coeff = r_coeffs.loc[(metric_to_consider, col), 'pearsonr']
        ax.text(0.95, text_y_pos, f'r coeff.: {r_coeff:.2f}',
                transform=ax.transAxes, fontsize=curr_fontsizes['label'],
                bbox=dict(facecolor='white', alpha=0.5, edgecolor='white', pad=0.2),
                ha='right')

        # Add a colorbar to each subplot
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, cax=cax)
        cbar.ax.tick_params(labelsize=curr_fontsizes['ticks'])

    fig.tight_layout()
    fig.subplots_adjust(wspace=wspace, hspace=hspace)  # Increase spacing to accommodate colorbars

    return fig

#### Plot the scatter plots colored by the maximum performance of the two models

In [22]:
fig = get_scatter_grid_v2('max_model_perf', size_one_box, 'pearsonr')
#save_or_show(fig, storing_path / f'scatter_3_4_pearsonr_grid_cka_linear_max_model_perf.pdf', SAVE)
save_or_show(fig, storing_path / f'scatter_1_4_pearsonr_grid_cka_linear_max_model_perf.pdf', SAVE)

stored img at /home/lciernik/projects/divers-priors/results_local/corr_mats_ds/sec_4_7_perf_vs_sim/scatter_1_4_pearsonr_grid_cka_linear_max_model_perf.pdf.


In [23]:
fig = get_scatter_grid_v2('max_model_perf', size_one_box, 'spearmanr')
#save_or_show(fig, storing_path / f'scatter_3_4_spearmanr_grid_cka_linear_max_model_perf.pdf', SAVE)
save_or_show(fig, storing_path / f'scatter_1_4_spearmanr_grid_cka_linear_max_model_perf.pdf', SAVE)

stored img at /home/lciernik/projects/divers-priors/results_local/corr_mats_ds/sec_4_7_perf_vs_sim/scatter_1_4_spearmanr_grid_cka_linear_max_model_perf.pdf.


#### Plot the scatter plots with all points colored the same or by combination of model categories

In [None]:
# def get_scatter_grid_v1(hue_col, figsize=(9, 7), corr_type='spearmanr'):
#     n, m = 3, 4
#     cm = 0.393701

#     fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(figsize[0] * cm * m, figsize[1] * cm * n), sharex=True,
#                              sharey=False)
#     axes = axes.flatten()

#     for i, (col, ax) in enumerate(zip(ds_order, axes)):
#         group_data = sim_data_new_subset[sim_data_new_subset['DS'] == col]
#         assert group_data['Similarity metric'].nunique() == 1

#         sns.scatterplot(
#             group_data,
#             x='Similarity value',
#             y='abs. diff. perf.',
#             hue=hue_col,
#             alpha=0.5,
#             ax=ax,
#             s=15,
#         )
#         xlbl = 'Similarity value' if i // m == 2 else ''
#         ax.set_xlabel(xlbl, fontsize=curr_fontsizes['label'])

#         ylbl = f'Performance Gap' if i % m == 0 else ''
#         ax.set_ylabel(ylbl, fontsize=curr_fontsizes['label'])

#         col_cat = ds_info.loc[ds_info['name'] == col, 'domain'].unique()[0]
#         # title = f"{col_cat}\n{col}" if i//m == 0 else col
#         title = f"$\\it{{{col_cat}}}$\n{col}" if i // m == 0 else col
#         ax.set_title(title, fontsize=curr_fontsizes['title'])

#         if i == 3 and hue_col:
#             sns.move_legend(ax,
#                             loc='upper left',
#                             title=hue_col,
#                             bbox_to_anchor=(1, 1), fontsize=curr_fontsizes['legend'],
#                             title_fontsize=curr_fontsizes['legend'], frameon=False)
#         elif hue_col:
#             ax.get_legend().remove()

#         ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))
#         ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))
#         ax.tick_params(axis='both',  # Apply to both x and y axes
#                        which='major',  # Apply to major ticks
#                        labelsize=curr_fontsizes['ticks'])
#         if corr_type == 'spearmanr':
#             r_coeff = r_coeffs.loc[('CKA linear', col), 'spearmanr']
#         else:
#             r_coeff = r_coeffs.loc[('CKA linear', col), 'pearsonr']

#         ax.text(0.95, 0.9, f'r coeff.: {r_coeff:.2f}',
#                 transform=ax.transAxes, fontsize=curr_fontsizes['label'],
#                 bbox=dict(facecolor='white', alpha=0.5),
#                 ha='right')  # Align text to the right
#     fig.subplots_adjust(wspace=0.2, hspace=0.1)
#     fig.tight_layout()
#     return fig


In [None]:
# sns.set_style('ticks')
# for corr_type in ['pearsonr', 'spearmanr']:
#     for hue_col, figsize in zip(pair_columns, [(9, 7), (9, 7), (10, 8), (10, 8), (8, 7), (8, 7)]):
#         fig = get_scatter_grid_v1(hue_col, figsize, corr_type)
#         if hue_col:
#             suffix = f'_{hue_col.replace(" ", "_")}'
#         else:
#             suffix = ""
#         save_or_show(fig, storing_path / f'scatter_3_4_{corr_type}_grid_cka_linear{suffix}.pdf', SAVE)