## Notebook for appendix D: *Variation of CKA values near the upper bound*
This notebook creates the figure for appendix section D. It shows the mean vs. std scatter plot for the tan-transformed and arccos-transformed CKA values.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from constants import (
    BASE_PATH_PROJECT,
    BASE_PATH_RESULTS,
    ds_list_sim_file,
    exclude_models,
    exclude_models_w_mae,
    fontsizes,
    fontsizes_cols,
    model_config_file
)
from helper import (
    load_all_datasetnames_n_info,
    load_model_configs_and_allowed_models,
    load_similarity_matrices,
    pp_storing_path,
    save_or_show
)


#### Global variables

In [None]:
# Define the path to the similarity matrices
base_path_similarity_matrices = BASE_PATH_PROJECT / 'model_similarities'

# Define similarity metrics to be used
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
    'rsa_method_correlation_corr_method_spearman',
]

# Load used dataset names
ds_list, ds_info = load_all_datasetnames_n_info(ds_list_sim_file, verbose=False)

# Define model filtering suffix
suffix = ''  # '_wo_mae'

# Version
version = 'arxiv'  #'arxiv'

# Define storing information
SAVE = True
storing_path = pp_storing_path(BASE_PATH_RESULTS / 'plots' / 'final' / version / 'app_D_inv_u_shape', SAVE)

#### Load data

In [None]:
# Load model configurations and allowed models
curr_excl_models = []
if suffix:
    curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)


In [None]:
# Load similarity matrices
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)

#### Compute the mean and std similarity matrices using the similarity matrices over all datasets and different preprocessing methods 


In [None]:
mean_stds = {}


def build_mean_std(df):
    mean = pd.DataFrame(df.mean(axis=0), index=allowed_models, columns=allowed_models)
    std = pd.DataFrame(df.std(axis=0), index=allowed_models, columns=allowed_models)
    return mean, std


for sim_metric in sim_metrics:
    result = np.stack(list(sim_mats[sim_metric].values()), axis=0)

    mean_stds[sim_metric] = {
        'arccos': build_mean_std(np.arccos(result)),
        'tan': build_mean_std(np.tan(result)),
    }

### Mean vs. STD scatter plot and comparison of different similarity metrics
The frist scatter plot shows the mean vs. std of the tan-transformed CKA linear values, while the second shows the relationship for arccos-transformed CKA values.

In [None]:
curr_fontsizes = fontsizes if version == 'arxiv' else fontsizes_cols
wspace = 0.3 if version == 'arxiv' else 0.35
height = 2.7 if version == 'arxiv' else 3.1

In [None]:
color_palette = sns.color_palette("viridis", 3)

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(2 * 4.1, height), sharey=False, sharex=False)

for i, (pp_type, (mean_res, std_res)) in enumerate(mean_stds['cka_kernel_linear_unbiased'].items()):
    ax = axs[i]
    iu2 = np.triu_indices(mean_res.shape[0], k=1)
    sns.scatterplot(
        x=mean_res.values[iu2],
        y=std_res.values[iu2],
        alpha=0.6,
        ax=ax,
        color=color_palette[i],
    )

    ax.set_xlabel(f'Mean {pp_type}(CKA linear)', fontsize=curr_fontsizes['label'])
    ax.set_ylabel(f'Std {pp_type}(CKA linear)', fontsize=curr_fontsizes['label'])
    ax.tick_params('both', labelsize=curr_fontsizes['ticks'])

plt.tight_layout()
plt.subplots_adjust(wspace=wspace)
save_or_show(fig, storing_path / f'mean_std_transf_scatter_plot{suffix}.pdf', SAVE)