In [None]:
import torch
import pandas as pd
import numpy as np
from pathlib import Path

import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

In [None]:
dataset = 'imagenet-subset-10k'
model_sim_root_path = Path('/home/space/diverse_priors/model_similarities')
model_sim_path = model_sim_root_path / dataset
if not model_sim_path.exists():
    raise FileNotFoundError(f"Directory {str(model_sim_path)} does not exists!")

In [None]:
SAVE = True
storing_path = model_sim_path / 'plots'

if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
storing_path

In [None]:
name_mapping = {
    'cka_kernel_rbf_unbiased_sigma_0.2': 'CKA RBF 0.2',
    'cka_kernel_rbf_unbiased_sigma_0.4': 'CKA RBF 0.4',
    'cka_kernel_rbf_unbiased_sigma_0.6': 'CKA RBF 0.6',
    'cka_kernel_rbf_unbiased_sigma_0.8': 'CKA RBF 0.8',
    'cka_kernel_linear_unbiased': 'CKA linear',
    'rsa_method_correlation_corr_method_pearson': 'RSA pearson',
    'rsa_method_correlation_corr_method_spearman': 'RSA spearman',
}

models_to_exclude = ['SegmentAnything_vit_b']

In [None]:
def get_model_ids(fn):
    with open(fn, 'r') as file:
        lines = file.readlines()
    lines = [line.strip() for line in lines]
    return lines

In [None]:
sim_mats = {}
storing_paths = {}
for sim_method in model_sim_path.rglob("**/similarity_matrix.pt"):
    print(sim_method)
    model_ids_fn = sim_method.parent / 'model_ids.txt'
    if model_ids_fn.exists():
        model_ids = get_model_ids(model_ids_fn)
    else:
        raise FileNotFoundError(f'{str(model_ids_fn)} does not exist.')

    sim_mat = torch.load(sim_method)
    df = pd.DataFrame(sim_mat, index = model_ids, columns=model_ids)
    print(df.shape)
    for model in models_to_exclude:
        if model in df.columns:
            df = df.drop(model, axis=0)
            df = df.drop(model, axis=1)
    print(df.shape)   
    np.fill_diagonal(df.values, 1)
    sim_mats[sim_method.parent.name] = df.copy()


sim_mats = {x:sim_mats[x] for x in name_mapping.keys() if x in sim_mats.keys()}

In [None]:
# Check all matrices have the same index
first_index = next(iter(sim_mats.values())).index
all_same_index = all(df.index.equals(first_index) for df in sim_mats.values())
if not all_same_index:
    raise ValueError('All DataFrames must have the same index.')

In [None]:
vmin = min([df.min().min() for df in sim_mats.values()])
vmax = max([df.max().max() for df in sim_mats.values()])

## Create single method heatmaps

In [None]:
for sim_method, sim_mat in sim_mats.items():
    plt.figure(figsize=(32.5,25))
    np.fill_diagonal(sim_mat.values, 1)
    g = sns.heatmap(sim_mat, annot=True, fmt='.2f')
    g.set_title(f"{dataset.upper()} {name_mapping[sim_method]}", fontsize=18)
    g.tick_params(axis='both', which='major', labelsize=16)
    g.tick_params(axis='both', which='minor', labelsize=12)
    plt.tight_layout()
    if SAVE:
        plt.savefig(storing_path / f'{sim_method}_sim_mat_heatmap.pdf')
        plt.close()
    else:
        plt.show()

## Create single method dendogram

In [None]:
cka_lin_dendo_row_ordering = None
cka_lin_dendo_col_ordering = None
for sim_method, sim_mat in sim_mats.items():
    np.fill_diagonal(sim_mat.values, 1)
    g = sns.clustermap(sim_mat, annot=True, fmt='.2f',figsize=(31,30))
    if sim_method == 'cka_kernel_linear_unbiased':
        cka_lin_dendo_row_ordering = g.dendrogram_row.reordered_ind
        cka_lin_dendo_col_ordering = g.dendrogram_col.reordered_ind
    g.fig.suptitle(f"{dataset.upper()} {name_mapping[sim_method]}", fontsize=20)
    g.tick_params(axis='both', which='major', labelsize=16)
    g.tick_params(axis='both', which='minor', labelsize=12)
    #plt.tight_layout()
    if SAVE:
        plt.savefig(storing_path / f'{sim_method}_sim_mat_dendogram.pdf')
        plt.close()
    else:
        plt.show()

## Create heatmap overviews

In [None]:
n = len(sim_mats)
fig, axs = plt.subplots(1, n, figsize=(n * 3, 3.25))

if n == 1:
    axs = [axs]

for ax, (key, df) in zip(axs, sim_mats.items()):
    sns.heatmap(df, ax=ax, cbar=False, annot=False, vmin=vmin, vmax=vmax)
    ax.set_title(name_mapping[key], fontsize=16)
    ax.set_xticks([])
    ax.set_yticks([])

# Adjust layout
plt.tight_layout()
if SAVE:
    plt.savefig(storing_path / 'all_methods_sim_mat_heatmap.pdf')
    plt.savefig(storing_path / 'all_methods_sim_mat_heatmap.png')
plt.show()

In [None]:
n = len(sim_mats)
fig, axs = plt.subplots(1, n, figsize=(n * 3, 3.25))

if n == 1:
    axs = [axs]

for ax, (key, df) in zip(axs, sim_mats.items()):
    reordered_df = df.iloc[cka_lin_dendo_row_ordering, cka_lin_dendo_col_ordering]
    sns.heatmap(reordered_df, ax=ax, cbar=False, annot=False, vmin=vmin, vmax=vmax)
    ax.set_title(name_mapping[key], fontsize=16)
    ax.set_xticks([])
    ax.set_yticks([])

# Adjust layout
plt.tight_layout()
if SAVE:
    plt.savefig(storing_path / 'all_methods_sim_mat_heatmap_cka_lin_dendo_ordering.pdf')
    plt.savefig(storing_path / 'all_methods_sim_mat_heatmap_cka_lin_dendo_ordering.png')
plt.show()