In [None]:
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from itertools import product, combinations

from helper import load_model_configs_and_allowed_models, load_similarity_matrices, save_or_show

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
]
sim_metrics_mapped = [sim_metric_name_mapping[k] for k in sim_metrics]

y_axis_ds = parse_datasets('../scripts/webdatasets_wo_imagenet.txt')
y_axis_ds = list(map(lambda x: x.replace('/', '_'), y_axis_ds))
ds_list = ['imagenet-subset-10k'] + y_axis_ds

cm = 0.393701

SAVE = True
storing_path = Path('/home/space/diverse_priors/results/plots/scatter_similarity')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
model_configs, allowed_models = load_model_configs_and_allowed_models(
    path='../scripts/models_config.json',
    exclude_models=['SegmentAnything_vit_b', 'DreamSim_dino_vitb16', 'DreamSim_open_clip_vitb32'],
    exclude_alignment=True,
)

In [None]:
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)
sim_mats = {sim_metric_name_mapping[k]:v for k, v in sim_mats.items()}

In [None]:
ds_lists = dict(
    ds_row_1_v1 = ['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_eurosat'],
    ds_row_1_v2 = ['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam'],
    ds_row_1_v3 = ['imagenet-subset-10k', 'wds_vtab_pets', 'wds_vtab_eurosat'],
    ds_row_1_v4 = ['imagenet-subset-10k', 'wds_vtab_pets', 'wds_vtab_pcam'],
)

ds_name_mapping= {
    'imagenet-subset-10k': 'ImageNet (natural)',
    'wds_vtab_flowers': 'Flowers (single domain)',
    'wds_vtab_pets': 'Pets (single domain)',
    'wds_vtab_eurosat': 'Eurosat (structured)',
    'wds_vtab_pcam': 'PCAM (structured)'
}

In [None]:
agg_data_path = Path('/home/space/diverse_priors/results/aggregated/r_coeff_dist/agg_corr_coeffs_all_ds.csv')

In [None]:
anchors = [
    'OpenCLIP_RN50_openai', 
    'simclr-rn50', 
    'resnet50', 
    'OpenCLIP_ViT-L-14_openai',
    'mae-vit-large-p16',
    'vit_large_patch16_224',
]

anchor_name_mapping = {
    'OpenCLIP_RN50_openai': 'OpenCLIP RN50',
    'OpenCLIP_ViT-L-14_openai': 'OpenCLIP ViT-L',
    'resnet50': 'ResNet-50',
    'vit_large_patch16_224': 'ViT-L',
    'simclr-rn50': 'SimCLR RN50',
    'dino-vit-base-p16': 'DINO ViT-B',
    'dinov2-vit-large-p14': 'DINOv2 ViT-L',
    'mae-vit-large-p16': 'MAE ViT-L'
}
anchor_named_2_orig = {v:k for k, v in anchor_name_mapping.items()}
anchor_nm_val_list = list(anchor_name_mapping.values())
anchor_nm_val_list_v2 = [anchor_name_mapping[mid] for mid in anchors]

anchor_col = 'Anchor Model'
sim_metric_col = 'Similarity metric'
comp_cat_col = 'Comparison category'
comp_cat_orig_col = 'Comparison category (orig. name)'
comp_val_col = 'Comparison values'


In [None]:
curr_ds_list = ds_lists['ds_row_1_v2']

In [None]:
from matplotlib.lines import Line2D
from scipy import stats

