In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from scipy.stats import spearmanr
import json 
from helper import save_or_show, plot_r_coeff_distribution
from collections import defaultdict
from scipy.stats import ranksums
from statsmodels.stats.multitest import multipletests
import starbars

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
# sim_metrics = similarity_metrics
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
    'rsa_method_correlation_corr_method_spearman',
]

x_axis_ds = 'imagenet-subset-10k'
y_axis_ds = parse_datasets('../scripts/webdatasets_wo_imagenet.txt')
y_axis_ds = list(map(lambda x: x.replace('/', '_'), y_axis_ds))

SAVE = False
storing_path = Path('/home/space/diverse_priors/results/plots/ds_sim__imagenet_sim/r_coeff_distributions')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

with open('../scripts/models_config.json','r') as f: 
    model_configs = json.load(f)

if 'SegmentAnything_vit_b' in model_configs.keys():
    model_configs.pop('SegmentAnything_vit_b')

In [None]:
allowed_models = sorted(list(model_configs.keys()))
len(allowed_models)

In [None]:
def get_model_ids(fn):
    with open(fn, 'r') as file:
        lines = file.readlines()
    lines = [line.strip() for line in lines]
    return lines


def load_sim_martix(path):
    model_ids_fn = path / 'model_ids.txt'
    sim_mat_fn = path / 'similarity_matrix.pt'
    if model_ids_fn.exists():
        model_ids = get_model_ids(model_ids_fn)
    else:
        raise FileNotFoundError(f'{str(model_ids_fn)} does not exist.')
    sim_mat = torch.load(sim_mat_fn)
    df = pd.DataFrame(sim_mat, index=model_ids, columns=model_ids)
    
    available_models = sorted(list(set(model_ids).intersection(allowed_models)))
    
    df = df.loc[available_models, available_models]
    return df

In [None]:
sim_mats = defaultdict(dict)
for sim_metric in sim_metrics:
    for ds in [x_axis_ds] + y_axis_ds:
        sim_mats[sim_metric][ds] = load_sim_martix(base_path_similarity_matrices / ds / sim_metric)

In [None]:
anchors = [
    'OpenCLIP_ViT-L-14_openai', 
    'resnet50', 
    'simclr-rn50', 
    'vit_large_patch16_224', 
    'vit_large_patch16_224.augreg_in21k', 
    'dino-vit-base-p16'
]

In [None]:
anchor_col = 'Anchor Model'
other_col = 'Other Model'
other_ds_col = 'Other Dataset'
sim_metric_col = 'Similarity metric'
sim_ds_col = 'Similarity value DS'
sim_imgnet_col = 'Similarity value IN'
info_orig_cols = ['objective', 'architecture_class', 'dataset_class', 'size_class']
info_cols = ['Objective', 'Architecture', 'Dataset size', 'Model size']
id_cols = [sim_metric_col, anchor_col, other_col] + info_cols + [x_axis_ds]

def get_other_model_info(mid):
    model_config = model_configs[mid]
    return model_config['objective'], model_config['architecture_class'], model_config['dataset_class'], model_config['size_class']


def get_melted_sim_values_metric_anchor(anch, met, met_ds_mats):
    sim_vals_ds = []
    for ds, curr_sim_mat in met_ds_mats.items():
        cols = curr_sim_mat.columns.tolist()
        cols.remove(anch)
        cols = sorted(list(set(cols).intersection(allowed_models))) # TODO: think if we need to remove this
        row_sim_mat = curr_sim_mat.loc[anch, cols]
        row_sim_mat.name = ds
        sim_vals_ds.append(row_sim_mat)

    anchor_sim_vals = pd.concat(sim_vals_ds, axis=1)
    anchor_sim_vals = anchor_sim_vals.reset_index(names=[other_col])
    anchor_sim_vals = pd.concat([anchor_sim_vals, 
                                 pd.DataFrame(anchor_sim_vals[other_col].apply(get_other_model_info).tolist(),
                                              columns = info_cols)], axis=1)
    anchor_sim_vals[sim_metric_col] = sim_metric_name_mapping[sim_metric]
    anchor_sim_vals[anchor_col] = anch
    anchor_sim_vals = pd.melt(anchor_sim_vals,
                              id_vars=id_cols,
                              var_name=other_ds_col,
                              value_name=sim_ds_col,
                             )
    anchor_sim_vals.rename(columns={x_axis_ds:sim_imgnet_col}, inplace=True)
    return anchor_sim_vals
    

