In [None]:
import sys
from itertools import combinations
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import json

from constants import (anchor_name_mapping, available_data, exclude_models, exclude_models_w_mae, model_config_file)
from helper import load_model_configs_and_allowed_models, load_similarity_matrices, save_or_show

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
]
sim_metrics_mapped = [sim_metric_name_mapping[k] for k in sim_metrics]

ds_list = parse_datasets('../scripts/webdatasets_w_insub10k.txt')
ds_list = list(map(lambda x: x.replace('/', '_'), ds_list))

cm = 0.393701

anchors = [
    'OpenCLIP_RN50_openai',
    'OpenCLIP_ViT-L-14_openai',
    'simclr-rn50',
    'dinov2-vit-large-p14',
    'resnet50',
    'vit_large_patch16_224',
]

curr_data = available_data[3]
curr_data_wo_ext = curr_data.split('.')[0]
agg_data_path = Path(f'/home/space/diverse_priors/results/aggregated/r_coeff_dist/{curr_data}')
tmp = pd.read_csv(agg_data_path)
print(curr_data, tmp.shape)

SAVE = True
storing_path = Path(f'/home/space/diverse_priors/results/plots/distribution_similarity/{curr_data_wo_ext}')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
curr_excl_models = exclude_models_w_mae if 'mae' in curr_data else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)

### Line plot distributions

In [None]:
r_df = pd.read_csv(agg_data_path)

In [None]:
r_df = r_df[r_df['Anchor Model'].isin(anchors)]
sim_mets = sim_metrics_mapped
n = len(anchors)
m = len(sim_metrics_mapped)

include_all = False

for comp_cat in r_df['Comparison category'].unique():
    subset = r_df[r_df['Comparison category'] == comp_cat]
    comp_cat_orig = subset['Comparison category (orig. name)'].unique()[0]
    fig, axes = plt.subplots(nrows=m, ncols=n, figsize=(n * 6 * cm, m * 3.5 * cm), sharex=True, sharey=False)
    handles, labels = [], []
    for anchor in anchors:
        for sim_metric in sim_mets:
            group_data = subset[(subset['Anchor Model'] == anchor) & (subset['Similarity metric'] == sim_metric)]
            key = (anchor, sim_metric)
            idx_row = list(anchors).index(key[0])
            idx_col = list(sim_mets).index(key[1])
            ax = axes[idx_col, idx_row]

            if not include_all:
                group_data = group_data[group_data['Comparison values'] != 'All']

            hue_order = sorted(list(group_data['Comparison values'].unique()))
            if include_all:
                hue_order = hue_order[1:] + hue_order[:1]
            elif comp_cat == 'Model size':
                hue_order = ['small', 'medium', 'large', 'xlarge']

            sns.kdeplot(
                group_data,
                x='r coeff',
                hue='Comparison values',
                hue_order=hue_order,
                ax=ax
            )

            ax.axvline(0.5, c='grey', ls=':', alpha=0.5, zorder=-1, lw=0.75)
            ax.set_xlim([-0.5, 1.1])

            ax.set_title(f"{anchor_name_mapping[key[0]]}" if idx_col == 0 else "", fontsize=11,
                         color=sns.color_palette('tab10', 10).as_hex()[
                             hue_order.index(model_configs.loc[key[0]][comp_cat_orig])])
            ax.set_xlabel('' if idx_col < (m - 1) else 'Correlation coefficient')
            ax.set_ylabel('' if idx_row > 0 else f"{key[1]}\nDensity", fontsize=11)

            if idx_col == 0 and idx_row == n - 1:
                sns.move_legend(ax, loc='upper left', title=comp_cat, bbox_to_anchor=(1, 1), fontsize=11,
                                title_fontsize=11, frameon=False)
            else:
                ax.get_legend().remove()

    fig.tight_layout()
    save_or_show(fig,
                 storing_path / f'line_plot_{comp_cat.replace(" ", "_")}.pdf',
                 SAVE)


#### Wasserstein distance data 

In [None]:
from itertools import combinations
from scipy.stats import wasserstein_distance

sim_metrics = r_df['Similarity metric'].unique()
metric_combs = list(combinations(sim_metrics, 2))

rows = []

for key, group_data in r_df.groupby(['Anchor Model', 'Comparison category', 'Comparison values']):
    for met1, met2 in metric_combs:
        r_vals_met1 = group_data[group_data['Similarity metric'] == met1]['r coeff']
        r_vals_met2 = group_data[group_data['Similarity metric'] == met2]['r coeff']
        ws_dist = wasserstein_distance(r_vals_met1, r_vals_met2)
        rows.append({
            'Anchor Model': key[0],
            'Comparison category': key[1],
            'Comparison values': key[2],
            'Metric1': met1,
            'Metric2': met2,
            'wasserstei': ws_dist
        })
ws_df = pd.DataFrame(rows)

In [None]:
if SAVE:
    ws_df.to_csv(storing_path / 'ws_distances.csv', index=False)