## Notebook 4.5: *Which model categories influence relative similarity consistency?*
This notebook creates figures for section 4.5. It generates the boxplots showing the distribution of correlation coefficients between dataset pairs for each model category (training objective, architecture, training data diversity, and size). Each plot is colored by model subcategories and sorted by median correlation. We compare these distributions against the baseline correlation coefficients of ungrouped models.

In [None]:
from itertools import combinations
from textwrap import wrap

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.gridspec import GridSpec
from scipy.stats import pearsonr, spearmanr
from itertools import product
import json 
from pathlib import Path
from constants import (
    model_size_order,
    fontsizes,
    fontsizes_cols,
    cat_color_mapping,
    BASE_PATH_RESULTS,
    ds_list_sim_file,
    model_cat_mapping
)
from helper import save_or_show, pp_storing_path, load_all_datasetnames_n_info

sns.set_style('ticks')

#### Global variables

In [None]:
# Datasets
ds_list, ds_info = load_all_datasetnames_n_info(ds_list_sim_file, verbose=False)

# Experiment configuration
corr_type = 'pearsonr'  # 'pearsonr', 'spearmanr'
suffix = '_with_rsa'  # '', '_wo_mae', '_with_rsa', '_with_rsa_wo_mae_rotnet_jigsaw'
exp_conf = f'{corr_type}{suffix}'

# Path to correlation data
data_path = BASE_PATH_RESULTS / f'aggregated/r_coeff_dist/with_cats_as_anchors/agg_{corr_type}_all_ds{suffix}.csv'
print(data_path)
assert data_path.exists(), f'Path does not exist: {data_path}. Aggregated correlation coefficients across all dataset pairs not found, please run aggregate_consistencies_for_model_set_pairs.ipynb first.'

# Path to sim data
sim_data_path = BASE_PATH_RESULTS / f'aggregated/model_sims/all_metric_ds_model_pair_similarity{suffix}.csv'
assert sim_data_path.exists(), f"Path does not exist: {sim_data_path}. Aggregated similarity data not found, please run aggregate_similarities_across_datasets.ipynb before."

# Version
version = 'arxiv'
curr_fontsizes = fontsizes if version == 'arxiv' else fontsizes_cols
curr_fontsizes = {k: v + 1 for k, v in curr_fontsizes.items()}

# Storing path
SAVE = True
storing_path = pp_storing_path(Path('/home/space/diverse_priors/results_rebuttal') / 'plots' / 'exp_distr_for_dataset_cats' / exp_conf, SAVE)
print(f"{storing_path=}")

#### Load data

In [None]:
r_coeff_data = pd.read_csv(data_path)
print(r_coeff_data.shape)
r_coeff_data = r_coeff_data[r_coeff_data['ds1'].isin(ds_list) & r_coeff_data['ds2'].isin(ds_list)].reset_index(
    drop=True).copy()
print(r_coeff_data.shape)
r_coeff_data.head()

In [None]:
r_coeff_data['ds1_cat'] = r_coeff_data['ds1'].map(ds_info['domain'].to_dict())
r_coeff_data['ds2_cat'] = r_coeff_data['ds2'].map(ds_info['domain'].to_dict())

In [None]:
r_coeff_data['ds_cat_pair'] = r_coeff_data[['ds1_cat', 'ds2_cat']].apply(lambda x: tuple(sorted(x.tolist())), axis=1)

In [None]:
curr_cats = r_coeff_data['Comparison category'].unique()
if 'Objective' in curr_cats and 'Dataset diversity' in curr_cats:
    r_coeff_data['Comparison category'] = r_coeff_data['Comparison category'].map({
        'Architecture': 'Architecture',
        'Dataset diversity': 'Training data',
        'Objective': 'Training objective',
        'Model size': 'Model size',
    })

In [None]:
if 'cat_pair' not in r_coeff_data:
    r_coeff_data['cat_pair'] = r_coeff_data[['anchor_cat', 'other_cat']].apply(
        lambda x: tuple(sorted([x['anchor_cat'], x['other_cat']])), axis=1)
