In [None]:
import sys
from pathlib import Path

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.ticker as ticker
import scipy

from sklearn.manifold import MDS, TSNE
from sklearn.cluster import SpectralClustering, AffinityPropagation

from constants import (
    exclude_models, 
    exclude_models_w_mae, 
    ds_name_mapping, 
    model_categories, 
    model_cat_mapping, 
    model_config_file, 
    ds_info_file,
    fontsizes,
    cat_name_mapping,
    sim_metric_name_mapping
)
from helper import load_model_configs_and_allowed_models, load_similarity_matrices, save_or_show, load_ds_info, get_fmt_name

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
# sim_metrics = similarity_metrics
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
    # 'rsa_method_correlation_corr_method_spearman',
]

ds_list = parse_datasets('../scripts/webdatasets_w_insub10k.txt')
ds_list = list(map(lambda x: x.replace('/', '_'), ds_list))

ds_info = load_ds_info(ds_info_file)

ds_oi = ['imagenet-subset-10k', 'wds_vtab_flowers']

suffix=''
# suffix = '_wo_mae'

cm = 0.393701

SAVE = True
storing_path = Path('/home/space/diverse_priors/results/plots/figure_1_subplots')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
curr_excl_models = []
if suffix:
    curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)


In [None]:
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)

In [None]:
# nr_models_per_model_set = [5, 6, 7] 
# model_sets = {}
# for i, (obj, data) in enumerate(model_configs.reset_index(names=['mid']).groupby('objective')):
#     model_sets[obj] = sorted(data['mid'].sample(nr_models_per_model_set[i]).tolist())

In [None]:
model_sets = {
    'Image-Text': [
        'OpenCLIP_EVA02-L-14_merged2b_s4b_b131k',
        'OpenCLIP_RN50_openai',
        'OpenCLIP_ViT-L-14_laion2b_s32b_b82k',
        'OpenCLIP_ViT-L-14_laion400m_e32',
        'vit_huge_patch14_clip_224.laion2b'],
 'Self-Supervised': [
     'dino-xcit-medium-24-p16',
     'dino-xcit-small-12-p16',
     'dinov2-vit-large-p14',
     'dinov2-vit-small-p14',
     'simclr-rn50',
     'vicreg-rn50'
 ],
 'Supervised': [
     'beit_large_patch16_224.in22k_ft_in22k',
     'deit3_base_patch16_224.fb_in22k_ft_in1k',
     'efficientnet_b7',
     'resnet152',
     'resnet50',
     'vgg19',
     'vit_large_patch16_224'
 ]
}

In [None]:
curr_sim_metric  = 'cka_kernel_linear_unbiased'
sim_data = sim_mats[curr_sim_metric]

In [None]:
phi = model_sets['Self-Supervised']
theta = model_sets['Supervised']
print(f'{len(phi)=}, {len(theta)=}')

ds_A = sim_data[ds_oi[0]].loc[phi, theta]
ds_B = sim_data[ds_oi[1]].loc[phi, theta]

vmin = min(ds_A.min().min(), ds_B.min().min())
vmax = max(ds_A.max().max(), ds_B.max().max())

In [None]:
def get_heatmap(df, palette):
    plt.figure(figsize=df.shape)
    
    sns.heatmap(
        df.T,
        xticklabels=False,
        yticklabels=False,
        cbar=False,
        vmin=vmin,
        vmax=vmax,
        cmap=palette
    )
    return plt.gcf()

In [None]:
fig = get_heatmap(ds_A, palette='Purples')
save_or_show(fig, storing_path / f'ds_A_heatmap.pdf', SAVE)

In [None]:
fig = get_heatmap(ds_B, palette='OrRd')
save_or_show(fig, storing_path / f'ds_B_heatmap.pdf', SAVE)

In [None]:
corr_coef = np.corrcoef(ds_A.values.flatten(), ds_B.values.flatten())[0, 1]
print(corr_coef)

# Create the plot
plt.figure(figsize=(6, 4))
sns.regplot(
    x=ds_A.values.flatten(),
    y=ds_B.values.flatten(),
    color='darkgrey',
    line_kws=dict(alpha=1, ls='--', lw=3),
    scatter_kws=dict(alpha=1, s=75),
    ci=None,
)

# Add the correlation coefficient as text on the plot
#plt.text(
#    0.7, 0.1,  # Position of the text (x, y) in axis coordinates
#    f'Pearson r = {corr_coef:.2f}',  # Correlation coefficient rounded to 2 decimal places
#    horizontalalignment='left',
#    verticalalignment='top',
#    transform=plt.gca().transAxes,  # Use axis coordinates (0 to 1)
#    fontsize=fontsizes['label'],
#    bbox=dict(edgecolor='white', facecolor='white', alpha=0.5)  # Optional: Add background to the text
#)

sns.despine()

plt.tick_params('both', bottom=False, left=False, labelbottom=False, labelleft=False)
plt.xlabel('CKA($\phi$, $\\theta$) on ' + ds_info.loc[ds_oi[0], 'name'], fontsize=fontsizes['label'], c='mediumpurple');
plt.ylabel('CKA($\phi$, $\\theta$) on ' + ds_info.loc[ds_oi[1], 'name'], fontsize=fontsizes['label'], c='firebrick');

save_or_show(plt.gcf(), storing_path / f'scatter_phi_theta.pdf', SAVE)