## Notebook for section 4.2: *Do representational similarities transfer across datasets?*
This notebook creates the figures for the section 4.2. It shows the similarity matrices for three datasets (ImageNet-1k, Flowers, and PCAM) as well as the mean and std similarity matrices over all 23 datasets. Furthermore, it shows the mean vs. std scatter plot of the tan-transformed CKA linear values and the comparison of different similarity metrics.

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from scipy.cluster.hierarchy import dendrogram, linkage

from constants import (
    BASE_PATH_PROJECT,
    BASE_PATH_RESULTS,
    ds_list_sim_file,
    exclude_models,
    exclude_models_w_mae,
    fontsizes,
    fontsizes_cols,
    model_config_file,
    sim_metric_name_mapping
)
from helper import (
    get_fmt_name,
    load_all_datasetnames_n_info,
    load_model_configs_and_allowed_models,
    load_similarity_matrices,
    pp_storing_path,
    save_or_show
)


#### Global variables

In [None]:
# Define the path to the similarity matrices
base_path_similarity_matrices = BASE_PATH_PROJECT / 'model_similarities'

# Define similarity metrics to be used
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
    'rsa_method_correlation_corr_method_spearman',
]

# Load used dataset names
ds_list, ds_info = load_all_datasetnames_n_info(ds_list_sim_file, verbose=False)

# Define model filtering suffix
suffix = ''  # '_wo_mae'

# Version
version = 'arxiv'  #'arxiv'

# Define storing information
SAVE = True
storing_path = pp_storing_path(BASE_PATH_RESULTS / 'plots' / 'final' / version / 'sec_4_2_sim_mats', SAVE)

#### Load data

In [None]:
# Load model configurations and allowed models
curr_excl_models = []
if suffix:
    curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)


In [None]:
# Load similarity matrices
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)

#### Plotting helper functions

In [None]:
def compute_ordering(sim_mat):
    """
    Compute the ordering of the models in the similarity matrix based on the average linkage clustering.
    """
    if not isinstance(sim_mat, pd.DataFrame):
        sim_mat = pd.DataFrame(sim_mat)
    dist_mat = 1 - sim_mat.values
    upper_tri_indices = np.triu_indices_from(dist_mat, k=1)
    linkage_matrix = linkage(dist_mat[upper_tri_indices], method='average')
    idx_new = np.array(dendrogram(linkage_matrix)['leaves'])
    plt.close()
    return idx_new

#### Compute the mean and std similarity matrices using the similarity matrices over all datasets and different preprocessing methods 


In [None]:
mean_sim_mats = {}
std_sim_mats = {}

mean_tanh_sim_mats = {}
std_tanh_sim_mats = {}

mean_tan_sim_mats = {}
std_tan_sim_mats = {}

mean_arccos_sim_mats = {}
std_arccos_sim_mats = {}

ordering = {}


def build_mean_std(df):
    mean = pd.DataFrame(df.mean(axis=0), index=allowed_models, columns=allowed_models)
    std = pd.DataFrame(df.std(axis=0), index=allowed_models, columns=allowed_models)
    return mean, std


for sim_metric in sim_metrics:
    result = np.stack(list(sim_mats[sim_metric].values()), axis=0)
    
    mean_sim_mats[sim_metric], std_sim_mats[sim_metric] = build_mean_std(result)

    ordering[sim_metric] = np.array(allowed_models)[compute_ordering(mean_sim_mats[sim_metric])]

    mean_tanh_sim_mats[sim_metric], std_tanh_sim_mats[sim_metric] = build_mean_std(np.tanh(result))

    mean_tan_sim_mats[sim_metric], std_tan_sim_mats[sim_metric] = build_mean_std(np.tan(result))
    
    mean_arccos_sim_mats[sim_metric], std_arccos_sim_mats[sim_metric] = build_mean_std(np.arccos(result))

### Similarity matrices for three datasets as well as mean and std similarity matrix over all datasets
The following cells construct and plot the similarity matrices (CKA linear) for three datasets (ImageNet-1k, Flowers, and PCAM) as well as the mean and std similarity matrices over all 23 datasets. 

