<!-- ## Notebook for section 4.2: *Do representational similarities transfer across datasets?*
This notebook creates the figures for the section 4.2. It shows the similarity matrices for three datasets (ImageNet-1k, Flowers, and PCAM) as well as the mean and std similarity matrices over all 23 datasets. Furthermore, it shows the mean vs. std scatter plot of the tan-transformed CKA linear values and the comparison of different similarity metrics. -->

# TODO fill

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from scipy.cluster.hierarchy import dendrogram, linkage

from constants import (
    BASE_PATH_PROJECT,
    BASE_PATH_RESULTS,
    ds_list_sim_file,
    ds_list_sim_file_30k,
    exclude_models,
    exclude_models_w_mae,
    fontsizes,
    fontsizes_cols,
    model_config_file,
    sim_metric_name_mapping
)
from helper import (
    get_fmt_name,
    load_all_datasetnames_n_info,
    load_model_configs_and_allowed_models,
    load_similarity_matrices,
    pp_storing_path,
    save_or_show
)
ds_list_sim_file, BASE_PATH_RESULTS

#### Global variables

In [None]:
# Define the path to the similarity matrices
# base_path_similarity_matrices = BASE_PATH_PROJECT / 'model_similarities'

# Load used dataset names
ds_list, ds_info = load_all_datasetnames_n_info(ds_list_sim_file, verbose=False)
ds_list_30k, ds_info_30k = load_all_datasetnames_n_info(ds_list_sim_file_30k, verbose=False)

# Define similarity metrics to be used
sim_metrics = {
    'cka_kernel_rbf_unbiased_sigma_0.2' : (BASE_PATH_PROJECT / 'model_similarities_rbf02', ds_list_30k),
    'cka_kernel_rbf_unbiased_sigma_0.4' : (BASE_PATH_PROJECT / 'model_similarities', ds_list),
    'cka_kernel_linear_unbiased': (BASE_PATH_PROJECT / 'model_similarities', ds_list),
    'rsa_method_correlation_corr_method_spearman': (BASE_PATH_PROJECT / 'model_similarities', ds_list),
}

# Define model filtering suffix
suffix = ''  # '_wo_mae'

# Version
version = 'arxiv'  #'arxiv'

# Define storing information
SAVE = True
storing_path = pp_storing_path(BASE_PATH_RESULTS / 'plots' / 'experiment_with_rbf02' / 'mean_cka_comparison', SAVE)

#### Load data

In [None]:
# Load model configurations and allowed models
curr_excl_models = []
if suffix:
    curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)


In [None]:
# Load similarity matrices
sim_mats = {}
for metric, (metric_path, curr_ds_list) in sim_metrics.items():
    sim_mats[metric] = load_similarity_matrices(
        path=metric_path,
        ds_list=curr_ds_list,
        sim_metrics=[metric],
        allowed_models=allowed_models,
    )[metric]

#### Plotting helper functions

In [None]:
def compute_ordering(sim_mat):
    """
    Compute the ordering of the models in the similarity matrix based on the average linkage clustering.
    """
    if not isinstance(sim_mat, pd.DataFrame):
        sim_mat = pd.DataFrame(sim_mat)
    dist_mat = 1 - sim_mat.values
    upper_tri_indices = np.triu_indices_from(dist_mat, k=1)
    linkage_matrix = linkage(dist_mat[upper_tri_indices], method='average')
    idx_new = np.array(dendrogram(linkage_matrix)['leaves'])
    plt.close()
    return idx_new

#### Compute the mean and std similarity matrices using the similarity matrices over all datasets and different preprocessing methods 


In [None]:
mean_sim_mats = {}
std_sim_mats = {}

mean_tanh_sim_mats = {}
std_tanh_sim_mats = {}

mean_tan_sim_mats = {}
std_tan_sim_mats = {}

mean_arccos_sim_mats = {}
std_arccos_sim_mats = {}

ordering = {}


def build_mean_std(df):
    mean = pd.DataFrame(df.mean(axis=0), index=allowed_models, columns=allowed_models)
    std = pd.DataFrame(df.std(axis=0), index=allowed_models, columns=allowed_models)
    return mean, std


