In [None]:
import sys 
import pandas as pd 
import numpy as np
from pathlib import Path
from itertools import combinations
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns

from constants import ds_info_file, cat_name_mapping, fontsizes
from helper import load_ds_info, save_or_show
import textwrap

sys.path.append('..')
from scripts.helper import parse_datasets

In [None]:
ds_list = parse_datasets('../scripts/webdatasets_w_insub10k.txt')
ds_list = list(map(lambda x: x.replace('/', '_'), ds_list))

ds_info = load_ds_info(ds_info_file)

## Storing information 
SAVE = True
corr_method = 'pearsonr'
storing_path = Path(f'/home/space/diverse_priors/results/plots/dist_r_coeff_dataset_cats/{corr_method}')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

## Load data
orig_sim_data = pd.read_csv('/home/space/diverse_priors/results/aggregated/model_sims/all_metric_ds_model_pair_similarity.csv')
orig_sim_data['Objective pair'] = orig_sim_data['Objective pair'].apply(eval)

combinations_objectives = [('', ''),
                            ('Image-Text', 'Image-Text'),
                            ('Image-Text', 'Self-Supervised'), 
                            ('Image-Text', 'Supervised'), 
                            ('Self-Supervised', 'Self-Supervised'), 
                            ('Self-Supervised', 'Supervised'), 
                            ('Supervised', 'Supervised')]

In [None]:
def filter_data(mcat1, mcat2, sim_data, curr_suffix):
    if len(mcat1)>0 and len(mcat2)>0:
            sim_data_flat = sim_data[sim_data['Objective pair'].apply(lambda x: sorted(x) == sorted([mcat1, mcat2]))].copy()
            curr_suffix = f"_{cat_name_mapping[mcat1]}_{cat_name_mapping[mcat2]}" + curr_suffix
    else:
        sim_data_flat = sim_data
    return sim_data_flat, curr_suffix
    
    
def get_ds_combs(df):
    n_ds = df['DS'].nunique()
    available_ds = sorted(list(set(df['DS'].unique()).intersection(ds_list)))
    available_domains = sorted(ds_info['domain'].unique().tolist())
    available_ds = sorted(available_ds, key = lambda x: available_domains.index(ds_info.loc[x, 'domain']))
    combs_DS = list(combinations(available_ds, 2))
    corr_mat = pd.DataFrame(index=available_ds, columns=available_ds).astype('float')
    return available_ds, available_domains, combs_DS, corr_mat


def get_r_coeff(x,y, method='pearsonr'):
    if method=='pearsonr':
        corr, _  = pearsonr(x, y)
    elif method=='spearmanr':
        corr, _  = spearmanr(x, y)
    else:
        raise ValueError("Unknown method")
    return corr 


def get_two_ds_data(ds1, ds2, all_sims):
    data_2_ds = all_sims[all_sims['DS'].isin([ds1, ds2])].copy()
    data_2_ds['model_pair'] = data_2_ds['Model 1'] + ", " + data_2_ds['Model 2']
    ds_similarities = pd.pivot_table(
        data_2_ds,
        columns='DS',
        index='model_pair',
        values= 'Similarity value',
    )
    return ds_similarities 


def fill_corr_mat(df, combs_DS, corr_mat):
    for ds1, ds2 in combs_DS:
        ds_sims = get_two_ds_data(ds1, ds2, df)
        corr = get_r_coeff(ds_sims.values[:,0], ds_sims.values[:,1], method=corr_method)
        corr_mat.loc[ds1, ds2] = float(corr)
        corr_mat.loc[ds2, ds1] = float(corr)
    
    np.fill_diagonal(corr_mat.values, 1)
    return corr_mat


def rename_idx_cols(corr_mat, available_ds):
    new_naming = [ds_info.loc[ds, 'name'] for ds in available_ds]
    corr_mat.index = new_naming
    corr_mat.columns = new_naming
    return corr_mat


def plot_heatmap(corr_mat, available_ds, available_domains, vmin = -0.2, vmax=1):
    plt.figure(figsize=(9,7))

    g = sns.heatmap(corr_mat, square=True, cmap='mako', cbar=False, vmin = vmin, vmax=vmax)
    
    tmp = np.where(~(ds_info.loc[available_ds,'domain'].iloc[:-1].values == ds_info.loc[available_ds,'domain'].iloc[1:].values))[0]
    tmp +=1
    
    for val in tmp:
        g.axhline(val, c='black', ls=":")
        g.axvline(val, c='black', ls=":")
    
    g.text(5 , -1.5, '\n'.join(available_domains[0].split(' ')), ha='center', va='top', fontsize=10, color='black')
    g.text(13 , -1.5, '\n'.join(available_domains[1].split(' ')), ha='center', va='top', fontsize=10, color='black')
    g.text(17.75 , -1, available_domains[2], ha='center', va='top', fontsize=10, color='black')
    g.text(21.75 , -1, available_domains[3], ha='center', va='top', fontsize=10, color='black')
    
    
    cbar = plt.colorbar(g.collections[0], ax=g, orientation='vertical', pad=0.025, aspect=40)
    
    plt.tight_layout()
    return plt.gcf()
    

