## Notebook appendix C: *CKA sensitivity to number of samples in dataset*

In [None]:
import torch
import pandas as pd
import numpy as np
from pathlib import Path

import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

from constants import sim_metric_name_mapping, fontsizes, BASE_PATH_PROJECT, BASE_PATH_RESULTS
from helper import save_or_show, pp_storing_path

In [None]:
# Define datasets
datasets = [f'imagenet-subset-{i}k' for i in [1,5,10,20,30,40]]
model_sim_root_path = BASE_PATH_PROJECT/ 'model_similarities_old_model_set'
model_sim_paths = [model_sim_root_path / dataset for dataset in datasets]
if not all([path.exists() for path in model_sim_paths]):
    raise FileNotFoundError(f"Some experiment directory does not exist!")

In [None]:
SAVE = True
storing_path = pp_storing_path(BASE_PATH_RESULTS / 'plots/nr_samples_for_cka_stability', SAVE)

In [None]:
name_mapping = {k:v for k,v in sim_metric_name_mapping.items() if 'cka' in k}

name_mapping_ds = {
    'imagenet-subset-1k': '1 sample per class', 
    'imagenet-subset-5k': '5 samples per class', 
    'imagenet-subset-10k': '10 samples per class', 
    'imagenet-subset-20k': '20 samples per class', 
    'imagenet-subset-30k': '30 samples per class',
    'imagenet-subset-40k': '40 samples per class', 
}

models_to_exclude = ['SegmentAnything_vit_b']

In [None]:
def get_model_ids(fn):
    with open(fn, 'r') as file:
        lines = file.readlines()
    lines = [line.strip() for line in lines]
    return lines

In [None]:
sim_mats = {}
for model_sim_path in model_sim_paths:
    print(model_sim_path)
    sim_mats[model_sim_path.name] = {}
    for sim_method in model_sim_path.rglob("**/similarity_matrix.pt"):
        if 'cka_kernel' not in str(sim_method):
            continue
        model_ids_fn = sim_method.parent / 'model_ids.txt'
        if model_ids_fn.exists():
            model_ids = get_model_ids(model_ids_fn)
        else:
            raise FileNotFoundError(f'{str(model_ids_fn)} does not exist.')
    
        sim_mat = torch.load(sim_method)
        df = pd.DataFrame(sim_mat, index = model_ids, columns=model_ids)
        for model in models_to_exclude:
            if model in df.columns:
                df = df.drop(model, axis=0)
                df = df.drop(model, axis=1)
        np.fill_diagonal(df.values, 1)
        sim_mats[model_sim_path.name][sim_method.parent.name] = df.copy()

sim_mats = {key: {x:curr_sim_mats[x] for x in name_mapping.keys()} for key, curr_sim_mats in sim_mats.items()}

In [None]:
flattened_dict = {}
# Iterate through the nested dictionary and flatten it
for outer_key, inner_dict in sim_mats.items():
    for inner_key, df in inner_dict.items():
        # Create a new key for the flattened dictionary
        new_key = f"{outer_key}_{inner_key}"
        # Add the dataframe to the flattened dictionary
        flattened_dict[new_key] = df

In [None]:
# Check all matrices have the same index
first_index = next(iter(flattened_dict.values())).index
all_same_index = all(df.index.equals(first_index) for df in flattened_dict.values())
if not all_same_index:
    raise ValueError('All DataFrames must have the same index.')

In [None]:
vmin = min([df.min().min() for df in flattened_dict.values()])
vmax = max([df.max().max() for df in flattened_dict.values()])

## Create heatmap overviews

In [None]:
# n_ds = len(sim_mats)
# n_sim_met = len(sim_mats['imagenet-subset-10k'])
# n_ds, n_sim_met
# fig, axs = plt.subplots(n_ds, n_sim_met, figsize=(n_sim_met * 2, n_ds * 2))
# for i, (ds, ds_sim_mat) in enumerate(sim_mats.items()):
#     for j, (sim_met, df) in enumerate(ds_sim_mat.items()):
#         ax = axs[i, j]
#         sns.heatmap(df, ax=ax, cbar=False, annot=False, vmin=vmin, vmax=vmax)
#         if i == 0:
#             ax.set_title(name_mapping[sim_met], fontsize=14)
#         if j == 0:
#             ax.set_ylabel(name_mapping_ds[ds], fontsize=14)
#         ax.set_xticks([])
#         ax.set_yticks([])
# # fig.suptitle('Similarity matrices for different ImageNet-1k subsets.', fontsize=15)

# # Adjust layout
# plt.tight_layout()
# save_or_show(fig, storing_path / 'cka_nr_samples_mats.pdf', SAVE)

## Create heatmap overviews sorted dendogram