for sim_metric in sim_metrics:
    result = np.stack(list(sim_mats[sim_metric].values()), axis=0)
    
    mean_sim_mats[sim_metric], std_sim_mats[sim_metric] = build_mean_std(result)

    ordering[sim_metric] = np.array(allowed_models)[compute_ordering(mean_sim_mats[sim_metric])]

    mean_tanh_sim_mats[sim_metric], std_tanh_sim_mats[sim_metric] = build_mean_std(np.tanh(result))

    mean_tan_sim_mats[sim_metric], std_tan_sim_mats[sim_metric] = build_mean_std(np.tan(result))
    
    mean_arccos_sim_mats[sim_metric], std_arccos_sim_mats[sim_metric] = build_mean_std(np.arccos(result))

## Mean vs. STD scatter plot and comparison of different similarity metrics
The frist scatter plot shows the mean vs. std of the tan-transformed CKA linear values. The other scatter plots show the mean values of between CKA linear and CKA RBF 0.4 (global vs. local) and CKA linear and RSA spearman.

In [None]:
curr_fontsizes = fontsizes if version == 'arxiv' else fontsizes_cols
wspace = 0.3 if version == 'arxiv' else 0.35
height = 2.7 if version == 'arxiv' else 3.1

In [None]:
combs = [
    ('cka_kernel_linear_unbiased', 'cka_kernel_rbf_unbiased_sigma_0.4'),
    ('cka_kernel_linear_unbiased', 'cka_kernel_rbf_unbiased_sigma_0.2'),
    ('cka_kernel_rbf_unbiased_sigma_0.2', 'cka_kernel_rbf_unbiased_sigma_0.4'),
    ('cka_kernel_linear_unbiased', 'rsa_method_correlation_corr_method_spearman')
]

In [None]:
ncols = len(combs) + 1
color_palette = sns.color_palette("viridis", ncols)
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4.1, height))

# Mean vs. std on CKA linear
ax = axs[0]

## TAN results
pp_type = '' 
kernel = 'cka_kernel_linear_unbiased' 
if pp_type == 'tan':
    mean_res = mean_tan_sim_mats[kernel]
    std_res = std_tan_sim_mats[kernel]
    lbl_left =  f"tan(CKA linear)"
elif pp_type == 'tanh':
    mean_res = mean_tanh_sim_mats[kernel]
    std_res = std_tanh_sim_mats[kernel]
    lbl_left =  f"tanh(CKA linear)"
elif pp_type == 'arccos':
    mean_res = mean_arccos_sim_mats[kernel]
    std_res = std_arccos_sim_mats[kernel]
    lbl_left =  f"arccos(CKA linear)"
elif not pp_type:
    mean_res = mean_sim_mats[kernel]
    std_res = std_sim_mats[kernel]
    lbl_left =  f"CKA linear"
else:
    raise ValueError(f"Unknown processing type of CKA values")

iu2 = np.triu_indices(mean_res.shape[0], k=1)

sns.scatterplot(
    x=mean_res.values[iu2],
    y=std_res.values[iu2],
    alpha=0.6,
    ax=ax,
    color=color_palette[0],
)

ax.set_xlabel(f'Mean {lbl_left}', fontsize=curr_fontsizes['label'])
ax.set_ylabel(f'Std {lbl_left}', fontsize=curr_fontsizes['label'])
ax.tick_params('both', labelsize=curr_fontsizes['ticks'])

for i, (x, y) in enumerate(combs, start=1):
    ax = axs[i]
    mean_x = mean_sim_mats[x].values[iu2]
    mean_y = mean_sim_mats[y].values[iu2]

    sns.scatterplot(
        x=mean_x,
        y=mean_y,
        alpha=0.6,
        ax=ax,
        color=color_palette[i],
    )

    ax.set_xlabel(f'Mean {sim_metric_name_mapping[x]}', fontsize=curr_fontsizes['label'])
    ax.set_ylabel(f'Mean {sim_metric_name_mapping[y]}', fontsize=curr_fontsizes['label'])
    ax.tick_params('both', labelsize=curr_fontsizes['ticks'])

    corr, _ = stats.pearsonr(mean_x, mean_y)
    ax.text(0.05, 0.95, f'r coeff.: {corr:.2f}', transform=ax.transAxes,
            verticalalignment='top', fontsize=curr_fontsizes['label'])