In [None]:
for sim_metric in ['CKA linear', 'CKA RBF 0.4']:
    
    sim_data = orig_sim_data[orig_sim_data['Similarity metric']==sim_metric].reset_index().copy()
    
    suffix = "_" + sim_metric.replace(" ", "_").lower()
    
    for mcat1, mcat2 in combinations_objectives:
        
        sim_data_flat, curr_suffix = filter_data(mcat1, mcat2, sim_data, suffix)

        available_ds, available_domains, combs_DS, corr_mat = get_ds_combs(sim_data_flat)

        corr_mat = fill_corr_mat(sim_data_flat, combs_DS, corr_mat)

        corr_mat = rename_idx_cols(corr_mat, available_ds)

        fig = plot_heatmap(corr_mat, available_ds, available_domains, vmin = -0.2, vmax=1)

        save_or_show(plt.gcf(), storing_path / f'grouped_heatmap{curr_suffix}.pdf', SAVE)

In [None]:
# ## Filtered similarity metric
# sim_metric = 'CKA linear'
# sim_data_flat = orig_sim_data[orig_sim_data['Similarity metric']==sim_metric].reset_index().copy()
# suffix = "_" + sim_metric.replace(" ", "_").lower()

# combinations_objectives = [('', ''),
#                             ('Image-Text', 'Image-Text'),
#                             ('Image-Text', 'Self-Supervised'), 
#                             ('Image-Text', 'Supervised'), 
#                             ('Self-Supervised', 'Self-Supervised'), 
#                             ('Self-Supervised', 'Supervised'), 
#                             ('Supervised', 'Supervised')]
# mcat1, mcat2 = combinations_objectives[0]

# if len(mcat1)>0 and len(mcat2)>0:
#     sim_data_flat = sim_data_flat[sim_data_flat['Objective pair'].apply(lambda x: sorted(x) == sorted([mcat1, mcat2]))].copy()
#     suffix = f"_{cat_name_mapping[mcat1]}_{cat_name_mapping[mcat2]}" + suffix

# sim_data_flat['Objective pair'].value_counts().sort_index()

In [None]:
# n_ds = sim_data_flat['DS'].nunique()
# available_ds = set(sim_data_flat['DS'].unique()).intersection(ds_list)
# available_ds = sorted(list(available_ds))
# available_domains = sorted(ds_info['domain'].unique().tolist())
# available_ds = sorted(available_ds, key = lambda x: available_domains.index(ds_info.loc[x, 'domain']))
# combs_DS = list(combinations(available_ds, 2))
# corr_mat = pd.DataFrame(index=available_ds, columns=available_ds).astype('float')

In [None]:
# def get_r_coeff(x,y, method='pearsonr'):
#     if method=='pearsonr':
#         corr, _  = pearsonr(x, y)
#     elif method=='spearmanr':
#         corr, _  = spearmanr(x, y)
#     else:
#         raise ValueError("Unknown method")
#     return corr 


# def get_two_ds_data(ds1, ds2, all_sims):
#     data_2_ds = all_sims[all_sims['DS'].isin([ds1, ds2])].copy()
#     data_2_ds['model_pair'] = data_2_ds['Model 1'] + ", " + data_2_ds['Model 2']
#     ds_similarities = pd.pivot_table(
#         data_2_ds,
#         columns='DS',
#         index='model_pair',
#         values= 'Similarity value',
#     )
#     return ds_similarities

In [None]:
# for ds1, ds2 in combs_DS:
#     ds_sims = get_two_ds_data(ds1, ds2, sim_data_flat)
#     corr = get_r_coeff(ds_sims.values[:,0], ds_sims.values[:,1], method=corr_method)
#     corr_mat.loc[ds1, ds2] = float(corr)
#     corr_mat.loc[ds2, ds1] = float(corr)

# np.fill_diagonal(corr_mat.values, 1)

In [None]:
# new_naming = [ds_info.loc[ds, 'name'] for ds in available_ds]
# corr_mat.index = new_naming
# corr_mat.columns = new_naming

