In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from scipy.stats import spearmanr
import json 
from helper import save_or_show, plot_r_coeff_distribution, load_model_configs_and_allowed_models
from collections import defaultdict
from scipy.stats import ranksums
from statsmodels.stats.multitest import multipletests
import starbars

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
    'rsa_method_correlation_corr_method_spearman',
]

x_axis_ds = 'imagenet-subset-10k'

SAVE = True
storing_path = Path('/home/space/diverse_priors/results/plots/ds_sim__imagenet_sim/r_coeff_distributions_v3')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)


# agg_data_path = Path('/home/space/diverse_priors/results/aggregated/r_coeff_dist/agg_corr_coeffs.csv')
agg_data_path = Path('/home/space/diverse_priors/results/aggregated/r_coeff_dist/agg_corr_coeffs_all_ds.csv')

In [None]:
model_configs, allowed_models = load_model_configs_and_allowed_models(
    path='../scripts/models_config.json',
    exclude_models=['SegmentAnything_vit_b', 'DreamSim_dino_vitb16', 'DreamSim_open_clip_vitb32'],
    exclude_alignment=True,
)

print(model_configs.shape, len(allowed_models))

In [None]:
anchors = [
    'OpenCLIP_RN50_openai', 
    'simclr-rn50', 
    'resnet50', 
    'OpenCLIP_ViT-L-14_openai',
    'mae-vit-large-p16',
    'vit_large_patch16_224',
]

anchor_name_mapping = {
    'OpenCLIP_RN50_openai': 'OpenCLIP RN50',
    'OpenCLIP_ViT-L-14_openai': 'OpenCLIP ViT-L',
    'resnet50': 'ResNet-50',
    'vit_large_patch16_224': 'ViT-L',
    'simclr-rn50': 'SimCLR RN50',
    'dino-vit-base-p16': 'DINO ViT-B',
    'dinov2-vit-large-p14': 'DINOv2 ViT-L',
    'mae-vit-large-p16': 'MAE ViT-L'
}
anchor_nm_val_list = list(anchor_name_mapping.values())
anchor_nm_val_list_v2 = [anchor_name_mapping[mid] for mid in anchors]


anchor_col = 'Anchor Model'
sim_metric_col = 'Similarity metric'
comp_cat_col = 'Comparison category'
comp_cat_orig_col = 'Comparison category (orig. name)'
comp_val_col = 'Comparison values'
r_col = 'r coeff'

In [None]:
r_df = pd.read_csv(agg_data_path)
r_df = r_df[r_df[anchor_col].isin(anchors)].copy().reset_index(drop=True)
r_df[anchor_col] = r_df[anchor_col].map(anchor_name_mapping)

In [None]:
# curr_sim_met = sim_metric_name_mapping[sim_metrics[1]]
# print(curr_sim_met)
# tmp = r_df[r_df[sim_metric_col] == curr_sim_met]

In [None]:
# cm = 0.393701
# # ncols = tmp[comp_cat_col].nunique()
# ncols = 2
# fig, axes = plt.subplots(nrows=1, ncols= ncols, figsize=(ncols*10*cm,16*cm), sharex=True, sharey=True)

# cmaps = {
#     'Objective': sns.color_palette('tab10', n_colors=4),
#     'Dataset size': sns.color_palette('tab10', n_colors=1) + sns.color_palette("flare", n_colors=4),
# }
# handles, labels = [], []
# for i, comp_cat in enumerate(['Objective', 'Dataset size']):
#     df = tmp[tmp[comp_cat_col] == comp_cat]
#     df = df.sort_values(anchor_col, key= lambda x: x.apply(lambda y: anchor_nm_val_list.index(y)))
#     nr_hue = df[comp_val_col].nunique()
#     ax = axes[i]
#     sns.boxplot(
#         data=df, 
#         x=r_col, 
#         y=anchor_col, 
#         hue=comp_val_col, 
#         legend=True,
#         ax=ax,
#         palette=cmaps[comp_cat]
#     )