def get_scatter_for_pp_r_df(r_df, with_reg_line=False):
    nrows = r_df[sim_metric_col].nunique()
    ncols = len(curr_ds_list)
    
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*10*cm, nrows*9*cm), sharex=True, sharey=True)
    def pp_simmat(sim_mat, anch):
        new= sim_mat.loc[anch].copy()
        new = new.drop(index=anch)
        return new
    
    markers = {
        'OpenCLIP RN50': 'o',               # Circle
        'OpenCLIP ViT-L': 's',           # Square
        'ResNet-50': '^',                           # Triangle Up
        'ViT-L': 'v',              # Triangle Down
        'SimCLR RN50': 'D',                        # Diamond
        'DINO ViT-B': '*',                  # Star
        'DINOv2 ViT-L': 'p',               # Pentagon
        'MAE ViT-L': 'H'                   # Hexagon
    }
    
    colors = sns.color_palette('tab10', 4).as_hex()
    colors = colors[-1:] + colors[:-1]
    colors = dict(zip(['All', 'Image-Text', 'Self-Supervised', 'Supervised'], colors))
    
    cat_mapping = {'All':'All', 'Image-Text':'IT', 'Self-Supervised':'SS', 'Supervised':'S'}
    for i, (sim_metric, data) in enumerate(r_df.groupby(sim_metric_col)):
        sim_mat_metric = sim_mats[sim_metric]
        for j, (ds1, ds2) in enumerate(combinations(curr_ds_list, 2)):
            ax = axes[i,j]
            handles, labels = [], []
            subset_data = data[data[['DS 1','DS 2']].apply(lambda x: ds1 in x.tolist() and ds2 in x.tolist(), axis=1)]
            nanchors = subset_data[anchor_col].nunique()
            for anchor, anch_data in subset_data.groupby(anchor_col):
                pp_ds1 = pp_simmat(sim_mat_metric[ds1], anchor_named_2_orig[anchor])
                pp_ds2 = pp_simmat(sim_mat_metric[ds2], anchor_named_2_orig[anchor])
                cat_values = np.array([model_configs.loc[c, 'objective'] for c in pp_ds1.index])
                
                if with_reg_line:
                    for val in np.unique(cat_values):
                        idxs = cat_values == val
                        sns.regplot(
                            x=pp_ds1[idxs], 
                            y=pp_ds2[idxs],
                            marker=markers[anchor], 
                            ax=ax,
                            color=colors[val],
                            line_kws=dict(alpha=0.75, ls=':', lw=0.75),
                            scatter_kws=dict(alpha=0.5),
                            ci=None
                        )
                else:
                    curr_colors = [colors[c] for c in cat_values]
                    ax.scatter(x=pp_ds1, y=pp_ds2, c=curr_colors, marker=markers[anchor], alpha=0.5)
                    
                handles.append(Line2D([0], [0], linestyle='None', marker=None))
                labels.append(anchor)
                for idx, row in anch_data[[comp_val_col, 'r coeff']].iterrows():
                    cur_marker = None if row[comp_val_col] == 'All' else markers[anchor]
                    handles.append(Line2D([0], [0], color=colors[row[comp_val_col]], marker=cur_marker, linestyle='None', markersize=7, alpha=0.5))
                    labels.append(f"r {cat_mapping[row[comp_val_col]]}: {row['r coeff']:.2f}")
    
            ax.legend(handles=handles, labels = labels, fontsize=10, framealpha=0.5, frameon=True, title='', loc='lower right', ncols=2 if nanchors>1 else 1)
            if i==0:
                ax.set_xlabel("")
            else:
                ax.set_xlabel(ds_name_mapping[ds1], fontsize=10)
            if j==0:
                ax.set_ylabel(f"{sim_metric}\n{ds_name_mapping[ds2]}", fontsize=10)
            else:
                ax.set_ylabel(ds_name_mapping[ds2])
    
    fig.tight_layout()
    return fig 

In [None]:
anchor_combinations = list(combinations(anchors, 1)) + list(combinations(anchors, 2))
# anchor_combinations = list(combinations(anchors, 1))

for curr_anchors in anchor_combinations:
    for draw_reg in [True, False]:
        r_df = pd.read_csv(agg_data_path)
        r_df = r_df[r_df[anchor_col].isin(curr_anchors)].copy().reset_index(drop=True)
        r_df[anchor_col] = r_df[anchor_col].map(anchor_name_mapping)
        r_df = r_df[r_df['DS 1'].isin(curr_ds_list) & r_df['DS 2'].isin(curr_ds_list)]
        r_df = r_df[r_df[sim_metric_col].isin(sim_metrics_mapped)]
        r_df = r_df[r_df[comp_cat_col] == 'Objective']
        fig = get_scatter_for_pp_r_df(r_df, with_reg_line=draw_reg)
        save_or_show(fig, 
                     storing_path/ f'scatter{"_wih_reg" if draw_reg else ""}_M{"_".join(curr_anchors)}_DS{"_".join(curr_ds_list)}.pdf', 
                     SAVE)

