In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap

from constants import (
    model_size_order,
    fontsizes
)
from helper import save_or_show

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

sns.set_style('ticks')

#### Global variables

In [None]:
# Similarity matrices
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
]
sim_metrics_mapped = [sim_metric_name_mapping[k] for k in sim_metrics]

# Datasets
ds_list = parse_datasets('../scripts/webdatasets_w_insub10k.txt')
ds_list = list(map(lambda x: x.replace('/', '_'), ds_list))

# Experiment configuration
corr_type = 'pearsonr'  # 'pearsonr', 'spearmanr'
suffix = ''  # '', '_wo_mae'
exp_conf = f'{corr_type}{suffix}'

# Path to correlation data
# base_path = Path('/home/space/diverse_priors/results/aggregated/r_coeff_dist/with_cats_as_anchors')
base_path = Path(
    '/Users/lciernik/Documents/TUB/projects/divers_prios/results/aggregated/r_coeff_dist/with_cats_as_anchors')
data_path = base_path / f'agg_{corr_type}_all_ds{suffix}.csv'

# Storing path
SAVE = False
storing_path = Path(
    f'/home/space/diverse_priors/results/plots/dist_r_coeff_cats_as_anchors/{exp_conf}'
)
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

# Plotting helper
cm = 0.393701

#### Load data

In [None]:
r_coeff_data = pd.read_csv(data_path)
r_coeff_data.head()

In [None]:
r_coeff_data['cat_pair'] = r_coeff_data[['anchor_cat', 'other_cat']].apply(
    lambda x: tuple(sorted([x['anchor_cat'], x['other_cat']])), axis=1)

#### Plotting helper functions

In [None]:
def get_distribution_plot(df_r_coeffs):
    n_plots = df_r_coeffs['anchor_cat'].nunique()
    if n_plots % 2 == 0:
        col_wrap = 2
    elif n_plots % 3 == 0:
        col_wrap = 3
    else:
        col_wrap = n_plots

    aspect = 1.25 if n_plots < 4 else 1.5

    g = sns.catplot(
        data=df_r_coeffs,
        x='other_cat',
        y='r coeff',
        col='anchor_cat',
        kind='boxen',
        hue='other_cat',
        palette='tab10',
        height=6 * cm,
        aspect=aspect,
        col_wrap=col_wrap

    )
    g.set_titles('{col_name}')
    g.set_xlabels('')
    g.set_ylabels('Correlation coefficient')
    for ax in g.axes.flatten():
        ax.axhline(0.5, c='grey', ls=':', alpha=0.5, zorder=-1)

    return g.fig



In [None]:
from textwrap import wrap
import matplotlib.patches as mpatches

hex2name = {v: k for k, v in sns.xkcd_rgb.items()}


def tuple2string(tup_dat):
    return f"{tup_dat[0]}, {tup_dat[1]}"


# Enable LaTeX rendering in Matplotlib
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{xcolor}'


def color_tick_labels(label):
    methods = label.split(', ')

    # Use LaTeX and xcolor for custom colors
    colored_label = [
        r"\textcolor{{red}}{{{}}}".format(method) if method == 'SSL' else r"\textcolor{{blue}}{{{}}}".format(method)
        for method in methods
    ]

    # Join the labels and enclose in $$ for LaTeX math mode
    return r'${}$'.format(',\n'.join(colored_label))


def wrap_labels(ax, width, break_long_words=False, color_tick_labels=False):
    x_ticks = ax.get_xticks()
    labels = [label.get_text() for label in ax.get_xticklabels()]
    if color_tick_labels:
        labels = [color_tick_labels(label) for label in labels]
        wrapped_labels = labels
    else:
        wrapped_labels = ['\n'.join(wrap(label, width, break_long_words=break_long_words)) for label in labels]
    ax.set_xticks(x_ticks, wrapped_labels, rotation=0, ha='center')


def create_custom_legend(color_maps):
    # Create patches for each category
    legend_patches = [mpatches.Patch(color=color, label=cat) for cat, color in color_maps.items()]

    # Add the legend to the plot
    plt.legend(handles=legend_patches, ncols=len(legend_patches),
               title="", loc='center', bbox_to_anchor=(0.5, -0.25),
               fontsize=fontsizes['ticks'],
               frameon=False
               )