plt.tight_layout()
plt.subplots_adjust(wspace=wspace)
save_or_show(fig, storing_path / f'mean_std_{pp_type}_comp_local_global_scatter_plot{suffix}.pdf', SAVE)

## Does local or global similarity differ?

In [None]:
# Datasets
ds_list, ds_info = load_all_datasetnames_n_info(ds_list_sim_file, verbose=False)
ds_list.append('imagenet-subset-30k')

# Experiment configuration
corr_type = 'spearmanr'  # 'pearsonr', 'spearmanr'
suffix = ''  # '', '_wo_mae'

# Path to correlation data
data_path = BASE_PATH_RESULTS / f'aggregated/r_coeff_dist/with_cats_as_anchors/agg_{corr_type}_all_ds.csv'
assert data_path.exists(), f'Path does not exist: {data_path}. Aggregated correlation coefficients across all dataset pairs not found, please run aggregate_consistencies_for_model_set_pairs.ipynb first.'

## Version and plotting info
version = 'arxiv'
curr_fontsizes = fontsizes if version == 'arxiv' else fontsizes_cols

storing_path = pp_storing_path(BASE_PATH_RESULTS / 'plots' / 'experiment_with_rbf02' / 'r_coeff_comparison', SAVE)

In [None]:
r_coeff_data = pd.read_csv(data_path)
r_coeff_data = r_coeff_data[r_coeff_data['ds1'].isin(ds_list) & r_coeff_data['ds2'].isin(ds_list)].reset_index(
    drop=True).copy()

In [None]:
r_coeff_data.loc[r_coeff_data['ds2'].str.startswith('imagenet-subset'), 'ds2'] = 'imagenet-subset'

In [None]:
from scipy import stats

combs = [('CKA linear', 'CKA RBF 0.4'), ('CKA linear', 'CKA RBF 0.2'), ('CKA RBF 0.2', 'CKA RBF 0.4'), ('CKA linear', 'RSA spearman')]
fig, axs = plt.subplots(nrows=1, ncols=len(combs), figsize=(6.5*len(combs), 5))  # Increased width for better visibility

for i, (x, y) in enumerate(combs):
    dat1 = r_coeff_data[r_coeff_data['Similarity metric'] == x]
    dat2 = r_coeff_data[r_coeff_data['Similarity metric'] == y]
    dat1 = dat1.set_index(['ds1', 'ds2', 'anchor_cat', 'other_cat'])
    dat2 = dat2.set_index(['ds1', 'ds2', 'anchor_cat', 'other_cat'])

    dat1.columns = [col + ' sm1' for col in dat1.columns]
    dat2.columns = [col + ' sm2' for col in dat2.columns]
    dat_concat = pd.concat([dat1, dat2], axis=1)

    ax = axs[i]
    sns.scatterplot(data=dat_concat, x="r coeff sm1", y="r coeff sm2", alpha=0.5, s=10, ax=ax)

    ax.set_xlabel(f'r coeff. ({x})', fontsize=curr_fontsizes['label'])
    ax.set_ylabel(f'r coeff. ({y})', fontsize=curr_fontsizes['label'])
    ax.tick_params(labelsize=curr_fontsizes['ticks'])

    r, p = stats.pearsonr(dat_concat['r coeff sm1'], dat_concat['r coeff sm2'])
    ax.text(0.05, 0.95, f'Overall r = {r:.2f}\np-value < 0.001', transform=ax.transAxes,
            verticalalignment='top', fontsize=curr_fontsizes['legend'])

plt.subplots_adjust(wspace=0.25 if version == 'arxiv' else 0.3)
save_or_show(fig, storing_path / f'consistency_local_global_scatter_plot_{corr_type}{suffix}.pdf', SAVE)