## Notebook 4.6: *Do dataset categories influence relative similarity consistency?*

This notebook creates figures for section 4.6. We visualize correlation coefficients between all dataset pairs, both grouped by training objectives and without grouping. This provides a comprehensive view of consistency patterns across different dataset categories.


In [None]:
import textwrap
from itertools import combinations

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.gridspec import GridSpec
from scipy.stats import pearsonr, spearmanr

from constants import (
    BASE_PATH_RESULTS,
    cat_name_mapping,
    ds_list_sim_file,
    fontsizes
)
from helper import (
    load_all_datasetnames_n_info,
    pp_storing_path,
    save_or_show
)

#### Global variables and data loading

In [None]:
# Load dataset information
ds_list, ds_info = load_all_datasetnames_n_info(ds_list_sim_file, verbose=False)

# Version and plotting info
version = '3_opt'
curr_fontsizes = {k: v + 1 for k, v in fontsizes.items()}

## Storing information 
corr_method = 'pearsonr'  # spearmanr, pearsonr
SAVE = True
storing_path = pp_storing_path(BASE_PATH_RESULTS / 'plots' / 'final' / version / 'sec_4_6_r_coeff_mats', SAVE)

## Load data
orig_sim_data_path = BASE_PATH_RESULTS / f'aggregated/model_sims/all_metric_ds_model_pair_similarity.csv'
assert orig_sim_data_path.exists(), f"Path does not exist: {orig_sim_data_path}. Aggregated similarity data not found, please run aggregate_similarities_across_datasets.ipynb before."
orig_sim_data = pd.read_csv(orig_sim_data_path)

## Process Objective column
orig_sim_data['Objective pair'] = orig_sim_data['Objective pair'].apply(eval)

combinations_objectives = [('all', 'all'),
                           ('Self-Supervised', 'Self-Supervised'),
                           ('Image-Text', 'Supervised'), ]

In [None]:
ds_info.loc['wds_vtab_diabetic_retinopathy', 'name'] = "Diabet. Retino."
ds_info.loc['wds_voc2007', 'name'] = "VOC 2007"

#### Helper funtions

In [None]:
def filter_data(mcat1, mcat2, sim_data, curr_suffix = None):
    if len(mcat1) > 0 and len(mcat2) > 0:
        sim_data_flat = sim_data[sim_data['Objective pair'].apply(lambda x: sorted(x) == sorted([mcat1, mcat2]))].copy()
        if curr_suffix is not None:
            curr_suffix = f"_{cat_name_mapping[mcat1]}_{cat_name_mapping[mcat2]}" + curr_suffix
    else:
        sim_data_flat = sim_data
    return sim_data_flat, curr_suffix


def get_ds_combs(df):
    n_ds = df['DS'].nunique()
    available_ds = sorted(list(set(df['DS'].unique()).intersection(ds_list)))
    available_domains = sorted(ds_info['domain'].unique().tolist())
    # available_ds = sorted(available_ds, key = lambda x: available_domains.index(ds_info.loc[x, 'domain']))
    available_ds = sorted(available_ds, key=lambda x: (ds_info.loc[x, 'domain'], ds_info.loc[x, 'name']))
    combs_DS = list(combinations(available_ds, 2))
    corr_mat = pd.DataFrame(index=available_ds, columns=available_ds).astype('float')
    return available_ds, available_domains, combs_DS, corr_mat


def get_r_coeff(x, y, method='pearsonr'):
    if method == 'pearsonr':
        corr, _ = pearsonr(x, y)
    elif method == 'spearmanr':
        corr, _ = spearmanr(x, y)
    else:
        raise ValueError("Unknown method")
    return corr


def get_two_ds_data(ds1, ds2, all_sims):
    data_2_ds = all_sims[all_sims['DS'].isin([ds1, ds2])].copy()
    data_2_ds['model_pair'] = data_2_ds['Model 1'] + ", " + data_2_ds['Model 2']
    ds_similarities = pd.pivot_table(
        data_2_ds,
        columns='DS',
        index='model_pair',
        values='Similarity value',
    )
    return ds_similarities


def fill_corr_mat(df, combs_DS, corr_mat):
    for ds1, ds2 in combs_DS:
        ds_sims = get_two_ds_data(ds1, ds2, df)
        corr = get_r_coeff(ds_sims.values[:, 0], ds_sims.values[:, 1], method=corr_method)
        corr_mat.loc[ds1, ds2] = float(corr)
        corr_mat.loc[ds2, ds1] = float(corr)

    np.fill_diagonal(corr_mat.values, 1)
    return corr_mat


