<!-- ## Notebook 4.3: *Do representational similarities cluster according to model categories?*
This notebook create the figures for the section 4.3. The notebook creates scatter plots of the representational similarities of the models in a 2D space colored by the model categories. The TNSE embeddings are computed solely for three datasets. -->
# TODO: fill

In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.cluster import SpectralClustering, AffinityPropagation
from sklearn.manifold import MDS, TSNE

from constants import (
    BASE_PATH_PROJECT,
    BASE_PATH_RESULTS,
    cat_color_mapping,
    cat_name_mapping,
    ds_list_sim_file,
    exclude_models,
    exclude_models_w_mae,
    fontsizes,
    fontsizes_cols,
    model_cat_mapping,
    model_config_file,
    model_size_order,
    sim_metric_name_mapping
)
from helper import (
    load_all_datasetnames_n_info,
    load_model_configs_and_allowed_models,
    load_similarity_matrices,
    pp_storing_path,
    save_or_show
)

#### Global variables

In [2]:
# Define the path to the similarity matrices
base_path_similarity_matrices = BASE_PATH_PROJECT / 'model_similarities'

# Define similarity metrics to be used
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
]

# Load used dataset names
ds_list, ds_info = load_all_datasetnames_n_info(ds_list_sim_file, verbose=False)

# Define datasets of interest
ds_oi = ['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam']

# Define model filtering suffix
suffix = ''  #'_wo_mae'

version = 'arxiv'

SAVE = True
storing_path = pp_storing_path(BASE_PATH_RESULTS / 'plots' / 'q1_ssl_pairs', SAVE)
storing_path




PosixPath('/home/space/diverse_priors/results_rebuttal/plots/q1_ssl_pairs')

#### Load model configurations and similarity matrices

In [3]:
# Load model configurations and allowed models
curr_excl_models = []
if suffix:
    curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)


Nr. models original=64


In [4]:
# Load similarity matrices
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)

### Embedding and clustering of representational similarities
The representational similarities are embedded in a 2D space using t-SNE and MDS. The representational similarities can be used to compute the dissimilarity between the models. The dissimilarity is computed as 1 - similarity. These dissimilarities can be directly used as input for the embedding methods.

In [5]:
## Define embedding and clustering methods
def get_embedder(manifold_method, n_components=6):
    if manifold_method == 'mds':

        embedder = MDS(n_components=n_components,
                       normalized_stress='auto',
                       dissimilarity='precomputed',
                       random_state=42,
                       eps=1e-4,
                       n_init=5)
    else:
        embedder = TSNE(n_components=n_components,
                        learning_rate='auto',
                        init='random',
                        perplexity=5,
                        metric='precomputed',
                        random_state=42)

    emb_cols = [f"{manifold_method.upper()} {i + 1}" for i in range(n_components)]
    return embedder, emb_cols


def get_clusterer(cluster_method, n_clusters=6):
    if cluster_method == 'spectral':
        clustering = SpectralClustering(n_clusters=n_clusters,
                                        affinity='precomputed',
                                        assign_labels='kmeans',
                                        random_state=42)
    else:
        clustering = AffinityPropagation(damping=0.75,
                                         affinity='precomputed',
                                         random_state=42)
    return clustering


In [6]:
def embed_ds_list(ds_list, sim_metric, sim_matrices, embedder, emb_cols, clustering=None):
    """
    Embeds the similarity matrices of the datasets in the dataset list in a low-dimensional embedding using the given similarity metric and embedding method.
    """
    sim_data = sim_matrices[sim_metric]
    embed_list = []
    for ds in ds_list:
        sim_mat = sim_data[ds]

        dissim_mat = 1 - sim_mat

        embs = embedder.fit_transform(dissim_mat.values)
        embs = pd.DataFrame(embs, columns=emb_cols)
        embs['Model'] = sim_mat.index.tolist()
        embs['Dataset'] = ds_info.loc[ds, 'name']
        for cat, cat_name in model_cat_mapping.items():
            embs[cat_name] = model_configs.loc[sim_mat.index, cat].map(cat_name_mapping).values

        if clustering:
            embs['Cluster'] = clustering.fit_predict(sim_mat.values, y=None)
            embs['Cluster'] = embs['Cluster'].astype('category')

        embed_list.append(embs)

    all_embeddings = pd.concat(embed_list, axis=0)
    return all_embeddings

