In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from scipy.stats import spearmanr

from helper import save_or_show, plot_r_coeff_distribution

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
# sim_metrics = similarity_metrics
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
    'rsa_method_correlation_corr_method_spearman',
]

x_axis_ds = ['imagenet-subset-10k']
y_axis_ds = parse_datasets('../scripts/webdatasets_wo_imagenet.txt')
y_axis_ds = list(map(lambda x: x.replace('/', '_'), y_axis_ds))

# anchor_model = None
# anchor_model = "OpenCLIP_ViT-L-14_openai" # ANCHOR MODEL 1
anchor_model = "resnet50"  # ANlCHOR MODEL 2
suffix = f'_anchor_{anchor_model}' if anchor_model else ''

SAVE = True
storing_path = Path('/home/space/diverse_priors/results/plots/ds_sim__imagenet_sim')
if SAVE:
    # storing_path = Path('/home/space/diverse_priors/model_similarities/plots/diversity')
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
def get_model_ids(fn):
    with open(fn, 'r') as file:
        lines = file.readlines()
    lines = [line.strip() for line in lines]
    return lines


def load_sim_martix(path):
    model_ids_fn = path / 'model_ids.txt'
    sim_mat_fn = path / 'similarity_matrix.pt'
    if model_ids_fn.exists():
        model_ids = get_model_ids(model_ids_fn)
    else:
        raise FileNotFoundError(f'{str(model_ids_fn)} does not exist.')
    sim_mat = torch.load(sim_mat_fn)
    df = pd.DataFrame(sim_mat, index=model_ids, columns=model_ids)
    return df

In [None]:
sim_mats = {
    'x_axis': {},
    'y_axis': {}
}

for sim_metric in sim_metrics:
    sim_mats['x_axis'][sim_metric] = {}
    sim_mats['y_axis'][sim_metric] = {}

    for ds in x_axis_ds:
        sim_mats['x_axis'][sim_metric][ds] = load_sim_martix(base_path_similarity_matrices / ds / sim_metric)

    for ds in y_axis_ds:
        sim_mats['y_axis'][sim_metric][ds] = load_sim_martix(base_path_similarity_matrices / ds / sim_metric)

In [None]:
def process_sim_mat(df, anchor):
    df = df.where(np.triu(np.ones(df.shape), k=1).astype(bool))
    df = df.reset_index(names=['models_1'])
    df_melted = pd.melt(
        df,
        id_vars='models_1',
        var_name='models_2',
        value_name='Similarity value',
    )
    df_melted = df_melted.dropna().reset_index(drop=True)
    df_melted['Model pair'] = df_melted['models_1'] + ', ' + df_melted['models_2']
    df_melted = df_melted[['Model pair', 'Similarity value']].copy()
    if anchor:
        df_melted = df_melted[df_melted['Model pair'].apply(lambda x: anchor in x)].reset_index(drop=True)
    return df_melted


In [None]:
y_dfs = []

for k1, v1 in sim_mats.items():
    for k2, v2 in v1.items():
        for k3, df in v2.items():
            df = process_sim_mat(df, anchor_model)
            df['Dataset'] = k3
            df['sim_metric'] = k2
            df['Similarity metric'] = sim_metric_name_mapping[k2]
            if k1 == 'y_axis':
                y_dfs.append(df.copy())
            else:
                sim_mats[k1][k2][k3] = df.copy()

y_df = pd.concat(y_dfs, ignore_index=True)

In [None]:
def get_imagenet_sim_value(row):
    model_pair = row['Model pair']
    sim_metric = row['sim_metric']
    df = sim_mats['x_axis'][sim_metric]['imagenet-subset-10k']
    new_row = df[df['Model pair'] == model_pair]
    if len(new_row) != 1:
        raise ValueError(f'Found more/less than one entry in the data: {new_row}')
    return new_row.iloc[0]['Similarity value']

In [None]:
y_df['Similarity value ImageNet subset'] = y_df.apply(get_imagenet_sim_value, axis=1)

In [None]:
row_order = sorted(y_df['Dataset'].unique())

In [None]:
x_col = "Similarity value ImageNet subset"
y_col = "Similarity value"

g = sns.relplot(
    data=y_df,
    x=x_col,
    y=y_col,
    col="Similarity metric",
    row="Dataset",
    row_order=row_order,
    height=2,
    aspect=1.5
)

g.set_titles('{col_name}')
for ax, ds in zip(g.axes[:, 0], row_order):
    ax.set_ylabel(f'Similarity value\n{ds}')


def add_correlation(data, **kws):
    x = data[x_col]
    y = data[y_col]
    ax = plt.gca()
    corr, _ = spearmanr(x, y)
    ax.text(.05, .95, f'r = {corr:.2f}', transform=ax.transAxes,
            fontsize=11, verticalalignment='top')


g.map_dataframe(add_correlation)

g.fig.suptitle(
    f"Correlation diversity in ImageNet vs other datasets{' ( anchor = ' + anchor_model + ' )' if anchor_model else ''}",
    y=1)
g.fig.tight_layout()

# if SAVE:
#     plt.savefig(storing_path / f'diversity_imgnet_vs_ds{suffix}.pdf', bbox_inches='tight')

save_or_show(g.fig, storing_path / f'diversity_imgnet_vs_ds{suffix}.pdf', SAVE)

In [None]:
fig = plot_r_coeff_distribution(y_df, "Similarity metric", x_col, y_col, "Dataset")
save_or_show(fig, storing_path / f'distr_corr_coeff_over_datasets{suffix}.pdf', SAVE)