In [None]:
# cmap = sns.color_palette('tab10', len(available_domains)).as_hex()
# domain_color_mapping={domain: color for domain, color in  zip(available_domains, cmap)}

In [None]:
# def color_labels(g, axis):
#     lbl_iter = g.get_xticklabels if axis == 'x' else g.get_yticklabels
#     for label in lbl_iter():
#         tick_text = label.get_text()
#         curr_color = domain_color_mapping[ds_info.loc[ds_info['name'] == tick_text, 'domain'].iloc[0]]
#         label.set_color(curr_color)
#         label.set(fontsize=fontsizes['ticks'])

# plt.figure(figsize=(9,7))

# g = sns.heatmap(corr_mat, square=True, cmap='mako', cbar=False, vmin = -0.2, vmax=1)

# tmp = np.where(~(ds_info.loc[available_ds,'domain'].iloc[:-1].values == ds_info.loc[available_ds,'domain'].iloc[1:].values))[0]
# tmp +=1

# for val in tmp:
#     g.axhline(val, c='black', ls=":")
#     g.axvline(val, c='black', ls=":")

# g.text(5 , -1.5, '\n'.join(available_domains[0].split(' ')), ha='center', va='top', fontsize=10, color='black')
# g.text(13 , -1.5, '\n'.join(available_domains[1].split(' ')), ha='center', va='top', fontsize=10, color='black')
# g.text(17.75 , -1, available_domains[2], ha='center', va='top', fontsize=10, color='black')
# g.text(21.75 , -1, available_domains[3], ha='center', va='top', fontsize=10, color='black')


# cbar = plt.colorbar(g.collections[0], ax=g, orientation='vertical', pad=0.025, aspect=40)

# plt.tight_layout()
# save_or_show(plt.gcf(), storing_path / f'grouped_heatmap{suffix}.pdf', SAVE)

In [None]:
# def color_labels(g, axis):
#     lbl_iter = g.ax_heatmap.get_xticklabels if axis == 'x' else g.ax_heatmap.get_yticklabels
#     for label in lbl_iter():
#         tick_text = label.get_text()
#         curr_color = domain_color_mapping[ds_info.loc[ds_info['name'] == tick_text, 'domain'].iloc[0]]
#         label.set_color(curr_color)
#         label.set(fontsize=fontsizes['ticks'])
# vmin = -0.2
# vmax = 1
# g = sns.clustermap(corr_mat, 
#                row_colors=ds_info.loc[available_ds,'domain'].map(domain_color_mapping).tolist(), 
#                col_colors=ds_info.loc[available_ds,'domain'].map(domain_color_mapping).tolist(),
#                row_cluster=False,
#                col_cluster=False,
#                cbar_pos=None,
#                cmap = 'Greys',
#                figsize=(8,8), 
#                    vmin=vmin,
#                    vmax=vmax
#               )

# color_labels(g, 'x')
# color_labels(g, 'y')

In [None]:
# corr_mat.index = available_ds
# corr_mat.columns = available_ds

# idxs = np.triu_indices_from(corr_mat.values, k=1)
# corr_mat.index[idxs[0]].values
# #corr_mat.values[idxs]
# flattend_corr_mat = pd.DataFrame({
#     'DS 1': corr_mat.index[idxs[0]].values,
#     'DS 2': corr_mat.columns[idxs[1]].values,
#     'R coeff': corr_mat.values[idxs]
# })

In [None]:
# flattend_corr_mat['Domain pair'] = flattend_corr_mat[['DS 1', 'DS 2']].apply(
#     lambda x: ",\n".join(sorted([ds_info.loc[x['DS 1'], 'domain'], ds_info.loc[x['DS 2'], 'domain']])),
#     axis=1
# )

In [None]:
# order = flattend_corr_mat['Domain pair'].value_counts().sort_index().index.tolist()

In [None]:
# import textwrap
# plt.figure(figsize=(8,5))
# g = sns.swarmplot(
#     flattend_corr_mat,
#     x='Domain pair',
#     y='R coeff',
#     size=3,
#     order=order
# )
# g.set_xlabel('')
# g.set_ylabel('Correlation coefficient')
# g.tick_params('x', rotation=90)
# g.axhline(0.5, c='grey', zorder=-1, ls=":", alpha=0.5)
# if len(mcat1)>0 and len(mcat2)>0:
#     g.set_title(f'Correlation coefficient across {mcat1} and {mcat2} model pairs.')
# else:
#     g.set_title(f'Correlation coefficient across all model pairs.')
# plt.tight_layout()
# save_or_show(plt.gcf(), storing_path / f'swarm{suffix}.pdf', SAVE)