In [None]:
import sys
from pathlib import Path

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.ticker as ticker
import scipy
from scipy.cluster.hierarchy import dendrogram, linkage
import textwrap
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec

from constants import (
    exclude_models, 
    exclude_models_w_mae, 
    ds_name_mapping, 
    model_categories, 
    model_cat_mapping, 
    model_config_file, 
    ds_info_file,
    fontsizes
)
from helper import load_model_configs_and_allowed_models, load_similarity_matrices, save_or_show, load_ds_info, get_fmt_name

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
# sim_metrics = similarity_metrics
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
    # 'rsa_method_correlation_corr_method_spearman',
]

ds_list = parse_datasets('../scripts/webdatasets_w_insub10k.txt')
ds_list = list(map(lambda x: x.replace('/', '_'), ds_list))
print(len(ds_list))

ds_info = load_ds_info(ds_info_file)


suffix=''
# suffix = '_wo_mae'

cm = 0.393701

SAVE = True
storing_path = Path('/home/space/diverse_priors/results/plots/mean_std_sim_matrix_v2')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
curr_excl_models = []
if suffix:
    curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)


In [None]:
info_orig_cols = model_categories
info_cols = list(model_cat_mapping.values())

In [None]:
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)

In [None]:
from itertools import product
def compute_ordering(sim_mat):
    if not isinstance(sim_mat, pd.DataFrame):
        sim_mat = pd.DataFrame(sim_mat)
    dist_mat = 1 - sim_mat.values
    upper_tri_indices = np.triu_indices_from(dist_mat, k=1)
    linkage_matrix = linkage(dist_mat[upper_tri_indices], method='average')
    idx_new = np.array(dendrogram(linkage_matrix)['leaves'])
    plt.close()
    return idx_new

In [None]:
mean_sim_mats = {}
std_sim_mats = {}

mean_tanh_sim_mats = {}
std_tanh_sim_mats = {}

mean_tan_sim_mats = {}
std_tan_sim_mats = {}

ordering = {}

def build_mean_std(df):
    mean = pd.DataFrame(df.mean(axis=0), index=allowed_models, columns=allowed_models)
    std =  pd.DataFrame(df.std(axis=0), index=allowed_models, columns=allowed_models)
    return mean, std 


for sim_metric in sim_metrics:
    result = np.stack(list(sim_mats[sim_metric].values()), axis=0)
    mean_sim_mats[sim_metric], std_sim_mats[sim_metric] = build_mean_std(result)
    
    ordering[sim_metric] = np.array(allowed_models)[compute_ordering(mean_sim_mats[sim_metric])]

    mean_tanh_sim_mats[sim_metric], std_tanh_sim_mats[sim_metric] = build_mean_std(np.tanh(result))
    
    mean_tan_sim_mats[sim_metric], std_tan_sim_mats[sim_metric] = build_mean_std(np.tan(result))

### Similarity matrices plots

Create figures with 3 subplots in the first row showing the similarity matrices for three different datasets (one natural images, single domain ds, and structured data). The second row should contain two subplots showing the mean and std of the similarity matrices across all datasets.
We create one figure for each similarity metric.

In [None]:
ds_lists = dict(
    ds_row_1_v2=['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam'],
)

In [None]:
def plot_one_heatmap(df, ordering, title, ax, cmap, vmin=0, vmax=1, cbar=False):
    df = df.loc[ordering, ordering]
    # g = sns.heatmap(df, ax=ax, vmin=vmin, vmax=vmax, cbar=cbar, cmap=cmap, square=True)
    g = sns.heatmap(df, ax=ax, vmin=vmin, vmax=vmax, cbar=cbar, cmap=cmap)
    new_title = "\n(N".join(title.split(' (N'))
    ax.set_title(new_title, fontsize=fontsizes['title'], y=1.01)
    ax.axis('off')
    return g


def get_mat_plot(curr_sim_metrics_data, mean_res, std_res, ds_list, ordering):
    # Create a GridSpec layout with extra columns for the colorbars
    fig = plt.figure(figsize=(3 * (len(ds_list) + 2) + 1.5, 3))
    gs = gridspec.GridSpec(1, len(ds_list) + 4, width_ratios=[0.05] + [1] * (len(ds_list) + 2) + [0.05])
    
    vmin, vmax = 0, 1

    # Plot the heatmaps for each dataset
    axs = []
    for i, ds in enumerate(ds_list):
        if i == 0:
            ax_bar = fig.add_subplot(gs[i])
        ax = fig.add_subplot(gs[i+1])
        df = curr_sim_metrics_data[ds]
        g = plot_one_heatmap(df, ordering, get_fmt_name(ds, ds_info), ax, 'rocket', vmin, vmax)
        if i == 0: 
            cbar = plt.colorbar(g.collections[0], cax=ax_bar, format="{x:.2f}", location='left')
            cbar.outline.set_edgecolor('none')
            cbar.ax.tick_params(labelsize=fontsizes['ticks'])
    
    # Plot the mean and std heatmaps
    mean_ax = fig.add_subplot(gs[-3])
    std_ax = fig.add_subplot(gs[-2])
    std_ax_bar = fig.add_subplot(gs[-1])
    
    
    mean_heatmap = plot_one_heatmap(mean_res, ordering, "Mean CKA over $\mathcal{D}$", mean_ax, 'rocket', vmin, vmax, False)
    std_heatmap = plot_one_heatmap(std_res, ordering, "Std CKA over $\mathcal{D}$", std_ax, 'mako', std_res.min().min(), std_res.max().max(), False)
    cbar = plt.colorbar(std_heatmap.collections[0], cax=std_ax_bar, format="{x:.2f}")
    cbar.outline.set_edgecolor('none') 
    cbar.ax.tick_params(labelsize=fontsizes['ticks'])

    fig.subplots_adjust(wspace=0.1, hspace=0.1)
    return fig


