In [None]:
import sys
from itertools import combinations
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import json
from matplotlib.lines import Line2D
from constants import (
    anchor_name_mapping, 
    available_data, 
    exclude_models, 
    exclude_models_w_mae, 
    cat_name_mapping, 
    ds_info_file, 
    model_config_file,
    fontsizes,
    model_cat_mapping,
    cat_color_mapping
)
from scipy.stats import spearmanr, pearsonr
from helper import load_model_configs_and_allowed_models, load_similarity_matrices, save_or_show, get_fmt_name, load_ds_info

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
# Get list of all datasets
ds_list = parse_datasets('../scripts/webdatasets_w_insub10k.txt')
ds_list = list(map(lambda x: x.replace('/', '_'), ds_list))

ds_info = load_ds_info(ds_info_file)

# Get subset of datasets
ds_lists = dict(
    ds_row_1_v2=['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam'],
)
curr_ds_list = ds_lists['ds_row_1_v2']

# Experiment configuration
corr_type = 'spearmanr'  # 'pearsonr', 'spearmanr'
suffix = '_wo_mae'  # '', '_wo_mae'
exp_conf = f'{corr_type}{suffix}'


# Path to correlation data
base_path = Path('/home/space/diverse_priors/results/aggregated/r_coeff_dist/with_cats_as_anchors')
agg_data_path = base_path / f'agg_{corr_type}_all_ds{suffix}.csv'


# Path to all similarities
sim_data_path = Path('/home/space/diverse_priors/results/aggregated/model_sims/all_metric_ds_model_pair_similarity.csv')


SAVE = True

storing_path = Path(
    f'/home/space/diverse_priors/results/plots/scatter_similarity_v4/{exp_conf}'
)
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)


cm = 0.393701

In [None]:
curr_excl_models = exclude_models_w_mae if 'mae' in exp_conf else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)

In [None]:
all_similarities = pd.read_csv(sim_data_path)
all_similarities.columns = ['Similarity metric', 'DS', 'Model 1', 'Model 2', 'Similarity value', 'Objective', 'Architecture', 'Dataset diversity',  'Model size']

In [None]:
model_cats = list(model_cat_mapping.values())
for cat in model_cats:
    all_similarities.loc[:,cat] = all_similarities.loc[:,cat].apply(eval)
    all_similarities.loc[:,cat] = all_similarities.loc[:,cat].apply(lambda x: tuple(sorted([cat_name_mapping[x[0]], cat_name_mapping[x[1]]])))

In [None]:
model_similarities = all_similarities[all_similarities['DS'].isin(curr_ds_list)].copy().reset_index(drop=True)

In [None]:
r_df = pd.read_csv(agg_data_path)
r_df = r_df[r_df['ds1'].isin(curr_ds_list) & r_df['ds2'].isin(curr_ds_list)]
r_df['cat_pair'] = r_df[['anchor_cat', 'other_cat']].apply(lambda x: tuple(sorted(x.tolist())), axis=1)
r_df = r_df[['ds1', 'ds2', 'cat_pair', 'r coeff','Comparison category', 'Similarity metric']]
r_df = r_df[~r_df.duplicated()]

In [None]:
def get_ds_name(ds):
    curr_ds_info = ds_info.loc[ds]
    return f"{curr_ds_info['name']} {curr_ds_info['domain']}"

def get_r_val(ds1, ds2, r_data):
    return r_data[r_data['ds1'].isin([ds1, ds2]) & r_data['ds2'].isin([ds1, ds2])].copy()


def get_ds_subsets(ds1, ds2, df, cat_col):
    ds1_data = df[df['DS']==ds1].copy().sort_values(['Model 1', 'Model 2']).reset_index(drop=True)
    ds2_data = df[df['DS']==ds2].copy().sort_values(['Model 1', 'Model 2']).reset_index(drop=True)

    ds1_data.loc[:, cat_col] = ds1_data.loc[:, cat_col].apply(str)
    return ds1_data, ds2_data

def plot_reg(ax, x, y, color):
    sns.regplot(
        x = x,
        y = y,
        color = color,
        line_kws=dict(alpha=1, ls='--', lw=3),
        scatter_kws=dict(alpha=0.5, s=5),
        ci=None,
        ax = ax
        )
    ax.set_xlim([-0.01,1.01])
    ax.set_xlim([-0.01,1.01])
    ax.tick_params('both', labelsize=fontsizes['ticks'])

def set_ylabel(ax, j, ds2, name, color):
    if j==0:
        ax.text(-0.2, 0.5, name, fontsize=fontsizes['title'],
                transform=ax.transAxes, va='center', ha='right', rotation=90, color=color)
        ax.set_ylabel(get_ds_name(ds2), fontsize=fontsizes['label'])
    else:
        ax.set_ylabel(get_ds_name(ds2), fontsize=fontsizes['label'])
    