dfs = []
for anchor in anchors:
    for sim_metric, ds_sim_mat in sim_mats.items():
        anchor_sim_vals = get_melted_sim_values_metric_anchor(anchor, sim_metric, ds_sim_mat)
        dfs.append(anchor_sim_vals)

all_sims = pd.concat(dfs, axis=0).reset_index(drop=True)

In [None]:
def compute_corr(data):
    x = data[sim_imgnet_col]
    y = data[sim_ds_col]
    corr, _ = spearmanr(x, y)
    return corr


def compute_r_coeffs(data):
    res = []
    for strata in info_cols:
        res.append(data.groupby(strata, dropna=False).apply(compute_corr, include_groups=False))
    return res  

In [None]:
r_col = 'r coeff'
r_dfs = []
for strata in info_cols:
    grouping_cols = [sim_metric_col, anchor_col, other_ds_col, strata]
    strata_rs = all_sims.groupby(grouping_cols, dropna=False).apply(compute_corr, include_groups=False).reset_index()
    strata_rs.columns = [sim_metric_col, anchor_col, other_ds_col, strata, r_col]
    all_rs = all_sims.groupby([sim_metric_col, anchor_col, other_ds_col], dropna=False).apply(compute_corr, include_groups=False).reset_index()
    
    all_rs.columns = [sim_metric_col, anchor_col, other_ds_col, r_col]
    all_rs[strata] = 'All'
    rs = pd.concat([all_rs, strata_rs], axis=0).reset_index(drop=True)
    rs = rs.sort_values([sim_metric_col, anchor_col, other_ds_col]).reset_index(drop=True)
    r_dfs.append(rs)

In [None]:
import itertools
def get_pairs(df, strata_col):
    tuple_list = df[[strata_col, sim_metric_col]].value_counts().sort_index().index.tolist()
    tuple_list = [(v1, v2) for v1, v2 in tuple_list if v1!='All']
    pairs = [(a, b) for i, a in enumerate(tuple_list) for j, b in enumerate(tuple_list) if i < j and a[1] == b[1]]
    return pairs

def get_config_data(df, config, strata_col):
    return df[(df[strata_col] == config[0]) & (df[sim_metric_col] == config[1])]

def get_pvals_pairs(pairs, df, strata_col, alpha=0.05):
    p_values = []
    for (config1, config2) in pairs:
        dat1 = get_config_data(df,config1,strata_col)[r_col].reset_index(drop=True)
        dat2 = get_config_data(df,config2,strata_col)[r_col].reset_index(drop=True)
        idx2drop = list(np.where(dat1.isna())[0]) + list(np.where(dat2.isna())[0])
        dat1.drop(labels=idx2drop, inplace=True)
        dat2.drop(labels=idx2drop, inplace=True)
        assert len(dat1) == len(dat2)
        stat, p_value = ranksums(dat1, dat2)
        p_values.append(p_value)
    return p_values


def correct_anchor_pvalues(anchor_pairs, anchor_pvals):
    anchor_corr_pvals = {}
    all_pvals = list(np.concatenate(list(anchor_pvals.values())))
    corrected_all_pvals = multipletests(all_pvals, method='hs')[1]
    # corrected_all_pvals = multipletests(all_pvals, method='bonferroni')[1]
    # corrected_all_pvals = multipletests(all_pvals, method='fdr_bh')[1]
    idx = 0 
    for anchor in anchors:
        nr_pvals = len(anchor_pairs[anchor])
        anchor_corr_pvals[anchor] = corrected_all_pvals[idx:(idx+nr_pvals)]
        idx += nr_pvals
    return anchor_corr_pvals


def add_sign_bars(anchor_pairs, anchor_corr_pvals, axes, strata_col, alpha=0.01):
    total_annots = 0
    for i, anchor in enumerate(anchors):
        curr_pairs = anchor_pairs[anchor]
        curr_corr_pvals = anchor_corr_pvals[anchor]
        annotations = [(config1, config2, pval) for (config1, config2), pval in zip(curr_pairs, curr_corr_pvals) if pval < alpha]
        total_annots += len(annotations)
        starbars.draw_annotation(annotations, ax = axes[i, 0])
        axes[i, 0].get_legend().remove()
    print(f"{strata_col} {total_annots=}")
    
