In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from scipy.stats import spearmanr
import json 
from helper import save_or_show, plot_r_coeff_distribution
from collections import defaultdict
from scipy.stats import ranksums
from statsmodels.stats.multitest import multipletests
import starbars

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
# sim_metrics = similarity_metrics
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
    'rsa_method_correlation_corr_method_spearman',
]

x_axis_ds = 'imagenet-subset-10k'
y_axis_ds = parse_datasets('../scripts/webdatasets_wo_imagenet.txt')
y_axis_ds = list(map(lambda x: x.replace('/', '_'), y_axis_ds))

storing_path = Path('/home/space/diverse_priors/results/aggregated/r_coeff_dist')
storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
with open('../scripts/models_config.json','r') as f: 
    model_configs = json.load(f)

for mid in ['SegmentAnything_vit_b', 'DreamSim_dino_vitb16', 'DreamSim_open_clip_vitb32']:
    if mid in model_configs.keys():
        model_configs.pop(mid)

# allowed_models = sorted(list(model_configs.keys()))
allowed_models = sorted([ k for k, v in model_configs.items() if v['alignment'] is None])


model_configs = pd.DataFrame(model_configs).T
model_configs = model_configs.loc[allowed_models]
print(model_configs.shape, len(allowed_models))

In [None]:
info_orig_cols = ['objective', 'architecture_class', 'dataset_class', 'size_class']
info_cols = ['Objective', 'Architecture', 'Dataset size', 'Model size']

In [None]:
def get_model_ids(fn):
    with open(fn, 'r') as file:
        lines = file.readlines()
    lines = [line.strip() for line in lines]
    return lines


def load_sim_martix(path):
    model_ids_fn = path / 'model_ids.txt'
    sim_mat_fn = path / 'similarity_matrix.pt'
    if model_ids_fn.exists():
        model_ids = get_model_ids(model_ids_fn)
    else:
        raise FileNotFoundError(f'{str(model_ids_fn)} does not exist.')
    sim_mat = torch.load(sim_mat_fn)
    df = pd.DataFrame(sim_mat, index=model_ids, columns=model_ids)
    
    available_models = sorted(list(set(model_ids).intersection(allowed_models)))
    
    df = df.loc[available_models, available_models]
    return df

In [None]:
sim_mats = defaultdict(dict)
for sim_metric in sim_metrics:
    for ds in [x_axis_ds] + y_axis_ds:
        sim_mats[sim_metric][ds] = load_sim_martix(base_path_similarity_matrices / ds / sim_metric)
        np.fill_diagonal(sim_mats[sim_metric][ds].values, 1)

In [None]:

for sim_metric, sim_mats in sim_mats.items():
    

In [None]:
curr_sim_metric = sim_metrics[2]
result = np.stack(list(sim_mats[curr_sim_metric].values()), axis=0)

In [None]:
mean_res = pd.DataFrame(result.mean(axis=0), index = allowed_models, columns=allowed_models)
std_res = pd.DataFrame(result.std(axis=0), index = allowed_models, columns=allowed_models)

In [None]:
pairs = [
    ('Image-Text', 'Image-Text'),
    ('Image-Text', 'Self-Supervised'),
    ('Image-Text', 'Supervised'),
    ('Self-Supervised', 'Self-Supervised'),
    ('Self-Supervised', 'Supervised'),
    ('Supervised', 'Supervised'), 
]
comb = []
for i, val in enumerate(model_configs['objective']):
    for j, val2 in enumerate(model_configs['objective']):
        if i>=j:
            continue
        # comb.append(f"{val}, {val2}")
        if val==val2 or (val, val2) in pairs:
            comb.append(f"{val}, {val2}")
        elif (val2, val) in pairs:
            comb.append(f"{val2}, {val}")
        else:
            raise ValueError("Unknown pair")
    

