In [None]:
import sys
from pathlib import Path

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.ticker as ticker
import scipy

from sklearn.manifold import MDS, TSNE
from sklearn.cluster import SpectralClustering, AffinityPropagation

from constants import (
    exclude_models, 
    exclude_models_w_mae, 
    ds_name_mapping, 
    model_categories, 
    model_cat_mapping, 
    model_config_file, 
    ds_info_file,
    fontsizes,
    cat_name_mapping,
    sim_metric_name_mapping,
    model_size_order,
    cat_color_mapping
)
from helper import load_model_configs_and_allowed_models, load_similarity_matrices, save_or_show, load_ds_info, get_fmt_name

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')

sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
]

ds_list = parse_datasets('../scripts/webdatasets_w_insub10k.txt')
ds_list = list(map(lambda x: x.replace('/', '_'), ds_list))

ds_info = load_ds_info(ds_info_file)

ds_oi = ['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam']

# suffix=''
suffix = '_wo_mae'

cm = 0.393701

SAVE = True
storing_path = Path('/home/space/diverse_priors/results/plots/clustering_on_sim_vals')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
curr_excl_models = []
if suffix:
    curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)


In [None]:
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)

In [None]:
def get_embedder(manifold_method, n_components = 6):
    if manifold_method == 'mds':
        
        embedder = MDS(n_components=n_components,
                       normalized_stress='auto', 
                       dissimilarity='precomputed',
                       random_state=42,
                       eps=1e-4,
                       n_init=5)
    else:
        embedder = TSNE(n_components=n_components,
                learning_rate='auto',
                init='random',
                perplexity=5,
                metric='precomputed',
                random_state=42)
    
    emb_cols = [f"{manifold_method.upper()} {i+1}" for i in range(n_components)]
    return embedder, emb_cols


def get_clusterer(cluster_method, n_clusters=6):
    if cluster_method == 'spectral':
        clustering = SpectralClustering(n_clusters=n_clusters,
                                    affinity='precomputed',
                                    assign_labels='kmeans',
                                    random_state=42)
    else:
        clustering = AffinityPropagation(damping=0.75,
                                 affinity='precomputed',
                                 random_state=42)
    return clustering
        

In [None]:
def embed_ds_list(ds_list, sim_metric, sim_matrices, embedder, emb_cols, clustering=None):
    sim_data = sim_matrices[sim_metric]
    embed_list = []
    for ds in ds_list:
        sim_mat = sim_data[ds]

        dissim_mat = 1 - sim_mat
    
        embs = embedder.fit_transform(dissim_mat.values)
        embs = pd.DataFrame(embs, columns = emb_cols)
        embs['Model'] = sim_mat.index.tolist()
        embs['Dataset'] = ds_info.loc[ds, 'name']
        for cat, cat_name in model_cat_mapping.items():
            embs[cat_name] = model_configs.loc[sim_mat.index, cat].map(cat_name_mapping).values
        
        if clustering:
            embs['Cluster'] = clustering.fit_predict(sim_mat.values, y=None) 
            embs['Cluster'] = embs['Cluster'].astype('category')

        
        
        embed_list.append(embs)
    
    all_embeddings = pd.concat(embed_list, axis=0)
    return all_embeddings

In [None]:
embedder, emb_cols = get_embedder('tsne', 2)

In [None]:
#### VERTICAL ORDERING
for sim_metric in sim_metrics:
        
    all_embeddings = embed_ds_list(ds_oi, sim_metric, sim_mats, embedder, emb_cols, None) 
    
    ncols = all_embeddings['Dataset'].nunique()
    nrows = len(model_cat_mapping)
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(2.5*ncols, 2*nrows))
    
    
    for i, col in enumerate(model_cat_mapping.values()):
        hue_order = model_size_order if col=='Model size' else  sorted(all_embeddings[col].unique())
        col_palette = {hue: cat_color_mapping[hue] for hue in hue_order}
        
        for j, (ds, ds_data)  in enumerate(all_embeddings.groupby('Dataset', sort=False)):
            ax = axes[i, j]
            
            sns.scatterplot(
                ds_data,
                x = emb_cols[0],
                y = emb_cols[1],
                hue = col,
                hue_order=hue_order,
                palette = col_palette,
                ax=ax,
                legend= True if j==(ncols-1) else False
            )
            sns.despine(ax=ax)
            ax.set_xticks([])
            ax.set_yticks([])
            if j==0:
                ax.set_ylabel(emb_cols[1], fontsize=fontsizes['label'])
            else:
                ax.set_ylabel('')
    
            if j==(ncols-1):
                sns.move_legend(ax, loc="upper left", bbox_to_anchor=(1, 1), frameon=False, fontsize=fontsizes['ticks'], title_fontsize=fontsizes['legend'])
            
            if i==(nrows-1):
                ax.set_xlabel(emb_cols[0], fontsize=fontsizes['label'])
            else:
                ax.set_xlabel('')
    
            if i==0:
                ax.set_title(ds, fontsize=fontsizes['title'])
                
    
    fig.tight_layout()
    save_or_show(fig, storing_path / f'tnse_{sim_metric}_across_categories{suffix}_v.pdf', SAVE)

In [None]:
#### HORIZONTAL ORDERING

for sim_metric in sim_metrics:
        
    all_embeddings = embed_ds_list(ds_oi, sim_metric, sim_mats, embedder, emb_cols, None) 
    
    nrows = all_embeddings['Dataset'].nunique()
    ncols = len(model_cat_mapping)
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(3*ncols, 2.8*nrows))
    
    
    for col_idx, col in enumerate(model_cat_mapping.values()):
        hue_order = model_size_order if col=='Model size' else  sorted(all_embeddings[col].unique())
        col_palette = {hue: cat_color_mapping[hue] for hue in hue_order}
        
        for row_idx, (ds, ds_data)  in enumerate(all_embeddings.groupby('Dataset', sort=False)):
            ax = axes[row_idx, col_idx]
            
            sns.scatterplot(
                ds_data,
                x = emb_cols[0],
                y = emb_cols[1],
                hue = col,
                hue_order=hue_order,
                palette = col_palette,
                ax=ax,
                # legend= True if row_idx==(nrows-1) else False
                legend= True if row_idx==0 else False,
                s=50, 
                alpha=0.75,
                linewidth=0
            )
            sns.despine(ax=ax)
            ax.set_xticks([])
            ax.set_yticks([])
            if col_idx==0:
                ax.set_ylabel('$\mathit{'+ds+'}$' + '\n'+emb_cols[1], fontsize=fontsizes['label'])
            else:
                ax.set_ylabel('')

            if row_idx==0:
                sns.move_legend(ax, 
                                loc="lower center", 
                                bbox_to_anchor=(0.5, 1.05), 
                                frameon=False, 
                                fontsize=fontsizes['ticks'], 
                                title_fontsize=fontsizes['legend'],
                                ncols=2 if len(hue_order)> 2 else 1, 
                                )
            if row_idx==(nrows-1):
                ax.set_xlabel(emb_cols[0], fontsize=fontsizes['label'])
            else:
                ax.set_xlabel('')
                
   
    # fig.tight_layout()
    fig.subplots_adjust(wspace=0.1, hspace=0.1)
    save_or_show(fig, storing_path / f'tnse_{sim_metric}_across_categories{suffix}_h.pdf', SAVE)