elif isinstance(r_coeff_data['cat_pair'].iloc[0], str):
    r_coeff_data['cat_pair'] = r_coeff_data['cat_pair'].apply(eval)


In [None]:
all_sim_data = pd.read_csv(sim_data_path)
all_sim_data.shape

In [None]:
all_sim_data['model_pair'] = all_sim_data[['Model 1', 'Model 2']].apply(
    lambda x: tuple(sorted([x['Model 1'], x['Model 2']])), axis=1)

In [None]:
all_sim_data = all_sim_data[all_sim_data['DS'].isin(ds_list)]

In [None]:
unique_ds_cats = list(ds_info.loc[ds_list, 'domain'].unique())
ds_cat_combinations = list(product(unique_ds_cats, unique_ds_cats))
ds_cat_combinations = sorted(list(set([tuple(sorted(tup))for tup in ds_cat_combinations ])))
ds_cat_combinations

In [None]:
ds_per_cat = ds_info.loc[ds_list].reset_index(names=['ds_name']).groupby('domain')['ds_name'].unique()

In [None]:
def compute_corr(x, y, corr_type):
    """ Compute correlation between two arrays x and y using the specified correlation"""
    if corr_type == 'pearsonr':
        corr, _ = pearsonr(x, y, )
    elif corr_type == 'spearmanr':
        corr, _ = spearmanr(x, y, )
    else:
        raise ValueError('Unknown corr type')
    return corr


def get_all_correlations(df, corr_type, col_data):
    r_vals = []

    for ds_cat1, ds_cat2 in ds_cat_combinations:
        if ds_cat1 == ds_cat2:
            curr_ds_tuple_list = combinations(list(ds_per_cat.loc[ds_cat1]), 2)
        else:
            curr_ds_tuple_list = product(list(ds_per_cat.loc[ds_cat1]), list(ds_per_cat.loc[ds_cat2]))

        for ds1, ds2 in curr_ds_tuple_list:
            ds1_subset = df[df['DS'] == ds1]
            ds2_subset = df[df['DS'] == ds2]
            r_vals.append({
                'ds1': ds1,
                'ds2': ds2,
                'ds_cat_pair': json.dumps((ds_cat1, ds_cat2)),
                'r coeff': compute_corr(ds1_subset[col_data].values, ds2_subset[col_data].values, corr_type)
                })
    return pd.DataFrame(r_vals)


def get_sim_metric_dist_info(df):
    """Get distribution information of the correlation coefficients computed over all dataset pairs for each similarity metric when no model grouping in used"""
    all_corrs = get_all_correlations(df, corr_type, 'Similarity value')
    return all_corrs.groupby('ds_cat_pair')['r coeff'].describe()

dist_info_no_cats = {}
for sim_metric, data in all_sim_data.groupby('Similarity metric'):
    dist_info_no_cats[sim_metric] = get_sim_metric_dist_info(data)
    print(sim_metric)
    display(dist_info_no_cats[sim_metric])

#### Plotting helper functions

In [None]:
def tuple2string(tup_dat):
    return f"{tup_dat[0]}, {tup_dat[1]}"


def wrap_labels(ax, width, break_long_words=False):
    x_ticks = ax.get_xticks()
    labels = [label.get_text() for label in ax.get_xticklabels()]
    wrapped_labels = ['\n'.join(wrap(label, width, break_long_words=break_long_words)) for label in labels]
    ax.set_xticks(x_ticks, wrapped_labels, rotation=0, ha='center')


def create_custom_legend(color_maps):
    legend_patches = [mpatches.Patch(color=color, label=cat) for cat, color in color_maps.items()]

    plt.legend(handles=legend_patches, ncols=len(legend_patches),
               title="", loc='center', bbox_to_anchor=(0.5, -0.25),
               fontsize=curr_fontsizes['ticks'],
               frameon=False
               )


