In [None]:
import sys
from itertools import combinations
from pathlib import Path
from matplotlib.lines import Line2D
from scipy.stats import spearmanr, pearsonr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import json
from matplotlib.lines import Line2D
from constants import (
    anchor_name_mapping, 
    available_data, 
    exclude_models, 
    exclude_models_w_mae, 
    cat_name_mapping, 
    ds_info_file, 
    model_config_file,
    fontsizes,
    model_cat_mapping,
    cat_color_mapping
)
from helper import load_model_configs_and_allowed_models, load_similarity_matrices, save_or_show, get_fmt_name, load_ds_info

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
# Get list of all datasets
ds_list = parse_datasets('../scripts/webdatasets_w_insub10k.txt')
ds_list = list(map(lambda x: x.replace('/', '_'), ds_list))

ds_info = load_ds_info(ds_info_file)

# Get subset of datasets
ds_lists = dict(
    ds_row_1_v2=['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam'],
)
curr_ds_list = ds_lists['ds_row_1_v2']

# Experiment configuration
suffix = '_wo_mae'  # '', '_wo_mae'

# Path to all similarities
sim_data_path = Path('/home/space/diverse_priors/results/aggregated/model_sims/all_metric_ds_model_pair_similarity.csv')

SAVE = True

storing_path = Path(
    f'/home/space/diverse_priors/results/plots/scatter_similarity_v5'
)
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)


cm = 0.393701

In [None]:
curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)

In [None]:
all_similarities = pd.read_csv(sim_data_path)
all_similarities.columns = ['Similarity metric', 'DS', 'Model 1', 'Model 2', 'Similarity value', 'Objective', 'Architecture', 'Dataset diversity',  'Model size']
print(all_similarities.shape)
all_similarities = all_similarities[all_similarities['Model 1'].isin(allowed_models) & all_similarities['Model 2'].isin(allowed_models) ]
print(all_similarities.shape)

In [None]:
model_cats = list(model_cat_mapping.values())
for cat in model_cats:
    all_similarities.loc[:,cat] = all_similarities.loc[:,cat].apply(eval)
    all_similarities.loc[:,cat] = all_similarities.loc[:,cat].apply(lambda x: tuple(sorted([cat_name_mapping[x[0]], cat_name_mapping[x[1]]])))

In [None]:
## Filter bad datasets
model_similarities = all_similarities[all_similarities['DS'].isin(curr_ds_list)].copy().reset_index(drop=True)

In [None]:
def get_ds_subsets(ds1, ds2, df, cat_col):
    def process_ds(ds):
        ds_data = df[df['DS'] == ds].copy().sort_values(['Model 1', 'Model 2']).reset_index(drop=True)
        ds_data.loc[:, cat_col] = ds_data.loc[:, cat_col].astype(str)
        return ds_data
    return process_ds(ds1), process_ds(ds2)

def plot_reg(ax, x, y, color):
    sns.regplot(
        x=x, y=y, color=color,
        line_kws=dict(alpha=1, ls='--', lw=3),
        scatter_kws=dict(alpha=0.5, s=5),
        ci=None, ax=ax
    )
    ax.set(xlim=(-0.01, 1.01), ylim=(-0.01, 1.01))
    ax.tick_params('both', labelsize=fontsizes['ticks'])

def set_ylabel(ax, j, ds2, name, color):
    if j == 0:
        ax.text(-0.2, 0.5, name, fontsize=fontsizes['title'],
                transform=ax.transAxes, va='center', ha='right', rotation=90, color=color)
    ax.set_ylabel(get_ds_name(ds2), fontsize=fontsizes['label'])

def get_ds_name(ds):
    curr_ds_info = ds_info.loc[ds]
    return f"{curr_ds_info['name']} {curr_ds_info['domain']}"

def compute_r(x, y, corr_type):
    corr_funcs = {'spearmanr': spearmanr, 'pearsonr': pearsonr}
    if corr_type not in corr_funcs:
        raise ValueError("Unknown correlation computation type")
    return corr_funcs[corr_type](x, y)[0]