In [None]:
g = sns.clustermap(sim_mats['imagenet-subset-10k']['cka_kernel_linear_unbiased'], annot=True, fmt='.2f')
cka_lin_dendo_row_ordering = g.dendrogram_row.reordered_ind
cka_lin_dendo_col_ordering = g.dendrogram_col.reordered_ind
plt.close()

In [None]:
n_ds = len(sim_mats)
n_sim_met = len(sim_mats['imagenet-subset-10k'])
fig, axs = plt.subplots(n_ds, n_sim_met, figsize=(n_sim_met * 2, n_ds * 2))

for i, (ds, ds_sim_mat) in enumerate(sim_mats.items()):
    for j, (sim_met, df) in enumerate(ds_sim_mat.items()):
        ax = axs[i, j]
        reordered_df = df.iloc[cka_lin_dendo_row_ordering, cka_lin_dendo_col_ordering]
        sns.heatmap(reordered_df, ax=ax, cbar=False, annot=False, vmin=vmin, vmax=vmax)
        if i == 0:
            ax.set_title(name_mapping[sim_met], fontsize=fontsizes['label'])
        if j == 0:
            ax.set_ylabel(name_mapping_ds[ds], fontsize=fontsizes['label'])
        ax.set_xticks([])
        ax.set_yticks([])

# Adjust layout
plt.tight_layout()

save_or_show(fig, storing_path / 'cka_nr_samples_mats_dendogram.pdf', SAVE)

## Create box/violin plots showing the difference plots

In [None]:
flat_triu_sim_mats = {}

for i, (ds, ds_sim_mat) in enumerate(sim_mats.items()):
    for j, (sim_met, df) in enumerate(ds_sim_mat.items()):
        if sim_met not in flat_triu_sim_mats.keys():
            flat_triu_sim_mats[sim_met] = {}

        matrix = df.values
        upper_triangular_indices = np.triu_indices(matrix.shape[0], k=1)
        flat_triu_sim_mats[sim_met][ds] = matrix[upper_triangular_indices]

In [None]:
def subset_size(col_name):
    return col_name.split('-')[-1]

flat_triu_diff_mats = {}
col_subset_comp = 'compared_subsets' 
col_abs_diff = 'absolute_difference'


for sim_met, flat_arr in flat_triu_sim_mats.items():
    curr_df = pd.DataFrame(flat_arr) 
    diff_df = curr_df.diff(axis=1).abs().iloc[:,1:].copy()
    
    all_cols = list(curr_df.columns)
    new_cols = [ f"{subset_size(all_cols[i])} and {subset_size(all_cols[i+1])}" for i in range(0, len(all_cols)-1, 1)]
    diff_df.columns = new_cols

    melted_diff_df = pd.melt(diff_df, 
                            var_name=col_subset_comp,
                            value_name=col_abs_diff)

    flat_triu_diff_mats[sim_met] = melted_diff_df

In [None]:
plot_fn = sns.boxplot
# plot_fn = sns.violinplot

##### Distribution difference for each similarity metric and and nr. of samples

In [None]:
# #### Single model similarity metrics
# cmap = plt.get_cmap('tab10')

# for i, (sim_met, melted_diff_df) in enumerate(flat_triu_diff_mats.items()):
#     plt.figure(figsize=(8,6))
#     plot_fn(data=melted_diff_df, x=col_subset_comp, y=col_abs_diff, color=cmap(i))
#     plt.xlabel('Compared ImageNet-1k subsets', fontsize=11)
#     plt.ylabel('Absolute difference of pairwise model similarities', fontsize=11)
#     plt.title(f'Influence of ImageNet1k subset size on pairwise model similarities ({name_mapping[sim_met]}).', fontsize=12)
#     if SAVE:
#         # plt.savefig(storing_path / f'boxplot_sim_diff_{sim_met}.pdf', bbox_inches='tight')
#         # plt.savefig(storing_path / f'boxplot_sim_diff_{sim_met}.png', bbox_inches='tight')
#         plt.close()
#     else:
#         plt.show()

##### Overview distribution difference for each similarity metric and and nr. of samples

In [None]:
dfs = []
met_col = 'Similarity metrics'
for sim_met, melted_diff_df in flat_triu_diff_mats.items():
    melted_diff_df[met_col] = name_mapping[sim_met]
    dfs.append(melted_diff_df)

In [None]:
all_sim_diffs = pd.concat(dfs)

In [None]:
plt.figure(figsize=(10,5))
sns.boxplot(data=all_sim_diffs, x=col_subset_comp, y=col_abs_diff, hue=met_col)

plt.xticks(fontsize=fontsizes['label'])
plt.yticks(fontsize=fontsizes['ticks'])

plt.xlabel('')
plt.ylabel('Abs. Similarity Difference', fontsize=fontsizes['label'])

plt.legend(fontsize=fontsizes['legend'])
save_or_show(plt.gcf(), storing_path / 'cka_nr_samples_box_diff.pdf', SAVE)