def rename_idx_cols(corr_mat, available_ds):
    new_naming = [ds_info.loc[ds, 'name'] for ds in available_ds]
    corr_mat.index = new_naming
    corr_mat.columns = new_naming
    return corr_mat


def get_all_correlations(df, corr_type, col_data):
    r_vals = []
    for ds1, ds2 in combinations(ds_list, 2):
        ds1_subset = df[df['DS'] == ds1]
        ds2_subset = df[df['DS'] == ds2]
        r_vals.append({
            'ds1': ds1,
            'ds2': ds2,
            'r coeff': get_r_coeff(ds1_subset[col_data].values, ds2_subset[col_data].values, corr_type)

        })

    return pd.DataFrame(r_vals)


def get_all_ds_corr_mat(sim_metric):
    all_sim_data = pd.read_csv(orig_sim_data_path)
    all_sim_data = all_sim_data[all_sim_data['DS'].isin(ds_list)]
    sim_data = all_sim_data[all_sim_data['Similarity metric'] == sim_metric]
    r_corrs = get_all_correlations(sim_data, corr_method, 'Similarity value')
    corr_mat = pd.DataFrame(columns=ds_list, index=ds_list, dtype=float)

    def add_entries(row):
        corr_mat.loc[row['ds1'], row['ds2']] = row['r coeff']
        corr_mat.loc[row['ds2'], row['ds1']] = row['r coeff']

    r_corrs.apply(add_entries, axis=1)
    np.fill_diagonal(corr_mat.values, 1)
    new_index = ds_info.loc[corr_mat.index, :].sort_values(['domain', 'name']).index
    corr_mat = corr_mat.loc[new_index, new_index]
    corr_mat.index = ds_info.loc[corr_mat.index, 'name']
    corr_mat.columns = ds_info.loc[corr_mat.columns, 'name']
    return corr_mat, list(new_index), sorted(ds_info.loc[new_index, 'domain'].unique())


def plot_heatmap_v2(ax, corr_mat, available_ds, available_domains, vmin=-0.44, vmax=1, set_title=True, cbar=False):
    sns.heatmap(corr_mat, square=True, cmap='mako', cbar=cbar, vmin=vmin, vmax=vmax, ax=ax)

    tmp = np.where(
        ~(ds_info.loc[available_ds, 'domain'].iloc[:-1].values == ds_info.loc[available_ds, 'domain'].iloc[1:].values))[
        0]
    tmp += 1

    for val in tmp:
        ax.axhline(val, c='black', ls=":", lw=2)
        ax.axvline(val, c='black', ls=":", lw=2)

    ax.tick_params('y', labelsize=curr_fontsizes['label'])
    ax.tick_params('x', pad=0.1, labelsize=curr_fontsizes['label'] - 1)
    labels = ax.get_xticklabels()
    ax.set_xticklabels(labels, rotation=45, ha='right')

    if set_title:
        text_pos = [(5, -2.2), (13, -2.2), (17.75, -2.2), (21.75, -2.2)]
    
        ax.text(text_pos[0][0], text_pos[0][1], '\n'.join(available_domains[0].split(' ')).replace('ain', '.'), ha='center',
                va='top', fontsize=curr_fontsizes['title'], color='black')
        ax.text(text_pos[1][0], text_pos[1][1], '\n'.join(available_domains[1].split(' ')).replace('ain', '.'), ha='center',
                va='top', fontsize=curr_fontsizes['title'], color='black')
        ax.text(text_pos[2][0], text_pos[2][1], '-\n'.join(textwrap.wrap(available_domains[2], width=7)), ha='center',
                va='top', fontsize=curr_fontsizes['title'], color='black')
        ax.text(text_pos[3][0], text_pos[3][1], '-\n'.join(textwrap.wrap(available_domains[3], width=5)), ha='center',
                va='top', fontsize=curr_fontsizes['title'], color='black')

    ax.set_xlabel('')
    ax.set_ylabel('')
    return ax


def setup_figure(m, size_fig=7, size_bar=0.25, wspace=0.05):
    fig = plt.figure(figsize=(m * size_fig + size_bar, size_fig))
    gs = GridSpec(1, m + 1, width_ratios=[size_fig] * m + [size_bar], wspace=wspace)

    # Create axes with shared x and y
    ax0 = fig.add_subplot(gs[0, 0])
    axs = [ax0] + [fig.add_subplot(gs[0, i], sharey=ax0, sharex=ax0) for i in range(1, m)]
    return fig, gs, axs