In [7]:
embedder, emb_cols = get_embedder('tsne', 2)

For each similarity metric, we first embedd all similarity matrices (one per dataset) into the low dimensional space. The embeddings are then visualized in a scatter plot where the points are colored according to the model categories. The scatter plots are created for each dataset and similarity metric.

In [8]:
sim_metric = 'cka_kernel_linear_unbiased'
all_embeddings = embed_ds_list(ds_oi, sim_metric, sim_mats, embedder, emb_cols, None)
all_embeddings['Dataset'].unique()

array(['ImageNet-1k', 'Flowers', 'PCAM'], dtype=object)

In [9]:
n_ds = all_embeddings['Dataset'].nunique()
fig, axes = plt.subplots(nrows=1, ncols=n_ds, figsize=(n_ds*4, 4))

# Handle the case where there's only one dataset
if n_ds == 1:
    axes = [axes]

# Store handles and labels for the legend
handles, labels = None, None

for i, ax in enumerate(axes):
    ds = all_embeddings['Dataset'].unique()[i]
    tmp_df = all_embeddings[all_embeddings['Dataset'] == ds]
    
    # Create scatterplot
    scatter = sns.scatterplot(
        data=tmp_df,
        x='TSNE 1',
        y='TSNE 2',
        hue='Training objective',
        style='Training objective',
        ax=ax,
        s=50,
        markers=['o', 's', '^'],
        alpha=0.75
    )
    
    # Keep axes but remove ticks
    ax.tick_params(axis='both', which='both', length=0)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    
    # Add dataset title
    ax.set_title(ds, fontsize=14)
    
    # Add axis labels if needed
    if i == 0:  # Only for first subplot, or remove this condition to add to all
        ax.set_ylabel('TSNE 2', fontsize=12)
    else:
        ax.set_ylabel('')
    ax.set_xlabel('TSNE 1', fontsize=12)
    
    # Store handle and labels for legend (from the last plot)
    if i == len(axes) - 1:
        handles, labels = scatter.get_legend_handles_labels()
    
    # Remove individual legends
    if ax.get_legend():
        ax.get_legend().remove()

# Add a single legend outside all subplots
fig.legend(
    handles, 
    labels, 
    loc='upper center', 
    bbox_to_anchor=(0.5, 0.1),  # Position at bottom center
    ncol=min(len(labels), 3),  # Adjust number of columns
    frameon=False,
    title="",
    fontsize=14
)

plt.tight_layout(rect=[0, 0.1, 1, 0.95])  # Adjust layout to make room for legend

if SAVE:
    fn = storing_path / f"tsne_all_{sim_metric}.pdf"
    fig.savefig(fn)
    print(f"Stored figure at {fn=}")
    plt.close(fig)
else:
    fig.show()

Stored figure at fn=PosixPath('/home/space/diverse_priors/results_rebuttal/plots/q1_ssl_pairs/tsne_all_cka_kernel_linear_unbiased.pdf')


In [10]:
import plotly.express as px


for ds in all_embeddings['Dataset'].unique():
    tmp_df = all_embeddings[all_embeddings['Dataset'] == ds]
    
    # Create the scatterplot with Plotly Express
    fig = px.scatter(
        tmp_df,
        x='TSNE 1',
        y='TSNE 2',
        hover_data='Model',
        color='Training objective',
        template='plotly_white'
    )
    
    # Update layout for a cleaner look
    fig.update_layout(
        title=f't-SNE of {sim_metric_name_mapping[sim_metric]} matrices evaluated on {ds}',
        legend_title_text='Training objective',
        xaxis_title='TSNE 1',
        yaxis_title='TSNE 2'
    )
    
    if SAVE:
        fn = storing_path / f"tsne_{ds.lower()}_{sim_metric}.html"
        fig.write_html(fn)
        print(f"Stored figure at {fn=}")
    else:
        fig.show()

Stored figure at fn=PosixPath('/home/space/diverse_priors/results_rebuttal/plots/q1_ssl_pairs/tsne_imagenet-1k_cka_kernel_linear_unbiased.html')
Stored figure at fn=PosixPath('/home/space/diverse_priors/results_rebuttal/plots/q1_ssl_pairs/tsne_flowers_cka_kernel_linear_unbiased.html')
Stored figure at fn=PosixPath('/home/space/diverse_priors/results_rebuttal/plots/q1_ssl_pairs/tsne_pcam_cka_kernel_linear_unbiased.html')