def plot_all_pairs(ax, ds1_data, ds2_data, color_all, corr_type, frameon):
    plot_reg(ax, ds1_data['Similarity value'], ds2_data['Similarity value'], color_all)
    r_val_combi = compute_r(ds1_data['Similarity value'], ds2_data['Similarity value'], corr_type)
    ax.legend(handles=[Line2D([0], [0], color=color_all, marker='o', linestyle='None', markersize=7, alpha=0.5)],
              labels=[f"r coeff.: {r_val_combi:.2f}"],
              fontsize=fontsizes['ticks'], framealpha=0.5, frameon=frameon, title='', loc='lower right', edgecolor='white')
    ax.set_xlabel("")

def plot_within_category(ax, ds1_data, ds2_data, cat_col, color_all, cat_colors, corr_type, frameon):
    within_cat = ds1_data[cat_col].apply(eval).apply(lambda x: x[0] == x[1])
    ax.scatter(x=ds1_data.loc[~within_cat, 'Similarity value'], y=ds2_data.loc[~within_cat, 'Similarity value'],
               c=color_all, alpha=0.5, s=5)
    ds1_data, ds2_data = ds1_data.loc[within_cat, :], ds2_data.loc[within_cat, :]

    handles, labels = [], []
    for cat_pair in ds1_data[cat_col].unique():
        ds1_sub = ds1_data[ds1_data[cat_col] == cat_pair]
        ds2_sub = ds2_data[ds2_data[cat_col] == cat_pair]
        curr_color = cat_colors[eval(cat_pair)[0]]

        plot_reg(ax, ds1_sub['Similarity value'], ds2_sub['Similarity value'], curr_color)
        r_val_combi = compute_r(ds1_sub['Similarity value'], ds2_sub['Similarity value'], corr_type)
        handles.append(Line2D([0], [0], color=curr_color, marker='o', linestyle='None', markersize=7, alpha=0.5))
        labels.append(f"r {eval(cat_pair)[0]} pairs: {r_val_combi:.2f}")

    ax.legend(handles=handles, labels=labels, fontsize=fontsizes['ticks'], framealpha=0.5,
              frameon=frameon, title='', loc='lower right', edgecolor='white')

def main_plot(curr_ds_list, tmp, cat_col, color_all, cat_colors, corr_type, frameon):
    combs = list(combinations(curr_ds_list, 2))
    n, m = 2, len(combs)
    fig, axs = plt.subplots(nrows=n, ncols=m, figsize=(5*m, 4*n), sharex=True, sharey=True)

    for j, (ds1, ds2) in enumerate(combs):
        ds1_data, ds2_data = get_ds_subsets(ds1, ds2, tmp, cat_col)

        # Upper row: all pairs
        plot_all_pairs(axs[0, j], ds1_data, ds2_data, color_all, corr_type, frameon)
        set_ylabel(axs[0, j], j, ds2, "All model pairs", 'black')

        # Lower row: within category
        plot_within_category(axs[1, j], ds1_data, ds2_data, cat_col, color_all, cat_colors, corr_type, frameon)
        axs[1, j].set_xlabel(get_ds_name(ds1), fontsize=fontsizes['label'])
        set_ylabel(axs[1, j], j, ds2, "Within training objective", 'black')
        
    return fig

In [None]:
cat_col = 'Objective'
color_all = 'darkgrey'
frameon = True

for sim_metric, data in model_similarities.groupby('Similarity metric'):
    for corr_type in ['pearsonr', 'spearmanr']:
        print(sim_metric, corr_type)
        cat_ois  = sorted(list(set([element for tup in data[cat_col].unique() for element in tup])))
        cat_colors = {cat_oi: color for cat_oi, color in zip(cat_ois, sns.color_palette('tab10', len(cat_ois)))}
        fig = main_plot(curr_ds_list, data, cat_col, color_all, cat_colors, corr_type=corr_type, frameon=frameon)

        curr_sim = sim_metric.lower().replace(' ', '_')
        save_or_show(fig, storing_path / f'{cat_col}_{corr_type}_{curr_sim}{suffix}.pdf', SAVE)