def get_dist_plot_for_cat(r_coeff_subdata):
    # remove duplicates
    r_coeff_subdata_wo_dup = r_coeff_subdata[
        ~r_coeff_subdata[['ds1', 'ds2', 'cat_pair', 'r coeff']].duplicated()].reset_index(drop=True)

    # Get color maps
    sub_cats = list(np.unique(r_coeff_subdata_wo_dup[['anchor_cat', 'other_cat']].apply(np.unique, axis=0)))
    if r_coeff_subdata_wo_dup['Comparison category'].unique()[0] == 'Model size':
        sub_cats = model_size_order
    color_maps = {cat: color for cat, color in zip(sub_cats, sns.color_palette('tab10', len(sub_cats)).as_hex())}

    # Get sorting order
    sorting_order = r_coeff_subdata_wo_dup.groupby('cat_pair')['r coeff'].median().sort_values(
        ascending=False).index.tolist()
    colors = [(color_maps[cat1], color_maps[cat2]) for (cat1, cat2) in sorting_order]
    sorting_order = [tuple2string(tup_data) for tup_data in sorting_order]

    # Convert tuples to strings
    r_coeff_subdata_wo_dup['cat_pair'] = r_coeff_subdata_wo_dup['cat_pair'].apply(tuple2string)

    # Plot
    plt.figure(figsize=(len(sorting_order) + 2, 4))
    g = sns.boxplot(
        r_coeff_subdata_wo_dup,
        x='cat_pair',
        y='r coeff',
        order=sorting_order
    )

    wrap_labels(g, 10, break_long_words=False)
    # g.tick_params(axis='x', which='major', labelsize=fontsizes['ticks'], rotation=45)
    g.tick_params(axis='x', which='major', labelsize=fontsizes['ticks'])
    g.tick_params(axis='y', which='major', labelsize=fontsizes['ticks'])

    for patch, color in zip(g.patches, colors):
        patch.set_facecolor('none')

        vertices = patch.get_path().vertices

        # Get box position and dimensions
        x = vertices[0, 0]  # left x position
        width = vertices[2, 0] - x  # width of the box
        y_bottom = vertices[1, 1]  # bottom of the box
        y_top = vertices[2, 1]  # top of the box
        height = y_top - y_bottom  # height of the box

        cmap = LinearSegmentedColormap.from_list('gradient', [color[0], color[1]])

        gradient = np.linspace(0, 1, 256).reshape(1, -1)
        g.imshow(gradient, aspect='auto', cmap=cmap,
                 extent=(x, x + width, y_bottom, y_top), zorder=-1)

    create_custom_legend(color_maps)

    g.set_xlim(-1, len(colors))
    ylim = (r_coeff_subdata_wo_dup['r coeff'].min() - 0.05, r_coeff_subdata_wo_dup['r coeff'].max() + 0.05)
    g.set_ylim(ylim)
    # g.set_ylim(-0.21, 1.05)  
    g.set_xlabel('')
    g.set_ylabel('Correlation coefficient', fontsize=fontsizes['label'])
    g.axhline(0.5, c='grey', ls=':', alpha=0.5, zorder=-1)
    return plt.gcf()


#### Plotting

In [None]:
tmp = r_coeff_data[(r_coeff_data['Comparison category'] == 'Objective') & (
            r_coeff_data['Similarity metric'] == 'CKA linear')].sort_values('r coeff').copy()

In [None]:
for group_key, group_data in r_coeff_data.groupby(['Comparison category', 'Similarity metric']):
    fig = get_dist_plot_for_cat(group_data)
    plt.title(f"{group_key[0]} - {group_key[1]}")
    save_or_show(fig, storing_path / f'dist_r_coeff_{group_key[0]}_{group_key[1]}.pdf', SAVE)
    break

In [None]:
for group_key, group_data in r_coeff_data.groupby(['Comparison category', 'Similarity metric']):
    fig = get_distribution_plot(group_data)
    curr_cat, curr_sim = group_key
    curr_cat = curr_cat.replace(" ", "_").lower()
    curr_sim = curr_sim.replace(" ", "_").lower()
    save_or_show(fig,
                 storing_path / f'dist_r_coeff_cat_anchor_{curr_cat}_{curr_sim}.pdf',
                 SAVE)
    break