In [None]:
# sns.jointplot( x=mean_upper[mean_upper != 0], y=std_upper[std_upper != 0], hue=comb)
mean_upper = np.triu(mean_res,k=1)
std_upper =  np.triu(std_res,k=1)
g = sns.scatterplot(
    x = mean_upper[mean_upper != 0],
    y = std_upper[std_upper != 0],
    hue = comb,
    alpha=0.75
)
g.set_xlabel('Mean CKA value over all datasets')
g.set_ylabel('Std of CKA values over all datasets')
g.set_title(sim_metric_name_mapping[curr_sim_metric])

In [None]:
mean_upper = np.triu(mean_res,k=1)
std_upper =  np.triu(std_res,k=1)
g = sns.scatterplot(
    x = mean_upper[mean_upper != 0],
    y = std_upper[std_upper != 0],
    hue = comb,
    alpha=0.75
)
g.set_xlabel('Mean CKA value over all datasets')
g.set_ylabel('Std of CKA values over all datasets')
g.set_title(curr_sim_metric)

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(24, 12))
sns.heatmap(mean_res, ax=axs[0])
axs[0].set_title(f'Mean {sim_metric_name_mapping[curr_sim_metric]} matrix over all datasets')
sns.heatmap(std_res, ax=axs[1])
fig.tight_layout()
axs[1].set_title(f'STD {sim_metric_name_mapping[curr_sim_metric]} matrix over all datasets')

In [None]:
similarity_mat = mean_res
dissimilarity_mat = (1 - mean_res)

In [None]:
from sklearn.cluster import SpectralClustering, AffinityPropagation, DBSCAN
from sklearn.manifold import TSNE, MDS

In [None]:
clustering_cols = []
clustering = SpectralClustering(n_clusters=6,
                                affinity='precomputed',
                                assign_labels='cluster_qr',
                                random_state=0)

model_configs['spec_labels'] = clustering.fit_predict(similarity_mat.values, y=None)    
clustering_cols.append('spec_labels')

In [None]:
clustering = AffinityPropagation(damping=0.85,
                                 affinity='precomputed',
                                 random_state=5)


model_configs['aff_labels'] = clustering.fit_predict(similarity_mat.values, y=None)  
clustering_cols.append('aff_labels')

In [None]:
clustering = DBSCAN(eps=3, min_samples=2, metric='precomputed')

model_configs['dbs_labels'] = clustering.fit_predict(dissimilarity_mat.values, y=None)  
clustering_cols.append('dbs_labels')

In [None]:
embs = {
    'tsne':TSNE(n_components=2, 
                learning_rate='auto', 
                init='random', 
                perplexity=10,
                metric='precomputed').fit_transform(dissimilarity_mat.values), 
    'mds': MDS(n_components=2,
               normalized_stress='auto',
               dissimilarity='precomputed').fit_transform(dissimilarity_mat.values),
}

In [None]:
from itertools import product

In [None]:
nrows = len(info_orig_cols)
ncols = len(embs)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize = (4*ncols, 4*nrows))
axes = axes.flatten()

for i, (m_cat_col, embed_type) in enumerate(product(info_orig_cols, embs.keys())):
    ax = axes[i]
    curr_embed = embs[embed_type]
    
    sns.scatterplot(
        x = curr_embed[:,0],
        y = curr_embed[:,1],
        hue = model_configs[m_cat_col],
        palette = 'tab10',
        ax=ax,
        s=100
    )
    ax.set_title(f"{info_cols[i//ncols]} – {embed_type.upper()}")
    if i%ncols ==0:
        ax.get_legend().remove()
    else:
        sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), title=info_cols[i//ncols])

    

In [None]:
nrows = len(clustering_cols)
ncols = len(embs)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize = (4*ncols, 4*nrows))
axes = axes.flatten()
for i, (clust_col, embed_type) in enumerate(product(clustering_cols, embs.keys())):
    ax = axes[i]
    curr_embed = embs[embed_type]
    sns.scatterplot(
        x = curr_embed[:,0],
        y = curr_embed[:,1],
        hue = model_configs[clust_col],
        palette = 'tab10',
        ax=ax,
        s=100
    )
    ax.set_title(f"{clust_col} – {embed_type.upper()}")
    if i%ncols ==0:
        ax.get_legend().remove()
    else:
        sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), title=clust_col)