In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from textwrap import wrap
import matplotlib.patches as mpatches
from itertools import combinations
from scipy.stats import pearsonr, spearmanr

from constants import (
    model_size_order,
    fontsizes,
    cat_color_mapping
)
from helper import save_or_show

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

sns.set_style('ticks')

#### Global variables

In [None]:
# Datasets
ds_list = parse_datasets('../scripts/webdatasets_w_insub10k.txt')
ds_list = list(map(lambda x: x.replace('/', '_'), ds_list))

# Experiment configuration
corr_type = 'spearmanr'  # 'pearsonr', 'spearmanr'
suffix = ''  # '', '_wo_mae'
exp_conf = f'{corr_type}{suffix}'

# Path to correlation data
base_path = Path('/home/space/diverse_priors/results/aggregated/r_coeff_dist/with_cats_as_anchors')
# base_path = Path('/Users/lciernik/Documents/TUB/projects/divers_prios/results/aggregated/r_coeff_dist/with_cats_as_anchors')
data_path = base_path / f'agg_{corr_type}_all_ds_with_rsa{suffix}.csv'

# Storing path
SAVE = True
storing_path = Path(
    f'/home/space/diverse_priors/results/plots/dist_r_coeff_cats_as_anchors_comp_local_global/{exp_conf}'
)
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

# Plotting helper
cm = 0.393701

#### Load data

In [None]:
r_coeff_data = pd.read_csv(data_path)
r_coeff_data = r_coeff_data[r_coeff_data['ds1'].isin(ds_list) & r_coeff_data['ds2'].isin(ds_list)].reset_index(drop=True).copy()

In [None]:
from scipy import stats
combs = [('CKA linear', 'CKA RBF 0.4'), ('CKA linear', 'RSA spearman')]
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))  # Increased width for better visibility

for i, (x, y) in enumerate(combs):
    dat1 = r_coeff_data[r_coeff_data['Similarity metric'] == x]
    dat2 = r_coeff_data[r_coeff_data['Similarity metric'] == y]
    dat1 = dat1.set_index(['ds1', 'ds2', 'anchor_cat', 'other_cat'])
    dat2 = dat2.set_index(['ds1', 'ds2', 'anchor_cat', 'other_cat'])
    
    dat1.columns = [col + ' sm1' for col in dat1.columns]
    dat2.columns = [col + ' sm2' for col in dat2.columns]
    dat_concat = pd.concat([dat1, dat2], axis=1)
    
    ax = axs[i]
    sns.scatterplot(data=dat_concat, x="r coeff sm1", y="r coeff sm2", alpha=0.25, s=5, ax=ax)
    
    ax.set_xlabel(f'r coeff. ({x})', fontsize=fontsizes['label'])
    ax.set_ylabel(f'r coeff. ({y})', fontsize=fontsizes['label'])
    ax.tick_params(labelsize=fontsizes['ticks'])
    
    r, p = stats.pearsonr(dat_concat['r coeff sm1'], dat_concat['r coeff sm2'])
    ax.text(0.05, 0.95, f'Overall r = {r:.2f}\np-value < 0.001', transform=ax.transAxes, 
            verticalalignment='top', fontsize=fontsizes['legend'])

save_or_show(fig, storing_path / f'consistency_local_global_scatter_plot{suffix}.pdf', SAVE)