In [None]:
from itertools import combinations

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import itertools
from matplotlib.lines import Line2D
from matplotlib.colors import LinearSegmentedColormap
from scipy.stats import spearmanr, pearsonr

from constants import (
    exclude_models,
    exclude_models_w_mae,
    cat_name_mapping,
    model_config_file,
    fontsizes,
    fontsizes_cols,
    model_cat_mapping,
    BASE_PATH_RESULTS,
    BASE_PATH_PROJECT,
    ds_list_sim_file,
    cat_color_mapping,
    sim_metric_name_mapping
)
from helper import load_model_configs_and_allowed_models, save_or_show, pp_storing_path, load_all_datasetnames_n_info, load_similarity_matrices

In [None]:
# Define the path to the similarity matrices
base_path_similarity_matrices = BASE_PATH_PROJECT / 'model_similarities_in1k_places'

# Define similarity metrics to be used
sim_metrics = [
    'cka_kernel_linear_unbiased',
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'rsa_method_correlation_corr_method_spearman',
]

# Load used dataset names
ds_list, ds_info = load_all_datasetnames_n_info(ds_list_sim_file, verbose=False)

# Define model filtering suffix
suffix = ''  # '_wo_mae'

# Version
version = 'arxiv'  #'arxiv'

corr_method = 'spearmanr' # spearman

# Define storing information
SAVE = True
# storing_path = pp_storing_path(BASE_PATH_RESULTS / 'plots' / 'final' / version / 'in1k_vs_places', SAVE)
storing_path = pp_storing_path(f'/home/lciernik/projects/divers-priors/results_local/corr_mats_ds/in1k_vs_places/{corr_method}', SAVE)

In [None]:
# Load model configurations and allowed models
model_configs, allowed_models = load_model_configs_and_allowed_models(
    path='../scripts/configs/models_config_places_vs_in1k.json',
    exclude_models=[],
    exclude_alignment=True,
)

In [None]:
# Load similarity matrices
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)

#### Create heatmaps

In [None]:
order = [
    'alexnet_places365',
    'densenet161_places365',
    'resnet18_places365',
    'resnet50_places365',
    'alexnet',
    'densenet161',
    'resnet18',
    'resnet50',
 ]

In [None]:
n, m = 6, 4
bsize = 3
for metric, curr_sim_mats in sim_mats.items():
    fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(m*bsize, n*bsize), sharex=True, sharey=True)
    axes = axes.flatten()
    for i,(k, v) in enumerate(curr_sim_mats.items()):
        sns.heatmap(v.loc[order, order], ax=axes[i])
        axes[i].set_title(ds_info.loc[k, 'name'])
    axes[-1].axis("off")
    save_or_show(fig, storing_path / f'heatmaps_{metric}.pdf', SAVE)

### Create boxplots

In [None]:
def get_pair_cat(row):
    if 'places' in row['model_M1']:
        M1_cat = 'Places365'
    else:
        M1_cat = 'IN1K'

    if 'places' in row['model_M2']:
        M2_cat = 'Places365'
    else:
        M2_cat = 'IN1K'
    return tuple(sorted((M1_cat, M2_cat)))
        

def pp_sim_mat(df):
    idx = np.triu_indices_from(df.values, k=1)
    flat_df = pd.DataFrame(dict(
        sim_values = df.values[idx],
        model_M1 = df.index.values[idx[0]],
        model_M2 = df.index.values[idx[1]]
    ))

    flat_df['model_pair'] = flat_df[['model_M1', 'model_M2']].apply(tuple, axis=1)
    flat_df['train_ds_pair'] = flat_df[['model_M1', 'model_M2']].apply(get_pair_cat, axis=1)
    

    return flat_df


def pp_all_sim_mats(sm):
    return { k:pp_sim_mat(v) for k,v in sm.items()}
    
    

In [None]:
def correlation(x, y):
    corr = None
    if corr_method == 'spearmanr':
        corr, _ = spearmanr(x, y)
    elif corr_method == 'kendalltau':
        corr, _ = kendalltau(x, y)
    elif corr_method == 'pearsonr':
        corr, _ = pearsonr(x, y)
    return corr


def compute_corrs_ds_pair(p1, p2):
    corrs = {}
    for pair_cat, pair_data in p1.groupby('train_ds_pair'):
        sim_1 = pair_data['sim_values']
        sim_2 = p2.loc[pair_data.index]['sim_values']
        corrs[pair_cat] = correlation(sim_1, sim_2)
    corrs['all'] = correlation(p1['sim_values'], p2['sim_values'])
    return corrs