#     ax.set_title(comp_cat)
#     ax.set_ylabel('')
#     ax.set_xlabel('Correlation coefficient')
#     thandles, tlabels = ax.get_legend_handles_labels()
#     handles += thandles if i==0 else thandles[1:]
#     labels += tlabels if i==0 else tlabels[1:]
#     ax.get_legend().remove()

#     for mid in ax.get_yticks()[:-1]:
#         ax.axhline(mid+0.5, ls=':', c='grey', alpha=0.5, lw=0.75)
# fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(1.11, 0.9), borderaxespad=0., frameon=False, title='Model category')
# fig.subplots_adjust(wspace=0.1)

In [None]:
# cm = 0.393701
# ncols = 2
# fig, axes = plt.subplots(nrows=1, ncols= ncols, figsize=(ncols*12*cm,15*cm), sharex=True, sharey=False)


# handles, labels = [], []
    
# for i, comp_cat in enumerate(['Objective', 'Dataset size']):
#     df = tmp[tmp[comp_cat_col] == comp_cat]
#     ax = axes[i]
#     sns.boxplot(
#         data=df, 
#         x=r_col, 
#         y=comp_val_col, 
#         hue=anchor_col,
#         hue_order=anchor_nm_val_list_v2,
#         legend=True,
#         ax=ax,
#         palette='tab10'
#     )

#     ax.set_title(comp_cat)
#     ax.set_ylabel('')
#     ax.set_xlabel('Correlation coefficient')
#     if i==0:
#         handles, labels = ax.get_legend_handles_labels()
#     ax.get_legend().remove()
#     for mid in ax.get_yticks()[:-1]:
#         ax.axhline(mid+0.5, ls=':', c='grey', alpha=0.5, lw=0.75)

# fig.subplots_adjust(wspace=0.3)
# fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(1.07, 0.95), borderaxespad=0.,frameon=False, title='Anchor Model')
# fig.tight_layout(rect=[0, 0, 0.9, 1])

In [None]:
import itertools
def get_pairs(df, strata_col):
    tuple_list = df[[strata_col, sim_metric_col]].value_counts().sort_index().index.tolist()
    tuple_list = [(v1, v2) for v1, v2 in tuple_list if v1!='All']
    pairs = [(a, b) for i, a in enumerate(tuple_list) for j, b in enumerate(tuple_list) if i < j and a[1] == b[1]]
    return pairs


def get_config_data(df, config, strata_col):
    return df[(df[strata_col] == config[0]) & (df[sim_metric_col] == config[1])]


def get_pvals_pairs(pairs, df, strata_col, alpha=0.05):
    p_values = []
    for (config1, config2) in pairs:
        dat1 = get_config_data(df,config1,strata_col)[r_col].reset_index(drop=True)
        dat2 = get_config_data(df,config2,strata_col)[r_col].reset_index(drop=True)
        idx2drop = list(np.where(dat1.isna())[0]) + list(np.where(dat2.isna())[0])
        dat1.drop(labels=idx2drop, inplace=True)
        dat2.drop(labels=idx2drop, inplace=True)
        assert len(dat1) == len(dat2)
        stat, p_value = ranksums(dat1, dat2)
        p_values.append(p_value)
    return p_values


def add_sign_bars(anchor_pairs, anchor_corr_pvals, axes, strata_col, alpha=0.01):
    total_annots = 0
    for i, anchor in enumerate(anchors):
        curr_pairs = anchor_pairs[anchor]
        curr_corr_pvals = anchor_corr_pvals[anchor]
        annotations = [(config1, config2, pval) for (config1, config2), pval in zip(curr_pairs, curr_corr_pvals) if pval < alpha]
        total_annots += len(annotations)
        starbars.draw_annotation(annotations, ax = axes[i])
        axes[i].get_legend().remove()
        
