In [None]:
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from helper import load_model_configs_and_allowed_models, load_similarity_matrices, save_or_show

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
# sim_metrics = similarity_metrics
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
    # 'rsa_method_correlation_corr_method_spearman',
]

y_axis_ds = parse_datasets('../scripts/webdatasets_wo_imagenet.txt')
y_axis_ds = list(map(lambda x: x.replace('/', '_'), y_axis_ds))
ds_list = ['imagenet-subset-10k'] + y_axis_ds

SAVE = False
storing_path = Path('/home/space/diverse_priors/results/plots/mean_std_sim_matrix')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
model_configs, allowed_models = load_model_configs_and_allowed_models(
    path='../scripts/models_config.json',
    exclude_models=['SegmentAnything_vit_b', 'DreamSim_dino_vitb16', 'DreamSim_open_clip_vitb32'],
    exclude_alignment=True,
)

print(model_configs.shape, len(allowed_models))

In [None]:
info_orig_cols = ['objective', 'architecture_class', 'dataset_class', 'size_class']
info_cols = ['Objective', 'Architecture', 'Dataset size', 'Model size']

In [None]:
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)

In [None]:
mean_sim_mats = {}
std_sim_mats = {}
for sim_metric in sim_metrics:
    result = np.stack(list(sim_mats[sim_metric].values()), axis=0)
    mean_res = pd.DataFrame(result.mean(axis=0), index=allowed_models, columns=allowed_models)
    std_res = pd.DataFrame(result.std(axis=0), index=allowed_models, columns=allowed_models)
    mean_sim_mats[sim_metric] = mean_res
    std_sim_mats[sim_metric] = std_res

### Similarity matrices plots

Create figures with 3 subplots in the first row showing the similarity matrices for three different datasets (one natural images, single domain ds, and structured data). The second row should contain two subplots showing the mean and std of the similarity matrices across all datasets.
We create one figure for each similarity metric.

In [None]:
ds_row_1_v1 = ['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_eurosat']
ds_row_1_v2 = ['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam']
ds_row_1_v3 = ['imagenet-subset-10k', 'wds_vtab_pets', 'wds_vtab_eurosat']
ds_row_1_v4 = ['imagenet-subset-10k', 'wds_vtab_pets', 'wds_vtab_pcam']

In [None]:
curr_sim_metric = sim_metrics[0]

curr_sim_metrics_data = sim_mats[curr_sim_metric]
mean_res = mean_sim_mats[curr_sim_metric]
std_res = std_sim_mats[curr_sim_metric]

In [None]:
# Create a figure
fig = plt.figure(figsize=(10, 6))

# Create a GridSpec with 2 rows and 6 columns (to allow for half-column spans)
gs = gridspec.GridSpec(2, 6)

# First row: 3 plots, each spanning 2 columns
axs = []
axs.append(fig.add_subplot(gs[0, :2]))  # First column (spanning 2 grid spaces)
axs.append(fig.add_subplot(gs[0, 2:4])) # Second column (spanning 2 grid spaces)
axs.append(fig.add_subplot(gs[0, 4:]))  # Third column (spanning 2 grid spaces)

# Second row: 2 plots spanning different parts of the columns
axs.append(fig.add_subplot(gs[1, :3]))  # Spans first column and first half of second column
axs.append(fig.add_subplot(gs[1, 3:]))  # Spans second half of second column and third column


# Plotting data on each subplot
for i, ds in enumerate(ds_row_1_v1):
    sns.heatmap(curr_sim_metrics_data[ds], ax=axs[i], cbar=False)
    axs[i].set_title(ds)


sns.heatmap(mean_res, ax=axs[3], cbar=False)
axs[3].set_title('Mean similarity matrix across all datasets')

sns.heatmap(mean_res, ax=axs[4], cbar=False)
axs[4].set_title('Mean similarity matrix across all datasets')

# Adjust layout
plt.tight_layout()

# Save or show
save_or_show(SAVE, storing_path, f'mean_std_sim_matrix_{curr_sim_metric}')

### Mean vs. STD scatter plot

In [None]:
# curr_sim_metric = sim_metrics[2]
# result = np.stack(list(sim_mats[curr_sim_metric].values()), axis=0)

In [None]:
# mean_res = pd.DataFrame(result.mean(axis=0), index=allowed_models, columns=allowed_models)
# std_res = pd.DataFrame(result.std(axis=0), index=allowed_models, columns=allowed_models)

In [None]:
pairs = [
    ('Image-Text', 'Image-Text'),
    ('Image-Text', 'Self-Supervised'),
    ('Image-Text', 'Supervised'),
    ('Self-Supervised', 'Self-Supervised'),
    ('Self-Supervised', 'Supervised'),
    ('Supervised', 'Supervised'),
]
comb = []
for i, val in enumerate(model_configs['objective']):
    for j, val2 in enumerate(model_configs['objective']):
        if i >= j:
            continue
        if val == val2 or (val, val2) in pairs:
            comb.append(f"{val}, {val2}")
        elif (val2, val) in pairs:
            comb.append(f"{val2}, {val}")
        else:
            raise ValueError("Unknown pair")


