## Notebook appendix E: *Does local or global similarity differ?*
This notebook creates the scatter plot of appendix E. The boxplots are generated in the `4_5_model_cats_influencing_similarity_consistency.ipynb`.    

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from constants import (
    BASE_PATH_RESULTS,
    ds_list_sim_file,
    fontsizes,
    fontsizes_cols
)
from helper import (
    load_all_datasetnames_n_info,
    pp_storing_path,
    save_or_show
)

sns.set_style('ticks')

#### Global variables

In [None]:
# Datasets
ds_list, ds_info = load_all_datasetnames_n_info(ds_list_sim_file, verbose=False)

# Experiment configuration
corr_type = 'pearsonr'  # 'pearsonr', 'spearmanr'
suffix = ''  # '', '_wo_mae'
exp_conf = f'{corr_type}{suffix}'

# Path to correlation data
data_path = BASE_PATH_RESULTS / f'aggregated/r_coeff_dist/with_cats_as_anchors/agg_{corr_type}_all_ds_with_rsa{suffix}.csv'
assert data_path.exists(), f'Path does not exist: {data_path}. Aggregated correlation coefficients across all dataset pairs not found, please run aggregate_consistencies_for_model_set_pairs.ipynb first.'

## Version and plotting info
version = 'arxiv'
curr_fontsizes = fontsizes if version == 'arxiv' else fontsizes_cols

SAVE = True
storing_path = pp_storing_path(BASE_PATH_RESULTS / 'plots' / 'final' / version / 'app_E_local_vs_global', SAVE)

#### Load data

In [None]:
r_coeff_data = pd.read_csv(data_path)
r_coeff_data = r_coeff_data[r_coeff_data['ds1'].isin(ds_list) & r_coeff_data['ds2'].isin(ds_list)].reset_index(
    drop=True).copy()

#### Create scatterplots with all correlations between all model set pairs and dataset pairs

In [None]:
from scipy import stats

combs = [('CKA linear', 'CKA RBF 0.4'), ('CKA linear', 'RSA spearman')]
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))  # Increased width for better visibility

for i, (x, y) in enumerate(combs):
    dat1 = r_coeff_data[r_coeff_data['Similarity metric'] == x]
    dat2 = r_coeff_data[r_coeff_data['Similarity metric'] == y]
    dat1 = dat1.set_index(['ds1', 'ds2', 'anchor_cat', 'other_cat'])
    dat2 = dat2.set_index(['ds1', 'ds2', 'anchor_cat', 'other_cat'])

    dat1.columns = [col + ' sm1' for col in dat1.columns]
    dat2.columns = [col + ' sm2' for col in dat2.columns]
    dat_concat = pd.concat([dat1, dat2], axis=1)

    ax = axs[i]
    sns.scatterplot(data=dat_concat, x="r coeff sm1", y="r coeff sm2", alpha=0.5, s=10, ax=ax)

    ax.set_xlabel(f'r coeff. ({x})', fontsize=curr_fontsizes['label'])
    ax.set_ylabel(f'r coeff. ({y})', fontsize=curr_fontsizes['label'])
    ax.tick_params(labelsize=curr_fontsizes['ticks'])

    r, p = stats.pearsonr(dat_concat['r coeff sm1'], dat_concat['r coeff sm2'])
    ax.text(0.05, 0.95, f'Overall r = {r:.2f}\np-value < 0.001', transform=ax.transAxes,
            verticalalignment='top', fontsize=curr_fontsizes['legend'])

plt.subplots_adjust(wspace=0.2 if version == 'arxiv' else 0.3)
save_or_show(fig, storing_path / f'consistency_local_global_scatter_plot{suffix}.pdf', SAVE)