for sim_metric in sim_metrics:
    curr_sim_metrics_data = sim_mats[sim_metric]
    mean_res = mean_sim_mats[sim_metric]
    std_res = std_sim_mats[sim_metric]
    curr_ordering = ordering[sim_metric] 
    for ds_list in ds_lists.values():
        fig = get_mat_plot(curr_sim_metrics_data, mean_res, std_res, ds_list, curr_ordering)
        save_or_show(fig, storing_path / f'mean_std_sim_matrix_{sim_metric}_DS{"_".join(ds_list)}{suffix}.pdf', SAVE)


### Mean vs. STD scatter plot

In [None]:
def plot_mean_std_scatter(mean_data, std_data):
    n = len(sim_metrics)
    fig, axs = plt.subplots(nrows=1, ncols=n, sharex=True, sharey=True, figsize=(n * 14 * cm, 10 * cm))
    for i, curr_sim_metric in enumerate(sim_metrics):
        mean_res = mean_data[curr_sim_metric].values
        std_res = std_data[curr_sim_metric].values
        iu2 = np.triu_indices(mean_res.shape[0], k=1)
        g = sns.scatterplot(
            x=mean_res[iu2],
            y=std_res[iu2],
            alpha=0.6,
            ax=axs[i],
        )
        g.set_xlabel('Mean CKA value across all datasets', fontsize=fontsizes['label'])
        g.set_ylabel('Std of CKA values across all datasets', fontsize=fontsizes['label'])
        g.set_title(sim_metric_name_mapping[curr_sim_metric], fontsize=fontsizes['title'])
        g.tick_params('both', labelsize=fontsizes['ticks'])
    
    fig.tight_layout()
    return fig
    

**Original CKA values**

In [None]:
fig = plot_mean_std_scatter(mean_sim_mats, std_sim_mats)
save_or_show(fig, storing_path / f'mean_std_scatter_plot{suffix}.pdf', SAVE)

**Tanh transformed CKA values**

In [None]:
fig = plot_mean_std_scatter(mean_tanh_sim_mats, std_tanh_sim_mats)
save_or_show(fig, storing_path / f'mean_std_tanh_scatter_plot{suffix}.pdf', SAVE)

**Tan transformed CKA values**

In [None]:
fig = plot_mean_std_scatter(mean_tan_sim_mats, std_tan_sim_mats)
save_or_show(fig, storing_path / f'mean_std_tan_scatter_plot{suffix}.pdf', SAVE)

#### With categories as colors

In [None]:
# pairs = [
#     ('Image-Text', 'Image-Text'),
#     ('Image-Text', 'Self-Supervised'),
#     ('Image-Text', 'Supervised'),
#     ('Self-Supervised', 'Self-Supervised'),
#     ('Self-Supervised', 'Supervised'),
#     ('Supervised', 'Supervised'),
# ]
# comb = []
# for i, val in enumerate(model_configs['objective']):
#     for j, val2 in enumerate(model_configs['objective']):
#         if i >= j:
#             continue
#         if val == val2 or (val, val2) in pairs:
#             comb.append(f"{val}, {val2}")
#         elif (val2, val) in pairs:
#             comb.append(f"{val2}, {val}")
#         else:
#             raise ValueError("Unknown pair")

In [None]:
# n = len(sim_metrics)
# fig, axs = plt.subplots(nrows=1, ncols=n, sharey=True, figsize=(n * 14 * cm, 10 * cm))
# for i, curr_sim_metric in enumerate(sim_metrics):
#     mean_res = mean_sim_mats[curr_sim_metric].values
#     std_res = std_sim_mats[curr_sim_metric].values
#     iu2 = np.triu_indices(mean_res.shape[0], k=1)
#     g = sns.scatterplot(
#         x=mean_res[iu2],
#         y=std_res[iu2],
#         hue=comb,
#         alpha=0.6,
#         ax=axs[i],
#         legend=False if i == 0 else True
#     )
#     g.set_xlabel('Mean CKA value across all datasets', fontsize=10)
#     g.set_ylabel('Std of CKA values across all datasets', fontsize=10)
#     g.set_title(sim_metric_name_mapping[curr_sim_metric], fontsize=12)
#     g.tick_params('both', labelsize=10)
#     if i > 0:
#         sns.move_legend(g, "upper left", bbox_to_anchor=(1, 1), frameon=False, fontsize=10)

# fig.tight_layout()
# save_or_show(fig, storing_path / f'mean_std_scatter_plot{suffix}.pdf', SAVE)