def compute_all_consistencies(ds_pairs, pp_sm):
    all_corrs = []
    for ds1, ds2 in ds_pairs:
        p1 = pp_sim_mats[ds1]
        p2 = pp_sim_mats[ds2]
        p1_p2_corrs = compute_corrs_ds_pair(p1, p2)
        p1_p2_corrs['ds1'] = ds1
        p1_p2_corrs['ds2'] = ds2
        all_corrs.append(p1_p2_corrs)
    all_corrs = pd.DataFrame(all_corrs)
    all_corrs.columns = ['IN1k, IN1k', 'IN1k, Places365', 'Places365, Places365', 'All combs', 'ds1', 'ds2']
    return all_corrs

In [None]:
color_maps = {
    'IN1k': cat_color_mapping['IN1k'],
    'Places365': '#4593c8'
}
color_maps

In [None]:
def get_colored_tick_labels(label, method_colors):
    methods = label.split(', ')
    colored_label = [(method, method_colors.get(method, '#335114')) for method in
                     methods]  # Default color is 'black' if method not found
    return colored_label

def set_colored_labels(ax, method_colors, y_pos_init=- 0.05, y_height=-0.07):
    x_ticks = ax.get_xticks()
    labels = [label.get_text() for label in ax.get_xticklabels()]
    ax.set_xticklabels([])  # Remove existing tick labels

    # Get the figure renderer for bounding box calculations
    renderer = ax.figure.canvas.get_renderer()

    # Store all text objects and their bounding boxes for each tick position
    all_boxes = []  # List to store (text_obj, bbox) for each tick position

    colors = []
    # First pass: Create all text objects and get their bounding boxes
    for tick_idx, (x_tick, label) in enumerate(zip(x_ticks, labels)):
        colored_methods = get_colored_tick_labels(label, method_colors)
        x_pos = x_tick
        y_pos = y_pos_init
        tick_texts = []
        curr_colors = []
        for i, (method, color) in enumerate(colored_methods):
            method = method.split(' DS')[0]

            text = ax.text(
                x_pos, y_pos, method,
                color=color,
                fontsize=fontsizes['ticks'],
                ha='center',
                va='top',
                transform=ax.get_xaxis_transform(),
                rotation=0,
            )

            bbox = text.get_window_extent(renderer=renderer)
            tick_texts.append((text, bbox))
            y_pos += y_height
            curr_colors.append(color)
        colors.append(curr_colors)

        all_boxes.append(tick_texts)
    return colors


def get_box_plot(corrs):
    g = sns.boxplot(corrs, linewidth=1.5)
    
    colors = set_colored_labels(g, color_maps)
    
    for patch, color in zip(g.patches, colors):
            patch.set_facecolor('none')
    
            vertices = patch.get_path().vertices
            if len(color) ==1:
                color = color*2
    
            # Get box position and dimensions
            x = vertices[0, 0]  # left x position
            width = vertices[2, 0] - x  # width of the box
            y_bottom = vertices[1, 1]  # bottom of the box
            y_top = vertices[2, 1]  # top of the box
            height = y_top - y_bottom  # height of the box
    
            cmap = LinearSegmentedColormap.from_list('gradient', [color[0], color[1]])
    
            gradient = np.linspace(0, 1, 256).reshape(1, -1)
            g.imshow(gradient, aspect='auto', cmap=cmap,
                     extent=(x, x + width, y_bottom, y_top), zorder=-1)
    
    g.set_xlim(-1, len(colors))
    
    ylim = (corrs.min(numeric_only=True).min(numeric_only=True) - 0.05, 
            corrs.max(numeric_only=True).max(numeric_only=True) + 0.05)
    g.set_ylim(ylim)
    
    g.tick_params('y', labelsize=fontsizes['ticks'])
    return plt.gcf(), g 



In [None]:
for metric, curr_sim_mats in sim_mats.items():
    pp_sim_mats = pp_all_sim_mats(curr_sim_mats)
    
    ds_pairs = list(itertools.combinations(pp_sim_mats.keys(), 2))
    
    all_correlations = compute_all_consistencies(ds_pairs, pp_sim_mats)

    fig, ax = get_box_plot(all_correlations)
    ax.set_title(sim_metric_name_mapping[metric], fontsize=fontsizes['title'])
    
    save_or_show(fig, storing_path / f'consistency_{metric}.pdf', SAVE)