In [None]:
import textwrap
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from sklearn.manifold import TSNE

from constants import sim_metric_name_mapping as name_mapping

In [None]:
def check_path(path):
    if not path.exists():
        raise FileNotFoundError(f"Directory {str(path)} does not exists!")


dataset = 'imagenet-subset-10k'
# model similarity matrices
model_sim_root_path = Path('/home/space/diverse_priors/model_similarities')
model_sim_path = model_sim_root_path / dataset
check_path(model_sim_path)

# clusterings 
clustering_root_path = Path('/home/space/diverse_priors/clustering/')
clustering_path = clustering_root_path / dataset
check_path(clustering_path)

In [None]:
SAVE = True
storing_path = clustering_path / 'plots' / 'models_filtered_tuned_wd_in1k'

if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
def get_model_ids(fn):
    with open(fn, 'r') as file:
        lines = file.readlines()
    lines = [line.strip() for line in lines]
    return lines

In [None]:
sim_mats = {}
storing_paths = {}
for sim_method in model_sim_path.rglob("**/similarity_matrix.pt"):
    print(sim_method)
    model_ids_fn = sim_method.parent / 'model_ids.txt'
    if model_ids_fn.exists():
        model_ids = get_model_ids(model_ids_fn)
    else:
        raise FileNotFoundError(f'{str(model_ids_fn)} does not exist.')

    sim_mat = torch.load(sim_method)
    sim_mats[sim_method.parent.name] = pd.DataFrame(sim_mat, index=model_ids, columns=model_ids)
    np.fill_diagonal(sim_mats[sim_method.parent.name].values, 1)

sim_mats = {x: sim_mats[x] for x in name_mapping.keys()}

In [None]:
# Check all matrices have the same index
first_index = next(iter(sim_mats.values())).index
all_same_index = all(df.index.equals(first_index) for df in sim_mats.values())
if not all_same_index:
    raise ValueError('All DataFrames must have the same index.')

In [None]:
# lbl_assignment_methods = ['kmeans', 'discretize', 'cluster_qr']
lbl_assignment_methods = ['cluster_qr']

In [None]:
clustering_labels = {k: {k1: pd.DataFrame(index=first_index) for k1 in lbl_assignment_methods} for k in sim_mats.keys()}

In [None]:
## For each clustering get all cluster_labels.csv
for clust_fn in clustering_path.rglob("**/cluster_labels.csv"):

    split_path = str(clust_fn).split("/")

    metric_key = split_path[-4]
    num_clusters = split_path[-3]
    lbl_assignment = split_path[-2]
    if lbl_assignment not in lbl_assignment_methods:
        continue

    print(clust_fn)
    num_clusters = f"{num_clusters.split('_')[-1]} clusters"

    df = pd.read_csv(clust_fn, index_col='Unnamed: 0')
    clustering_labels[metric_key][lbl_assignment].loc[:, num_clusters] = df['cluster'].astype('category')

In [None]:
clustering_labels['cka_kernel_linear_unbiased']['cluster_qr']

In [None]:
clustering_labels = {
    k: {k1: df[sorted(df.columns.tolist())] for k1, df in sub_dict.items()} for k, sub_dict in
    clustering_labels.items()}

clustering_labels = {k: {k1: df[~df['3 clusters'].isna()].copy() for k1, df in sub_dict.items()} for k, sub_dict in
                     clustering_labels.items()}

In [None]:
clustering_labels['cka_kernel_linear_unbiased']['cluster_qr'].head()

In [None]:
available_models = sorted(clustering_labels['cka_kernel_linear_unbiased']['cluster_qr'].index.tolist())

In [None]:
# Get for each similarity metric TSNE embeddings
tsne_embeddings = {}
for key, sim_mat in sim_mats.items():
    tmp = sim_mat.loc[available_models, available_models]
    dissimilarity_mat = 1 - tmp.values
    tsne_embeddings[key] = TSNE(n_components=2,
                                learning_rate='auto',
                                init='random',
                                perplexity=10,
                                metric='precomputed',
                                random_state=42
                                ).fit_transform(dissimilarity_mat)

In [None]:
lbl_assignment_method = lbl_assignment_methods[-1]
lbl_assignment_method

In [None]:
lbl_assignment_clustering_labels = {k: sub_dict[lbl_assignment_method] for k, sub_dict in clustering_labels.items()}