In [None]:
ds_lists = dict(
    ds_row_1_v2=['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam'],
)

In [None]:
def plot_one_heatmap(df, ordering, title, ax, cmap, vmin=0, vmax=1, cbar=False):
    df = df.loc[ordering, ordering]
    g = sns.heatmap(df, ax=ax, vmin=vmin, vmax=vmax, cbar=cbar, cmap=cmap)
    new_title = "\n(N".join(title.split(' (N'))
    ax.set_title(new_title, fontsize=fontsizes['title'], y=1.01)
    ax.axis('off')
    return g


def add_bounding_box(ax, start_row, start_col, end_row, end_col, color='red', linewidth=2):
    rect = patches.Rectangle((start_col, start_row), end_col - start_col, end_row - start_row,
                             fill=False, edgecolor=color, linewidth=linewidth)
    ax.add_patch(rect)


def add_all_boxes(ax, color_boxes):
    for pos, col in zip(color_boxes, ['yellow', 'white', 'cyan']):
        add_bounding_box(ax, *pos, color=col)


def get_mat_plot(curr_sim_metrics_data, mean_res, std_res, ds_list, ordering, color_boxes):
    # Create a GridSpec layout with extra columns for the colorbars
    fig = plt.figure(figsize=(3 * (len(ds_list) + 2) + 1.5, 2.7))

    gs = gridspec.GridSpec(1, len(ds_list) + 4, width_ratios=[0.05] + [1] * (len(ds_list) + 2) + [0.05])

    vmin, vmax = 0, 1

    # Plot the heatmaps for each dataset
    axs = []
    for i, ds in enumerate(ds_list):
        if i == 0:
            ax_bar = fig.add_subplot(gs[i])
        ax = fig.add_subplot(gs[i + 1])
        df = curr_sim_metrics_data[ds]
        g = plot_one_heatmap(df, ordering, get_fmt_name(ds, ds_info), ax, 'rocket', vmin, vmax)
        if i == 0:
            cbar = plt.colorbar(g.collections[0], cax=ax_bar, format="{x:.2f}", location='left')
            cbar.outline.set_edgecolor('none')
            cbar.ax.tick_params(labelsize=fontsizes['ticks'])
        add_all_boxes(ax, color_boxes)

    # Plot the mean and std heatmaps
    mean_ax = fig.add_subplot(gs[-3])
    std_ax = fig.add_subplot(gs[-2])
    std_ax_bar = fig.add_subplot(gs[-1])

    mean_heatmap = plot_one_heatmap(mean_res, ordering, "Mean CKA over $\mathcal{D}$", mean_ax, 'rocket', vmin, vmax,
                                    False)
    add_all_boxes(mean_ax, color_boxes)
    std_heatmap = plot_one_heatmap(std_res, ordering, "Std CKA over $\mathcal{D}$", std_ax, 'mako', std_res.min().min(),
                                   std_res.max().max(), False)
    add_all_boxes(std_ax, color_boxes)
    cbar = plt.colorbar(std_heatmap.collections[0], cax=std_ax_bar, format="{x:.2f}")
    cbar.outline.set_edgecolor('none')
    cbar.ax.tick_params(labelsize=fontsizes['ticks'])

    fig.subplots_adjust(wspace=0.1, hspace=0.1)
    return fig


bboxes = {
    'cka_kernel_rbf_unbiased_sigma_0.4': [(11, 11, 25, 25), (27, 27, 38, 38), (40, 40, 51, 51)],
    'cka_kernel_linear_unbiased': [(7, 7, 21, 21), (23, 23, 34, 34), (40, 40, 51, 51)],
}