def correct_anchor_pvalues(anchor_pairs, anchor_pvals):
    anchor_corr_pvals = {}
    all_pvals = list(np.concatenate(list(anchor_pvals.values())))
    # corrected_all_pvals = multipletests(all_pvals, method='hs')[1]
    corrected_all_pvals = multipletests(all_pvals, method='bonferroni')[1]
    # corrected_all_pvals = multipletests(all_pvals, method='fdr_bh')[1]
    idx = 0 
    for anchor in anchors:
        nr_pvals = len(anchor_pairs[anchor])
        anchor_corr_pvals[anchor] = corrected_all_pvals[idx:(idx+nr_pvals)]
        idx += nr_pvals
    return anchor_corr_pvals


def plot_box_per_sim_metric(subset, strata_name, strata_col, strata_col_orig):
    

    n_rows = 2
    n_cols = 3

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3.55, n_rows * 2.6), sharey=False, sharex=False)
    axes = axes.flatten() 
    
    handles, labels = [], []
    tab10_cols = list(sns.color_palette("tab10").as_hex())
    comparison_type = list(subset[strata_col].unique())
    anchor_pairs = {}
    anchor_pvals = {}
    for i, anchor in enumerate(anchors):
        anchor_data = subset[subset[anchor_col] == anchor_name_mapping[anchor]]

        ## Compute pairwise distribution statistics
        anchor_pairs[anchor] = get_pairs(anchor_data, strata_col)
        anchor_pvals[anchor] = get_pvals_pairs(anchor_pairs[anchor], anchor_data, strata_col, alpha=0.01)
        
        sns.boxplot(
            data=anchor_data, 
            x=sim_metric_col,
            y=r_col,
            hue=strata_col,
            ax=axes[i]
        )
        strata_type = model_configs.loc[anchor][strata_col_orig]
        axes[i].set_title(f'{anchor_name_mapping[anchor]} ({strata_type})', color=tab10_cols[comparison_type.index(strata_type)],fontsize=11)
        axes[i].set_xlabel("")  # Remove x-axis label for clarity
        
        axes[i].legend(loc='best')
        if i==0:
            handles, labels = axes[i].get_legend_handles_labels()
        if (i)%n_cols == 0:
            axes[i].set_ylabel("Correlation coefficient", fontsize=10)
        else:
            axes[i].set_ylabel("")
        axes[i].tick_params('both', labelsize=10)
        for mid in axes[i].get_xticks()[:-1]:
            axes[i].axvline(mid+0.5, ls=':', c='grey', alpha=0.5, lw=0.75)
    
    anchor_corr_pvals = correct_anchor_pvalues(anchor_pairs, anchor_pvals)
    
    fig.subplots_adjust(hspace=-.2, wspace=-.15)
    
    add_sign_bars(anchor_pairs, anchor_corr_pvals, axes, strata_col, alpha=0.01)
    
    fig.legend(handles, labels, 
               bbox_to_anchor=(1.15, 0.60), 
               borderaxespad=0., 
               fontsize=11, 
               title = strata_name, 
               title_fontsize=11,
               frameon=False,
              )
    
    fig.tight_layout()  # Make room for the legend
    return fig


r_df.loc[r_df[comp_cat_col] == 'Dataset size', comp_cat_col] = 'Dataset diversity'
for include_all in [True, False]:
    if not include_all:
        subset = r_df[r_df[comp_val_col]!='All']
    else:
        subset = r_df
    for strata_cat, df in subset.groupby(comp_cat_col):
        strata_cat_orig = df[comp_cat_orig_col].unique()[0]
        fig = plot_box_per_sim_metric(df, strata_cat, comp_val_col, strata_cat_orig)
        save_or_show(fig, storing_path / f'dist_r_coeffs_{strata_cat.replace(" ", "_")}_only_box_per_sim_metric{"_include_all" if include_all else ""}.pdf', SAVE)