def get_colored_tick_labels(label, method_colors):
    methods = label.split(', ')
    colored_label = [(method, method_colors.get(method, 'black')) for method in
                     methods]  # Default color is 'black' if method not found
    return colored_label


def set_colored_labels(ax, method_colors, y_pos_init=- 0.05, y_height=-0.11):
    x_ticks = ax.get_xticks()
    labels = [label.get_text() for label in ax.get_xticklabels()]
    ax.set_xticklabels([])  # Remove existing tick labels

    # Get the figure renderer for bounding box calculations
    renderer = ax.figure.canvas.get_renderer()

    # Store all text objects and their bounding boxes for each tick position
    all_boxes = []  # List to store (text_obj, bbox) for each tick position

    # First pass: Create all text objects and get their bounding boxes
    for tick_idx, (x_tick, label) in enumerate(zip(x_ticks, labels)):
        colored_methods = get_colored_tick_labels(label, method_colors)
        x_pos = x_tick
        y_pos = y_pos_init
        tick_texts = []
        for i, (method, color) in enumerate(colored_methods):
            method = method.split(' DS')[0]

            text = ax.text(
                x_pos, y_pos, method,
                color=color,
                fontsize=curr_fontsizes['label'],
                ha='center',
                va='top',
                transform=ax.get_xaxis_transform(),
                rotation=0,
            )

            bbox = text.get_window_extent(renderer=renderer)
            tick_texts.append((text, bbox))
            y_pos += y_height

        all_boxes.append(tick_texts)


def get_dist_plot_for_cat(r_coeff_subdata, curr_all_ds_info, verbose=False):
    # remove duplicates
    print('R coeff data before duplicate removal: ', r_coeff_subdata.shape)
    r_coeff_subdata_wo_dup = r_coeff_subdata[
        ~r_coeff_subdata[['ds1', 'ds2', 'cat_pair', 'r coeff']].duplicated()].reset_index(drop=True)
    print('R coeff data after duplicate removal: ', r_coeff_subdata_wo_dup.shape)

    # Get color maps
    sub_cats = list(np.unique(r_coeff_subdata_wo_dup[['anchor_cat', 'other_cat']].apply(np.unique, axis=0)))
    if r_coeff_subdata_wo_dup['Comparison category'].unique()[0] == 'Model size':
        sub_cats = model_size_order
    color_maps = {cat: cat_color_mapping[cat] for cat in sub_cats}

    # Get sorting order
    sorting_order = r_coeff_subdata_wo_dup.groupby('cat_pair')['r coeff'].median().sort_values(
        ascending=False).index.tolist()
    colors = [(color_maps[cat1], color_maps[cat2]) for (cat1, cat2) in sorting_order]
    sorting_order = [tuple2string(tup_data) for tup_data in sorting_order]

    # Convert tuples to strings
    r_coeff_subdata_wo_dup['cat_pair'] = r_coeff_subdata_wo_dup['cat_pair'].apply(tuple2string)

    tmp = r_coeff_subdata_wo_dup.groupby('cat_pair')['r coeff'].describe().sort_values('mean')
    tmp['Mean - std over 0.5'] = (tmp['mean'] - 2 * tmp['std']) > 0.5
    if verbose:
        display(tmp)

    # Plot
    add_inches = 2.5 if version == 'arxiv' else 3.5
    plt.figure(figsize=(len(sorting_order) + add_inches, 4))
    g = sns.boxplot(
        r_coeff_subdata_wo_dup,
        x='cat_pair',
        y='r coeff',
        order=sorting_order
    )

    ## infor of corr when not using any categories 
    g.axhline(curr_all_ds_info['50%'], c='grey', ls=':', alpha=0.5, zorder=-1)
    g.axhspan(curr_all_ds_info['25%'], curr_all_ds_info['75%'], fc='lightgrey', alpha=0.5, zorder=-1)

    ## color label 
    set_colored_labels(g, color_maps)
    g.tick_params(axis='y', which='major', labelsize=curr_fontsizes['ticks'])

    ## adapt color boxes 
    for patch, color in zip(g.patches, colors):
        patch.set_facecolor('none')

        vertices = patch.get_path().vertices

        # Get box position and dimensions
        x = vertices[0, 0]  # left x position
        width = vertices[2, 0] - x  # width of the box
        y_bottom = vertices[1, 1]  # bottom of the box
        y_top = vertices[2, 1]  # top of the box
        height = y_top - y_bottom  # height of the box

        cmap = LinearSegmentedColormap.from_list('gradient', [color[0], color[1]])

        gradient = np.linspace(0, 1, 256).reshape(1, -1)
        g.imshow(gradient, aspect='auto', cmap=cmap,
                 extent=(x, x + width, y_bottom, y_top), zorder=-1)
    g.set_xlim(-1, len(colors))
    ylim = (r_coeff_subdata_wo_dup['r coeff'].min() - 0.05, r_coeff_subdata_wo_dup['r coeff'].max() + 0.05)
    g.set_ylim(ylim)
    g.set_xlabel('')
    g.set_ylabel('Correlation coefficient', fontsize=curr_fontsizes['label'])

    return plt.gcf()