def get_plot_for_info(df, strata_col, strata_orig_col, use_hist=True):
    n = df[anchor_col].nunique()
    m = df[sim_metric_col].nunique() + 1
    fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(m*4.5, n*3))
    
    handles, labels = [], []
    
    peranchor_annotations = defaultdict(dict)
    # for i, (anchor, anchor_data) in enumerate(df.groupby(anchor_col, dropna=False, sort=False)):
    anchor_pairs = {}
    anchor_pvals = {}
    
    for i, anchor in enumerate(anchors):
        anchor_data = df[df[anchor_col] == anchor]

        ## Compute pairwise distribution statistics
        anchor_pairs[anchor] = get_pairs(anchor_data, strata_col)
        anchor_pvals[anchor] = get_pvals_pairs(anchor_pairs[anchor], anchor_data, strata_col, alpha=0.01)

        ## Boxplot
        sns.boxplot(data=anchor_data, x=sim_metric_col, y=r_col, hue=strata_col, ax = axes[i, 0], legend=True)
        if i==0:
            axes[i, 0].set_title('Distribution r coeff. (spearman)', fontsize=13)
            handles, labels = axes[i, 0].get_legend_handles_labels()
        
        axes[i, 0].set_xlabel('')
        axes[i, 0].set_ylabel(f'{anchor}\n({model_configs[anchor][strata_orig_col]})', fontsize=12)
        axes[i, 0].tick_params('both', labelsize=12)

        
        ## Distribution density for each 
        for k, (met, met_data) in enumerate(anchor_data.groupby(sim_metric_col, dropna=False)):
            ax = axes[i, k+1]
            if use_hist:
                sns.histplot(met_data, x=r_col, hue=strata_col, bins=10, multiple='dodge', kde=True, ax=ax, alpha=0.5, legend=False);
            else:
                sns.kdeplot(met_data, x=r_col, hue=strata_col, ax=ax, alpha=0.8, legend=False);
            if i==0:
                ax.set_title(met, fontsize=13)
            
            if i == (n-1):
                ax.set_xlabel('r coeff. (spearman)', fontsize=12)
            else:
                ax.set_xlabel('')
            ax.set_ylabel('')
            ax.tick_params('both', labelsize=12)

    ## Apply multiple testing correction over all anchors
    anchor_corr_pvals = correct_anchor_pvalues(anchor_pairs, anchor_pvals)

    ## Add significance bars
    add_sign_bars(anchor_pairs, anchor_corr_pvals, axes, strata_col, alpha=0.01)
        
    fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(1.05, 1), borderaxespad=0., fontsize=13, title = strata_col, title_fontsize=13)
    
    # Adjust layout
    fig.tight_layout(rect=[0, 0, 0.925, 1])  # Make room for the legend
    return fig, (anchor_pairs, anchor_corr_pvals)

In [None]:
pvalue_information = {}

for use_hist in [True, False]:
    for (df, strata_col, strata_orig_col) in zip(r_dfs, info_cols, info_orig_cols):
        fig, pvalue_information[strata_col] = get_plot_for_info(df, strata_col, strata_orig_col, use_hist=use_hist)
        suffix='hist' if use_hist else 'kde'
        save_or_show(fig, storing_path / f'dist_r_coeffs_{strata_col.replace(" ", "_")}_{suffix}.pdf', SAVE)

In [None]:
row_order = r_dfs[0][sim_metric_col].unique()

col_order = anchors
col_order1 = anchors[:3]
col_order2 = anchors[3:]