In [None]:
n = len(sim_mats.keys())
m = lbl_assignment_clustering_labels['cka_kernel_linear_unbiased'].shape[1]
fig, axs = plt.subplots(nrows=m, ncols=n, figsize=(3 * n, 3 * m))

for i, (key, embedd) in enumerate(tsne_embeddings.items()):
    for j, col in enumerate(lbl_assignment_clustering_labels[key]):
        lbls = lbl_assignment_clustering_labels[key][col]
        sns.scatterplot(
            x=embedd[:, 0],
            y=embedd[:, 1],
            hue=lbls,
            palette='tab10',
            legend=False,
            s=75,
            ax=axs[j, i],
            alpha=0.6
        )
        if j == 0:
            axs[j, i].set_title(f'{name_mapping[key]}', fontsize=16)
        if i == 0:
            axs[j, i].set_ylabel(f'{col}', fontsize=16)
plt.tight_layout()
if SAVE:
    fig.savefig(storing_path / f'tsne_clustering_{lbl_assignment_method}_overview.pdf')
    plt.close(fig)
else:
    plt.show(fig)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
import plotly.express as px

# Define number of rows and columns for the subplots
n = len(sim_mats.keys())
m = lbl_assignment_clustering_labels['cka_kernel_linear_unbiased'].shape[1]

# Create subplot grid
fig = make_subplots(
    rows=m,
    cols=n,
    subplot_titles=[f'{name_mapping[key]}' for key in tsne_embeddings] * m
)

categories = lbl_assignment_clustering_labels['cka_kernel_rbf_unbiased_sigma_0.2']['7 clusters'].cat.categories.tolist()

color_palette = px.colors.qualitative.T10
label_to_color = {label: color_palette[label] for label in categories}

for i, (key, embedd) in enumerate(tsne_embeddings.items()):
    for j, col in enumerate(lbl_assignment_clustering_labels[key]):
        lbls = lbl_assignment_clustering_labels[key][col]
        colors = [label_to_color[label] for label in lbls]
        scatter = go.Scatter(
            x=embedd[:, 0],
            y=embedd[:, 1],
            mode='markers',
            marker=dict(color=colors,
                        size=7),
            showlegend=False,
            hovertext=lbls.index.tolist()
        )
        fig.add_trace(scatter, row=j + 1, col=i + 1)

# Update axis labels and titles
for i, col in enumerate(lbl_assignment_clustering_labels[key]):
    fig.update_yaxes(title_text=f'{col}', row=i + 1, col=1)

for i in range(n):
    fig.update_xaxes(title_text='t-SNE Dimension', row=m, col=i + 1)

fig.update_layout(height=300 * m, width=300 * n, title_text="t-SNE Plots", showlegend=False)

if SAVE:
    fig_html = storing_path / f'tsne_clustering_{lbl_assignment_method}_overview.html'
    pio.write_html(fig, file=fig_html, auto_open=True)
fig.show()


In [None]:
for i, (key, embedd) in enumerate(tsne_embeddings.items()):
    curr_storing_path = storing_path / key
    print(f"Creating single plots for {key} ...")
    for j, col in enumerate(lbl_assignment_clustering_labels[key]):

        lbls = lbl_assignment_clustering_labels[key][col]

        ax = sns.scatterplot(
            x=embedd[:, 0],
            y=embedd[:, 1],
            hue=lbls,
            palette='tab10',
            # legend=False,
            s=75,
            alpha=0.6
        )
        curr_df = lbl_assignment_clustering_labels[key].reset_index().groupby(col)
        group_models_spec = curr_df.index.unique()
        lbl_spec = group_models_spec.apply(lambda x: '\n'.join(textwrap.wrap(', '.join(x), width=50)))
        legend = ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        legend.set_title('Labels')
        for k, label in lbl_spec.items():
            legend.get_texts()[k].set_text(f"{label}")

        ax.set_title(
            f"{name_mapping[key]} on {dataset.capitalize()} with {len(lbl_spec)} clusters ({lbl_assignment_method}).")
        if SAVE:
            curr_storing_path.mkdir(parents=True, exist_ok=True)
            plt.savefig(curr_storing_path / f"{col.replace(' ', '_')}_{lbl_assignment_method}.pdf", bbox_inches='tight')
            plt.close()
        else:
            plt.show()