#### Plotting

In [None]:
def get_dist_plot_for_cat_ax(ax, r_coeff_subdata, curr_all_ds_info, verbose=False, ylim=(-0.6, 1.05)):
    # remove duplicates
    r_coeff_subdata_wo_dup = r_coeff_subdata[
        ~r_coeff_subdata[['ds1', 'ds2', 'cat_pair', 'r coeff']].duplicated()].reset_index(drop=True)

    # Get color maps
    sub_cats = list(np.unique(r_coeff_subdata_wo_dup[['anchor_cat', 'other_cat']].apply(np.unique, axis=0)))
    if r_coeff_subdata_wo_dup['Comparison category'].unique()[0] == 'Model size':
        sub_cats = model_size_order
    color_maps = {cat: cat_color_mapping[cat] for cat in sub_cats}

    # Get sorting order
    sorting_order = r_coeff_subdata_wo_dup.groupby('cat_pair')['r coeff'].median().sort_values(
        ascending=False).index.tolist()
    colors = [(color_maps[cat1], color_maps[cat2]) for (cat1, cat2) in sorting_order]
    sorting_order = [tuple2string(tup_data) for tup_data in sorting_order]

    # Convert tuples to strings
    r_coeff_subdata_wo_dup['cat_pair'] = r_coeff_subdata_wo_dup['cat_pair'].apply(tuple2string)

    tmp = r_coeff_subdata_wo_dup.groupby('cat_pair')['r coeff'].describe().sort_values('mean')
    tmp['Mean - std over 0.5'] = (tmp['mean'] - 2 * tmp['std']) > 0.5
    if verbose:
        display(tmp)

    ax = sns.boxplot(
        r_coeff_subdata_wo_dup,
        x='cat_pair',
        y='r coeff',
        order=sorting_order,
        ax=ax
    )

    ## infor of corr when not using any categories 
    ax.axhline(curr_all_ds_info['50%'], c='grey', ls=':', lw=3, alpha=0.75, zorder=-1)
    ax.axhspan(curr_all_ds_info['25%'], curr_all_ds_info['75%'], fc='lightgrey', alpha=0.75, zorder=-1)

    ## color label 
    set_colored_labels(ax, color_maps)
    ax.tick_params(axis='y', which='major', labelsize=curr_fontsizes['ticks'])

    ## adapt color boxes 
    for patch, color in zip(ax.patches, colors):
        patch.set_facecolor('none')

        vertices = patch.get_path().vertices

        # Get box position and dimensions
        x = vertices[0, 0]  # left x position
        width = vertices[2, 0] - x  # width of the box
        y_bottom = vertices[1, 1]  # bottom of the box
        y_top = vertices[2, 1]  # top of the box
        height = y_top - y_bottom  # height of the box

        cmap = LinearSegmentedColormap.from_list('gradient', [color[0], color[1]])

        gradient = np.linspace(0, 1, 256).reshape(1, -1)
        ax.imshow(gradient, aspect='auto', cmap=cmap,
                  extent=(x, x + width, y_bottom, y_top), zorder=-1)

    ax.set_xlim(-1, len(colors))
    ax.set_ylim(ylim)
    ax.set_xlim(-0.75, len(sorting_order) - 0.25)
    ax.set_xlabel('')
    ax.set_ylabel('Correlation coeff.', fontsize=curr_fontsizes['label'])