In [None]:
cm = 0.393701
n = len(sim_metrics)
fig, axs = plt.subplots(nrows=1, ncols=n, sharey=True, figsize=(n*10*cm, 10*cm))
for i, curr_sim_metric in enumerate(sim_metrics):
    mean_res = mean_sim_mats[curr_sim_metric]
    std_res = std_sim_mats[curr_sim_metric]
    mean_upper = np.triu(mean_res, k=1)
    std_upper = np.triu(std_res, k=1)
    g = sns.scatterplot(
        x=mean_upper[mean_upper != 0],
        y=std_upper[std_upper != 0],
        hue=comb,
        alpha=0.75,
        ax=axs[i]
    )
    g.set_xlabel('Mean CKA value across all datasets')
    g.set_ylabel('Std of CKA values across all datasets')
    g.set_title(sim_metric_name_mapping[curr_sim_metric])

fig.tight_layout()
save_or_show(SAVE, storing_path, 'mean_std_scatter_plot')

In [None]:
# sns.jointplot( x=mean_upper[mean_upper != 0], y=std_upper[std_upper != 0], hue=comb)
mean_upper = np.triu(mean_res, k=1)
std_upper = np.triu(std_res, k=1)
g = sns.scatterplot(
    x=mean_upper[mean_upper != 0],
    y=std_upper[std_upper != 0],
    hue=comb,
    alpha=0.75
)
g.set_xlabel('Mean CKA value over all datasets')
g.set_ylabel('Std of CKA values over all datasets')
g.set_title(sim_metric_name_mapping[curr_sim_metric])

In [None]:
mean_upper = np.triu(mean_res, k=1)
std_upper = np.triu(std_res, k=1)
g = sns.scatterplot(
    x=mean_upper[mean_upper != 0],
    y=std_upper[std_upper != 0],
    hue=comb,
    alpha=0.75
)
g.set_xlabel('Mean CKA value over all datasets')
g.set_ylabel('Std of CKA values over all datasets')
g.set_title(curr_sim_metric)

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(24, 12))
sns.heatmap(mean_res, ax=axs[0])
axs[0].set_title(f'Mean {sim_metric_name_mapping[curr_sim_metric]} matrix over all datasets')
sns.heatmap(std_res, ax=axs[1])
fig.tight_layout()
axs[1].set_title(f'STD {sim_metric_name_mapping[curr_sim_metric]} matrix over all datasets')

In [None]:
similarity_mat = mean_res
dissimilarity_mat = (1 - mean_res)

In [None]:
from sklearn.cluster import SpectralClustering, AffinityPropagation, DBSCAN
from sklearn.manifold import TSNE, MDS

In [None]:
clustering_cols = []
clustering = SpectralClustering(n_clusters=6,
                                affinity='precomputed',
                                assign_labels='cluster_qr',
                                random_state=0)

model_configs['spec_labels'] = clustering.fit_predict(similarity_mat.values, y=None)
clustering_cols.append('spec_labels')

In [None]:
clustering = AffinityPropagation(damping=0.85,
                                 affinity='precomputed',
                                 random_state=5)

model_configs['aff_labels'] = clustering.fit_predict(similarity_mat.values, y=None)
clustering_cols.append('aff_labels')

In [None]:
clustering = DBSCAN(eps=3, min_samples=2, metric='precomputed')

model_configs['dbs_labels'] = clustering.fit_predict(dissimilarity_mat.values, y=None)
clustering_cols.append('dbs_labels')

In [None]:
embs = {
    'tsne': TSNE(n_components=2,
                 learning_rate='auto',
                 init='random',
                 perplexity=10,
                 metric='precomputed').fit_transform(dissimilarity_mat.values),
    'mds': MDS(n_components=2,
               normalized_stress='auto',
               dissimilarity='precomputed').fit_transform(dissimilarity_mat.values),
}

In [None]:
from itertools import product

In [None]:
nrows = len(info_orig_cols)
ncols = len(embs)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(4 * ncols, 4 * nrows))
axes = axes.flatten()

for i, (m_cat_col, embed_type) in enumerate(product(info_orig_cols, embs.keys())):
    ax = axes[i]
    curr_embed = embs[embed_type]

    sns.scatterplot(
        x=curr_embed[:, 0],
        y=curr_embed[:, 1],
        hue=model_configs[m_cat_col],
        palette='tab10',
        ax=ax,
        s=100
    )
    ax.set_title(f"{info_cols[i // ncols]} – {embed_type.upper()}")
    if i % ncols == 0:
        ax.get_legend().remove()
    else:
        sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), title=info_cols[i // ncols])



In [None]:
nrows = len(clustering_cols)
ncols = len(embs)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(4 * ncols, 4 * nrows))
axes = axes.flatten()
for i, (clust_col, embed_type) in enumerate(product(clustering_cols, embs.keys())):
    ax = axes[i]
    curr_embed = embs[embed_type]
    sns.scatterplot(
        x=curr_embed[:, 0],
        y=curr_embed[:, 1],
        hue=model_configs[clust_col],
        palette='tab10',
        ax=ax,
        s=100
    )
    ax.set_title(f"{clust_col} – {embed_type.upper()}")
    if i % ncols == 0:
        ax.get_legend().remove()
    else:
        sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), title=clust_col)