In [None]:
def get_only_dist_plot(anchor_data, strata_col, strata_col_orig, col_order):
    g = sns.FacetGrid(anchor_data, 
                      row=sim_metric_col, 
                      row_order=row_order, 
                      col=anchor_col, col_order=col_order,
                      hue=strata_col, aspect=7, height=.5, sharey=False)
    # Draw the densities in a few steps
    g.map(sns.kdeplot, r_col,
          bw_adjust=.5, clip_on=False,
          fill=True, alpha=0.1, linewidth=1.25)
    
    # Set the subplots to overlap
    g.figure.subplots_adjust(hspace=-.2, wspace=-.15)
    
    g.set_titles('')
    # Remove axes details that don't play well with overlap
    for i, ax in enumerate(g.axes[0,:]):
        ax.set_title(f'{col_order[i]}\n({model_configs[col_order[i]][strata_col_orig]})', fontsize=13)
      
    g.set(yticks=[], ylabel="")
    for i, ax in enumerate(g.axes[:,0]):
        ax.set_ylabel(f"{row_order[i]}", rotation=0, loc='bottom', fontsize=12)
        
    g.set(xticks=[-0.5, 0, 0.5, 1], xlabel="corr. coeff.")
    
    g.despine(bottom=True, left=True)
    
    g.figure.patch.set_facecolor('none')
    g.add_legend(title=strata_col, fontsize=12, title_fontsize=12)
    
    # Set the background of each subplot to be transparent
    for ax in g.axes.flat:
        ax.patch.set_alpha(0)  # Makes the background of each subplot transparent
    return g.fig

In [None]:
for (df, strata_col, strata_orig_col) in zip(r_dfs, info_cols, info_orig_cols):
    # all 
    fig = get_only_dist_plot(df, strata_col, strata_orig_col, col_order)
    save_or_show(fig, storing_path / f'dist_r_coeffs_{strata_col.replace(" ", "_")}_only_kde_all.pdf', SAVE)

    # part1 
    fig = get_only_dist_plot(df, strata_col, strata_orig_col, col_order1)
    save_or_show(fig, storing_path / f'dist_r_coeffs_{strata_col.replace(" ", "_")}_only_kde_part1.pdf', SAVE)

    # part2
    fig = get_only_dist_plot(df, strata_col, strata_orig_col, col_order2)
    save_or_show(fig, storing_path / f'dist_r_coeffs_{strata_col.replace(" ", "_")}_only_kde_part2.pdf', SAVE)


In [None]:
def plot_box_per_info_col(subset, strata_col, strata_col_orig):
    g = sns.catplot(
        subset, 
        x=strata_col,
        y=r_col,
        hue=sim_metric_col,
        col=anchor_col,
        col_order=anchors,
        col_wrap=3,
        kind="box",
        height=2.5,
        aspect=1.5,
        palette = {row_order[0]:'lightgrey',
                   row_order[1]:'beige',
                   row_order[2]:'white',
                  })
    g.set_ylabels('Correlation coefficients', fontsize=10)
    g.set_xlabels('')
    g.set_titles('')
    tab10_cols = list(sns.color_palette("tab10").as_hex())
    comparison_type = list(subset[strata_col].unique())
    # Remove axes details that don't play well with overlap
    for i, ax in enumerate(g.axes):
        strata_type = model_configs[anchors[i]][strata_col_orig]
        ax.set_title(f'{anchors[i]}\n({strata_type})', color=tab10_cols[comparison_type.index(strata_type)], fontsize=11)
        ax.tick_params('x', rotation=20, labelsize=10)
        ax.tick_params('y', labelsize=10)
        for k, mid in enumerate(ax.get_xticks()):
            ax.axvspan(mid-0.5, mid+0.5, facecolor=tab10_cols[k], alpha=0.1, zorder=-1)
    
    g.fig.tight_layout(rect=[0, 0, 0.9, 1])
    return g.fig


for (df, strata_col, strata_orig_col) in zip(r_dfs, info_cols, info_orig_cols):
    # all 
    fig = plot_box_per_info_col(df, strata_col, strata_orig_col)
    save_or_show(fig, storing_path / f'dist_r_coeffs_{strata_col.replace(" ", "_")}_only_box_per_info_col.pdf', SAVE)