def create_subplot_figure(subset_data, curr_all_ds_info):
    # Create figure
    add_inches = 0.5 if version == 'arxiv' else 2.5
    fig = plt.figure(figsize=(10 + add_inches, 3.5 * 3))  # Adjust size as needed

    min_r_coeff = subset_data['r coeff'].min()
    max_r_coeff = subset_data['r coeff'].max()
    offset = 0.05 * (max_r_coeff - min_r_coeff)
    ylim = (min_r_coeff-offset, max_r_coeff+offset)

    # Create GridSpec with custom width ratios for the first row (2:1 ratio)
    gs = GridSpec(3, 2, figure=fig,
                  width_ratios=[2, 1],  # A gets 2 parts, B gets 1 part
                  height_ratios=[1, 1, 1])  # Adjust heights if needed

    # First row: A (2/3 width) and B (1/3 width)
    ax1 = fig.add_subplot(gs[0, 0])  # First plot (reference)
    ax2 = fig.add_subplot(gs[0, 1], sharey=ax1)  # Share y with first plot
    ax3 = fig.add_subplot(gs[1, :], sharey=ax1)  # Share y with first plot
    ax4 = fig.add_subplot(gs[2, :], sharey=ax1)  # Share y with first plot
    axes = [ax1, ax2, ax3, ax4]
    plt.setp(ax2.get_yticklabels(), visible=False)

    for cat, ax in zip(model_cat_mapping.values(), axes):
        group_data = subset_data[subset_data['Comparison category'] == cat]

        get_dist_plot_for_cat_ax(ax, group_data, curr_all_ds_info, verbose=False, ylim=ylim)

        ax.set_title(cat, fontsize=curr_fontsizes['title'])

    fig.tight_layout()
    plt.subplots_adjust(hspace=0.45)
    return fig

In [None]:
# for group_key, group_data in r_coeff_data.groupby(['Comparison category', 'Similarity metric', 'ds_cat_pair']):
#     print(group_key)
#     curr_all_ds_info = dist_info_no_cats[group_key[1]].loc[json.dumps(group_key[2]),:]
#     print(curr_all_ds_info)
#     fig = get_dist_plot_for_cat(group_data, curr_all_ds_info, False)

#     curr_cat, curr_sim, ds_cat_pair = group_key
#     curr_cat = curr_cat.replace(" ", "_").lower()
#     curr_sim = curr_sim.replace(" ", "_").lower()

#     ds_cat_pair = ds_cat_pair[0].replace(" ", "_").lower() +"_n_"+ ds_cat_pair[1].replace(" ", "_").lower()

#     save_or_show(fig, storing_path / f'dist_r_coeff_cat_anchor_box_{curr_cat}_{curr_sim}_{ds_cat_pair}.pdf', SAVE)

In [None]:
for group_key, subset_data in r_coeff_data.groupby(['Similarity metric', 'ds_cat_pair']):
    print(group_key)
    curr_all_ds_info = dist_info_no_cats[group_key[0]].loc[json.dumps(group_key[1]),:]
    fig = create_subplot_figure(subset_data, curr_all_ds_info)

    curr_sim = group_key[0].replace(" ", "_").lower()
    ds_cat_pair = group_key[1][0].replace(" ", "_").lower() +"_n_"+ group_key[1][1].replace(" ", "_").lower()
    save_or_show(fig, storing_path / f'dist_r_coeff_cat_anchor_box_all_cats_{curr_sim}_{ds_cat_pair}.pdf', SAVE)