def plot_scatter_subset_data(sim_data, r_val_data, cat_col, cat_ois):
    
    combs = list(combinations(curr_ds_list, 2))
    n, m = (len(cat_ois) + 1), len(combs)
    fig, axs = plt.subplots(nrows=n, ncols=m, figsize=(4*m, 3*n), sharex=True, sharey=True)

    cat_colors = {cat_oi: color for cat_oi, color in zip(cat_ois, sns.color_palette('tab10', n))}

    color_all = 'darkgrey'
    # color_all = '#005f5f'
    frameon = True
    for j, (ds1, ds2) in enumerate(combinations(curr_ds_list, 2)):
        ax = axs[0, j]
        
        ds1_data, ds2_data = get_ds_subsets(ds1, ds2, sim_data, cat_col)
        
        plot_reg(ax, ds1_data['Similarity value'], ds2_data['Similarity value'], color_all)

        if corr_type == 'spearmanr':
            r_val_combi, _ = spearmanr(ds1_data['Similarity value'], ds2_data['Similarity value'])
        elif corr_type == 'pearsonr':
            r_val_combi, _ = pearsonr(ds1_data['Similarity value'], ds2_data['Similarity value'])
        else:
            raise ValueError("Unknown correlation computation type")
    
        ax.legend(handles=[Line2D([0], [0], color=color_all, marker='o', linestyle='None', markersize=7, alpha=0.5)], 
                  labels=[f"r : {r_val_combi:.2f}"], 
                  fontsize=fontsizes['ticks'], 
                  framealpha=0.5, 
                  frameon=frameon, 
                  title='',
                  loc='lower right',
                  edgecolor='white')

        ax.set_xlabel("")
        set_ylabel(ax, j, ds2, f"All model pairs", 'black')

    for i, cat_oi in enumerate(cat_ois, start=1):
        # subset similarity data 
        sim_data_cat_oi = sim_data[sim_data[cat].apply(lambda x: cat_oi in x)]
        # subset correlation coeficient data
        r_vals_cat_oi = r_val_data[r_val_data['cat_pair'].isin(sim_data_cat_oi[cat].unique())]
        
        # for each pair of datasets create a scatterplot
        for j, (ds1, ds2) in enumerate(combinations(curr_ds_list, 2)):
            ax = axs[i, j]
            
            # corr values
            curr_r_vals = get_r_val(ds1, ds2, r_vals_cat_oi)
            curr_r_vals.loc[:, 'cat_pair'] = curr_r_vals.loc[:, 'cat_pair'].apply(str)
    
            # sim values
            ds1_data, ds2_data = get_ds_subsets(ds1, ds2, sim_data_cat_oi, cat_col)
            
            handles = []
            labels = []
            for sub_cat in cat_ois:
                idx = (ds1_data.loc[:, cat_col] == str(tuple([sub_cat, cat_oi]))) | (ds1_data.loc[:, cat_col] == str(tuple([cat_oi, sub_cat])))

                plot_reg(ax, ds1_data.loc[idx,'Similarity value'], ds2_data.loc[idx, 'Similarity value'], cat_color_mapping[sub_cat])

                idx_r = (curr_r_vals.loc[:, 'cat_pair'] == str(tuple([sub_cat, cat_oi]))) | (curr_r_vals.loc[:, 'cat_pair'] == str(tuple([cat_oi, sub_cat])))
                r_val_combi = curr_r_vals.loc[idx_r, 'r coeff'].item()
                
                handles.append(
                    Line2D([0], [0],
                           color=cat_color_mapping[sub_cat], 
                           marker='o', linestyle='None',
                           markersize=7, alpha=0.5))
                labels.append(f"r {sub_cat}: {r_val_combi:.2f}")
    

            if i == (n-1):
                ax.set_xlabel(get_ds_name(ds1), fontsize=fontsizes['label'])
            else:
                ax.set_xlabel("")

            set_ylabel(ax, j, ds2, f"Anchor category: {cat_oi}", cat_color_mapping[cat_oi])   
    
            ax.legend(handles=handles, labels=labels, fontsize=fontsizes['ticks'], framealpha=0.5, frameon=frameon, title='',
                          loc='lower right', edgecolor='white')

    fig.tight_layout()

    return fig

In [None]:
for cat in model_cats:
    sub_cats = sorted(list(set([element for tup in model_similarities[cat].unique() for element in tup])))
    cat_r_df = r_df[r_df['Comparison category'] == cat].copy().reset_index(drop=True)
    for sim_metric, data in model_similarities.groupby('Similarity metric'):
        cat_r_df_sim_met = cat_r_df[cat_r_df['Similarity metric']==sim_metric]
        fig = plot_scatter_subset_data(data, cat_r_df_sim_met, cat, sub_cats)

        curr_cat = cat.replace(" ", "_").lower()
        curr_sim = sim_metric.replace(" ", "_").lower()
        save_or_show(fig, storing_path / f'{curr_cat}_{curr_sim}.pdf', SAVE)