### Line plot distributions

In [None]:
r_df = pd.read_csv(agg_data_path)

SAVE=True
storing_path = Path('/home/space/diverse_priors/results/plots/distribution_similarity')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
anchors = [
    'OpenCLIP_RN50_openai',
    'OpenCLIP_ViT-L-14_openai',
    'simclr-rn50', 
    'mae-vit-large-p16',
    'resnet50', 
    'vit_large_patch16_224',
]

r_df = r_df[r_df['Anchor Model'].isin(anchors)]
sim_mets = r_df['Similarity metric'].unique()
n=len(anchors)
m=len(sim_mets)

include_all = False 

for comp_cat in r_df['Comparison category'].unique():
    subset = r_df[r_df['Comparison category']==comp_cat]
    comp_cat_orig = subset['Comparison category (orig. name)'].unique()[0]
    fig, axes = plt.subplots(nrows=m, ncols=n, figsize=(n*6*cm, m*3.5*cm), sharex=True, sharey=False)
    handles, labels = [], []
    # for key, group_data in subset.groupby(['Anchor Model', 'Similarity metric']):
    for anchor in anchors:
        for sim_metric in sim_mets:
            group_data = subset[(subset['Anchor Model']==anchor) & (subset['Similarity metric']==sim_metric)]
            key = (anchor, sim_metric)
            idx_row = list(anchors).index(key[0])
            idx_col = list(sim_mets).index(key[1])
            ax = axes[idx_col, idx_row]
            
            if not include_all:
                group_data = group_data[group_data['Comparison values']!='All']
    
            hue_order = sorted(list(group_data['Comparison values'].unique()))
            if include_all:
                hue_order = hue_order[1:] + hue_order[:1]
    
            sns.kdeplot(
                group_data,
                x='r coeff',
                hue='Comparison values',
                hue_order=hue_order, 
                ax=ax
            )
    
            ax.axvline(0.5, c='grey', ls=':', alpha=0.5, zorder=-1, lw=0.75)
            ax.set_xlim([-0.5, 1.1])
            
            ax.set_title(f"{anchor_name_mapping[key[0]]}" if idx_col==0 else "", fontsize=11, 
                         color=sns.color_palette('tab10', 10).as_hex()[hue_order.index(model_configs.loc[key[0]][comp_cat_orig])])
            ax.set_xlabel('' if idx_col < (m-1)  else 'Correlation coefficient') 
            ax.set_ylabel('' if idx_row > 0 else f"{key[1]}\nDensity", fontsize=11)
            
            if idx_col==0 and idx_row==n-1:
                sns.move_legend(ax, loc='upper left', title=comp_cat, bbox_to_anchor=(1, 1), fontsize=11, title_fontsize=11, frameon=False)
            else:
                ax.get_legend().remove()
        
    
    fig.tight_layout()
    save_or_show(fig,
                 storing_path/ f'line_plot_{comp_cat.replace(" ", "_")}.pdf', 
                 SAVE)
    

#### Wasserstein distance data 

In [None]:
from itertools import combinations
from scipy.stats import wasserstein_distance

sim_metrics = r_df['Similarity metric'].unique()
metric_combs = list(combinations(sim_metrics, 2))


rows = []

for key, group_data in r_df.groupby(['Anchor Model', 'Comparison category', 'Comparison values']):
    for met1, met2 in metric_combs:
        r_vals_met1 = group_data[group_data['Similarity metric'] == met1]['r coeff']
        r_vals_met2 = group_data[group_data['Similarity metric'] == met2]['r coeff']
        ws_dist = wasserstein_distance(r_vals_met1, r_vals_met2)
        rows.append({
            'Anchor Model':key[0], 
            'Comparison category':key[1], 
            'Comparison values':key[2],
            'Metric1': met1,
            'Metric2':met2,
            'wasserstei': ws_dist
        })
ws_df = pd.DataFrame(rows)

In [None]:
if SAVE:
    ws_df.to_csv(storing_path / 'ws_distances.csv', index=False)