In [None]:
def plot_violin_per_info_col(subset, strata_col, strata_col_orig):
    subset = subset[subset[sim_metric_col].isin(row_order[:2])]
    g = sns.catplot(
        subset, 
        x=strata_col,
        y=r_col,
        hue=sim_metric_col,
        col=anchor_col,
        col_order=anchors,
        col_wrap=3,
        kind="violin",
        split=True, 
        inner='quart',
        height=2.5,
        aspect=1.5,
        density_norm='width',
        palette = {row_order[0]:'lightgrey',
                   row_order[1]:'beige',
                  })
    g.set_ylabels('Correlation coefficients', fontsize=10)
    g.set_xlabels('')
    g.set_titles('')
    tab10_cols = list(sns.color_palette("tab10").as_hex())
    comparison_type = list(subset[strata_col].unique())
    # Remove axes details that don't play well with overlap
    for i, ax in enumerate(g.axes):
        strata_type = model_configs[anchors[i]][strata_col_orig]
        ax.set_title(f'{anchors[i]}\n({strata_type})', color=tab10_cols[comparison_type.index(strata_type)], fontsize=11)
        ax.tick_params('x', rotation=20, labelsize=10)
        ax.tick_params('y', labelsize=10)
        for k, mid in enumerate(ax.get_xticks()):
            ax.axvspan(mid-0.5, mid+0.5, facecolor=tab10_cols[k], alpha=0.1, zorder=-1)
    
    g.fig.tight_layout(rect=[0, 0, 0.9, 1])
    return g.fig


for (df, strata_col, strata_orig_col) in zip(r_dfs, info_cols, info_orig_cols):
    fig = plot_violin_per_info_col(df, strata_col, strata_orig_col)
    save_or_show(fig, storing_path / f'dist_r_coeffs_{strata_col.replace(" ", "_")}_only_violin_per_info_col.pdf', SAVE)

In [None]:
def add_sign_bars(anchor_pairs, anchor_corr_pvals, axes, strata_col, alpha=0.01):
    total_annots = 0
    for i, anchor in enumerate(anchors):
        curr_pairs = anchor_pairs[anchor]
        curr_corr_pvals = anchor_corr_pvals[anchor]
        annotations = [(config1, config2, pval) for (config1, config2), pval in zip(curr_pairs, curr_corr_pvals) if pval < alpha]
        total_annots += len(annotations)
        starbars.draw_annotation(annotations, ax = axes[i])
        axes[i].get_legend().remove()
        

def plot_box_per_sim_metric(subset, strata_col, strata_col_orig, anchor_pairs, anchor_corr_pvals):

    n_rows = 2
    n_cols = 3

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3.55, n_rows * 2.6), sharey=False, sharex=False)
    axes = axes.flatten() 
    
    handles, labels = [], []
    tab10_cols = list(sns.color_palette("tab10").as_hex())
    comparison_type = list(subset[strata_col].unique())
    for i, anchor in enumerate(anchors):
        sns.boxplot(
            data=subset[subset[anchor_col] == anchor], 
            x=sim_metric_col,
            y=r_col,
            hue=strata_col,
            ax=axes[i]
        )
        strata_type = model_configs[anchor][strata_col_orig]
        axes[i].set_title(f'{anchor}\n({strata_type})', color=tab10_cols[comparison_type.index(strata_type)],fontsize=11)
        axes[i].set_xlabel("")  # Remove x-axis label for clarity
        
        axes[i].legend(loc='best')
        if i==0:
            handles, labels = axes[i].get_legend_handles_labels()
        if (i)%n_cols == 0:
            axes[i].set_ylabel("Correlation coefficient", fontsize=10)
        else:
            axes[i].set_ylabel("")
        axes[i].tick_params('both', labelsize=10)
        for mid in axes[i].get_xticks()[:-1]:
            axes[i].axvline(mid+0.5, ls=':', c='grey', alpha=0.5, lw=0.75)
    
    fig.subplots_adjust(hspace=-.2, wspace=-.15)
    
    add_sign_bars(anchor_pairs, anchor_corr_pvals, axes, strata_col, alpha=0.01)
    
    fig.legend(handles, labels, 
               bbox_to_anchor=(1.15, 0.60), 
               borderaxespad=0., 
               fontsize=11, 
               title = strata_col, 
               title_fontsize=11,
               frameon=False,
              )
    
    fig.tight_layout()  # Make room for the legend
    return fig

for (df, strata_col, strata_orig_col) in zip(r_dfs, info_cols, info_orig_cols): 
    (anchor_pairs, anchor_corr_pvals) = pvalue_information[strata_col]
    fig = plot_box_per_sim_metric(df, strata_col, strata_orig_col, anchor_pairs, anchor_corr_pvals)
    save_or_show(fig, storing_path / f'dist_r_coeffs_{strata_col.replace(" ", "_")}_only_box_per_sim_metric.pdf', SAVE)