for sim_metric in ['cka_kernel_linear_unbiased']:
    print(sim_metric)
    curr_sim_metrics_data = sim_mats[sim_metric]
    mean_res = mean_sim_mats[sim_metric]
    std_res = std_sim_mats[sim_metric]
    curr_ordering = ordering[sim_metric]
    color_boxes = bboxes[sim_metric]
    for ds_list in ds_lists.values():
        fig = get_mat_plot(curr_sim_metrics_data, mean_res, std_res, ds_list, curr_ordering, color_boxes)
        save_or_show(fig, storing_path / f'mean_std_sim_matrix_{sim_metric}_DS{"_".join(ds_list)}{suffix}.pdf', SAVE)

### Mean vs. STD scatter plot and comparison of different similarity metrics
The frist scatter plot shows the mean vs. std of the tan-transformed CKA linear values. The other scatter plots show the mean values of between CKA linear and CKA RBF 0.4 (global vs. local) and CKA linear and RSA spearman.

In [None]:
curr_fontsizes = fontsizes if version == 'arxiv' else fontsizes_cols
wspace = 0.3 if version == 'arxiv' else 0.35
height = 2.7 if version == 'arxiv' else 3.1

In [None]:
color_palette = sns.color_palette("viridis", 3)

fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(3 * 4.1, height))

# Mean vs. std on CKA linear
ax = axs[0]

## TAN results
pp_type = '' 
kernel = 'cka_kernel_linear_unbiased' 
if pp_type == 'tan':
    mean_res = mean_tan_sim_mats[kernel]
    std_res = std_tan_sim_mats[kernel]
    lbl_left =  f"tan(CKA linear)"
elif pp_type == 'tanh':
    mean_res = mean_tanh_sim_mats[kernel]
    std_res = std_tanh_sim_mats[kernel]
    lbl_left =  f"tanh(CKA linear)"
elif pp_type == 'arccos':
    mean_res = mean_arccos_sim_mats[kernel]
    std_res = std_arccos_sim_mats[kernel]
    lbl_left =  f"arccos(CKA linear)"
elif not pp_type:
    mean_res = mean_sim_mats[kernel]
    std_res = std_sim_mats[kernel]
    lbl_left =  f"CKA linear"
else:
    raise ValueError(f"Unknown processing type of CKA values")

iu2 = np.triu_indices(mean_res.shape[0], k=1)

sns.scatterplot(
    x=mean_res.values[iu2],
    y=std_res.values[iu2],
    alpha=0.6,
    ax=ax,
    color=color_palette[0],
)

ax.set_xlabel(f'Mean {lbl_left}', fontsize=curr_fontsizes['label'])
ax.set_ylabel(f'Std {lbl_left}', fontsize=curr_fontsizes['label'])
ax.tick_params('both', labelsize=curr_fontsizes['ticks'])

combs = [
    ('cka_kernel_linear_unbiased', 'cka_kernel_rbf_unbiased_sigma_0.4'),
    ('cka_kernel_linear_unbiased', 'rsa_method_correlation_corr_method_spearman')
]

for i, (x, y) in enumerate(combs, start=1):
    ax = axs[i]
    mean_x = mean_sim_mats[x].values[iu2]
    mean_y = mean_sim_mats[y].values[iu2]

    sns.scatterplot(
        x=mean_x,
        y=mean_y,
        alpha=0.6,
        ax=ax,
        color=color_palette[i],
    )

    ax.set_xlabel(f'Mean {sim_metric_name_mapping[x]}', fontsize=curr_fontsizes['label'])
    ax.set_ylabel(f'Mean {sim_metric_name_mapping[y]}', fontsize=curr_fontsizes['label'])
    ax.tick_params('both', labelsize=curr_fontsizes['ticks'])

    corr, _ = stats.pearsonr(mean_x, mean_y)
    ax.text(0.05, 0.95, f'r coeff.: {corr:.2f}', transform=ax.transAxes,
            verticalalignment='top', fontsize=curr_fontsizes['label'])

plt.tight_layout()
plt.subplots_adjust(wspace=wspace)
save_or_show(fig, storing_path / f'mean_std_{pp_type}_comp_local_global_scatter_plot{suffix}.pdf', SAVE)