## Notebook to create the subplots for Figure 1
This notebook creates the subplots for Figure 1. It uses the similarity matrices of the models on the datasets to create the heatmaps and scatterplot.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from constants import (
    BASE_PATH_PROJECT,
    BASE_PATH_RESULTS,
    ds_info_file,
    exclude_models,
    exclude_models_w_mae,
    model_config_file
)
from helper import (
    load_all_datasetnames_n_info,
    load_model_configs_and_allowed_models,
    load_similarity_matrices,
    pp_storing_path,
    save_or_show
)

#### Global variables

In [None]:
base_path_similarity_matrices = BASE_PATH_PROJECT / 'model_similarities'
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
]

### Config datasets
ds_list, ds_info = load_all_datasetnames_n_info(ds_info_file, verbose=True)

ds_oi = ['imagenet-subset-10k', 'wds_vtab_flowers']

### Model sets
model_sets = {
    'Image-Text': [
        'OpenCLIP_EVA02-L-14_merged2b_s4b_b131k',
        'OpenCLIP_RN50_openai',
        'OpenCLIP_ViT-L-14_laion2b_s32b_b82k',
        'OpenCLIP_ViT-L-14_laion400m_e32',
        'vit_huge_patch14_clip_224.laion2b'],
    'Self-Supervised': [
        'dino-xcit-medium-24-p16',
        'dino-xcit-small-12-p16',
        'dinov2-vit-large-p14',
        'simclr-rn50',
        'vicreg-rn50'
    ],
    'Supervised': [
        'beit_large_patch16_224.in22k_ft_in22k',
        'deit3_base_patch16_224.fb_in22k_ft_in1k',
        'efficientnet_b7',
        'resnet152',
        'resnet50',
        'vgg19',
        'vit_large_patch16_224'
    ]
}

## Storing info

suffix = ''
# suffix = '_wo_mae'

SAVE = False
storing_path = pp_storing_path(BASE_PATH_RESULTS / f'plots/figure_1_subplots', SAVE)

#### Load model configs and similarity matrices

In [None]:
curr_excl_models = []
if suffix:
    curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)


In [None]:
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)

### Heatmaps and Scatterplot

In [None]:
curr_sim_metric = 'cka_kernel_linear_unbiased'
sim_data = sim_mats[curr_sim_metric]

In [None]:
phi = model_sets['Self-Supervised']
theta = model_sets['Supervised']
print(f'{len(phi)=}, {len(theta)=}')

ds_A = sim_data[ds_oi[0]].loc[phi, theta]
ds_B = sim_data[ds_oi[1]].loc[phi, theta]

vmin = min(ds_A.min().min(), ds_B.min().min())
vmax = max(ds_A.max().max(), ds_B.max().max())

In [None]:
def get_heatmap(df, palette):
    plt.figure(figsize=df.shape)

    sns.heatmap(
        df.T,
        xticklabels=False,
        yticklabels=False,
        cbar=False,
        vmin=vmin,
        vmax=vmax,
        cmap=palette
    )
    return plt.gcf()

#### Heatmaps

In [None]:
fig = get_heatmap(ds_A, palette='Purples')
save_or_show(fig, storing_path / f'ds_A_heatmap.pdf', SAVE)

In [None]:
fig = get_heatmap(ds_B, palette='OrRd')
save_or_show(fig, storing_path / f'ds_B_heatmap.pdf', SAVE)

#### Scatterplot

In [None]:
corr_coef = np.corrcoef(ds_A.values.flatten(), ds_B.values.flatten())[0, 1]
print(corr_coef)

# Create the plot
plt.figure(figsize=(6, 4))
sns.regplot(
    x=ds_B.values.flatten(),
    y=ds_A.values.flatten(),
    color='darkgrey',
    line_kws=dict(alpha=1, ls='--', lw=3),
    scatter_kws=dict(alpha=1, s=75),
    ci=None,
)
sns.despine()

plt.tick_params('both', bottom=False, left=False, labelbottom=False, labelleft=False)
save_or_show(plt.gcf(), storing_path / f'scatter_phi_theta.pdf', SAVE)