def add_colorbar(gs, cmap='mako', vmin=-0.44, vmax=1):
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
    sm.set_array([])
    cax = fig.add_subplot(gs[0, -1])
    plt.colorbar(sm, cax=cax)
    cax.tick_params(labelsize=curr_fontsizes['label'])


def update_suffix(suffix, mcat1, mcat2):
    def pp_str(x):
        return x.lower().replace(' ', '_')

    if mcat1 == mcat2:
        suffix += f"_2_{pp_str(mcat1)}"
    else:
        suffix += f"_{pp_str(mcat1)}_n_{pp_str(mcat2)}"
    return suffix

### Compute consistency matrices computed on different subsets of model pairs representational similarities 

In [None]:
m = len(combinations_objectives)

for sim_metric in ['CKA RBF 0.4', 'CKA linear']:
    print(sim_metric)

    sim_data = orig_sim_data[orig_sim_data['Similarity metric'] == sim_metric].reset_index().copy()

    suffix = "_" + sim_metric.replace(" ", "_").lower()

    fig, gs, axs = setup_figure(m, size_fig=6.5, size_bar=0.25, wspace=0.05)

    for i, (mcat1, mcat2) in enumerate(combinations_objectives):
        ax = axs[i]
        if mcat1 == 'all' and mcat2 == 'all':
            corr_mat, available_ds, available_domains = get_all_ds_corr_mat(sim_metric)
        else:
            sim_data_flat, _ = filter_data(mcat1, mcat2, sim_data, suffix)
            available_ds, available_domains, combs_DS, corr_mat = get_ds_combs(sim_data_flat)
            corr_mat = fill_corr_mat(sim_data_flat, combs_DS, corr_mat)
            corr_mat = rename_idx_cols(corr_mat, available_ds)

        plot_heatmap_v2(ax, corr_mat, available_ds, available_domains)

        if i > 0:
            plt.setp(ax.get_yticklabels(), visible=False)

        suffix = update_suffix(suffix, mcat1, mcat2)

    add_colorbar(gs, cmap='mako', vmin=-0.44, vmax=1)

    save_or_show(plt.gcf(), storing_path / f'grouped_heatmap{suffix}.pdf', SAVE)

### NEW: Compute consistency matrices computed on different subsets of model pairs representational similarities 

In [None]:
sim_metric = 'CKA linear'

sim_data = orig_sim_data[orig_sim_data['Similarity metric'] == sim_metric].reset_index().copy()

In [None]:
def get_subset_data(mcat1, mcat2):
    sim_data_flat, _ = filter_data(mcat1, mcat2, sim_data, None)
    available_ds, available_domains, combs_DS, corr_mat = get_ds_combs(sim_data_flat)
    corr_mat = fill_corr_mat(sim_data_flat, combs_DS, corr_mat)
    corr_mat = rename_idx_cols(corr_mat, available_ds)
    return corr_mat, available_ds, available_domains

In [None]:
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(12, 16))  # Increased height
gs = gridspec.GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1.6, 1])  # Larger top section

ax1 = fig.add_subplot(gs[0, :])
ax3 = fig.add_subplot(gs[1, 0])
ax4 = fig.add_subplot(gs[1, 1])

cmap='mako'
vmin=-0.44
vmax=1

corr_mat, available_ds, available_domains = get_all_ds_corr_mat(sim_metric)
plot_heatmap_v2(ax1, corr_mat, available_ds, available_domains, cbar=True)
ax1.figure.axes[-1].tick_params(labelsize=curr_fontsizes['label'])

corr_mat, available_ds, available_domains = get_subset_data('Self-Supervised', 'Self-Supervised')
plot_heatmap_v2(ax3, corr_mat, available_ds, available_domains, set_title=False)
ax3.set_xticks([])
ax3.set_yticks([])
ax3.set_xlabel('SSL', fontsize=curr_fontsizes['title'])

corr_mat, available_ds, available_domains = get_subset_data('Image-Text', 'Supervised')
plot_heatmap_v2(ax4, corr_mat, available_ds, available_domains, set_title=False)
ax4.set_xticks([])
ax4.set_yticks([])
ax4.set_xlabel('Img-Txt & Sup', fontsize=curr_fontsizes['title'])

plt.subplots_adjust(hspace=0.25, wspace=0.05)
save_or_show(plt.gcf(), storing_path / f'grouped_heatmap_{corr_method}_one_big_two_small.pdf', SAVE)