In [None]:
import sys
from collections import defaultdict
from pathlib import Path

import pandas as pd
import torch
from scipy.stats import spearmanr

from helper import load_model_configs_and_allowed_models

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping, anchors

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
# sim_metrics = similarity_metrics
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
    'rsa_method_correlation_corr_method_spearman',
]

x_axis_ds = 'imagenet-subset-10k'
y_axis_ds = parse_datasets('../scripts/webdatasets_wo_imagenet.txt')
y_axis_ds = list(map(lambda x: x.replace('/', '_'), y_axis_ds))

storing_path = Path('/home/space/diverse_priors/results/aggregated/r_coeff_dist')
storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
model_configs, allowed_models = load_model_configs_and_allowed_models(
    path='../scripts/models_config.json',
    exclude_models=['SegmentAnything_vit_b', 'DreamSim_dino_vitb16', 'DreamSim_open_clip_vitb32'],
    exclude_alignment=True,
)

print(model_configs.shape, len(allowed_models))

In [None]:
def get_model_ids(fn):
    with open(fn, 'r') as file:
        lines = file.readlines()
    lines = [line.strip() for line in lines]
    return lines


def load_sim_martix(path):
    model_ids_fn = path / 'model_ids.txt'
    sim_mat_fn = path / 'similarity_matrix.pt'
    if model_ids_fn.exists():
        model_ids = get_model_ids(model_ids_fn)
    else:
        raise FileNotFoundError(f'{str(model_ids_fn)} does not exist.')
    sim_mat = torch.load(sim_mat_fn)
    df = pd.DataFrame(sim_mat, index=model_ids, columns=model_ids)

    available_models = sorted(list(set(model_ids).intersection(allowed_models)))

    df = df.loc[available_models, available_models]
    return df

In [None]:
sim_mats = defaultdict(dict)
for sim_metric in sim_metrics:
    for ds in [x_axis_ds] + y_axis_ds:
        sim_mats[sim_metric][ds] = load_sim_martix(base_path_similarity_matrices / ds / sim_metric)

In [None]:
anchor_col = 'Anchor Model'
other_col = 'Other Model'
other_ds_col = 'Other Dataset'
sim_metric_col = 'Similarity metric'
sim_ds_col = 'Similarity value DS'
sim_imgnet_col = 'Similarity value IN'
info_orig_cols = ['objective', 'architecture_class', 'dataset_class', 'size_class']
info_cols = ['Objective', 'Architecture', 'Dataset size', 'Model size']
id_cols = [anchor_col, other_col, sim_metric_col] + info_cols + [x_axis_ds]
comp_cat_col = 'Comparison category'
comp_cat_orig_col = 'Comparison category (orig. name)'
comp_val_col = 'Comparison values'
r_col = 'r coeff'


def get_other_model_info(mid):
    model_config = model_configs.loc[mid]
    return model_config['objective'], model_config['architecture_class'], model_config['dataset_class'], model_config[
        'size_class']


def get_melted_sim_values_metric_anchor(anch, met, met_ds_mats):
    sim_vals_ds = []
    for ds, curr_sim_mat in met_ds_mats.items():
        cols = curr_sim_mat.columns.tolist()
        cols.remove(anch)
        cols = sorted(list(set(cols).intersection(allowed_models)))
        row_sim_mat = curr_sim_mat.loc[anch, cols]
        row_sim_mat.name = ds
        sim_vals_ds.append(row_sim_mat)

    anchor_sim_vals = pd.concat(sim_vals_ds, axis=1)
    anchor_sim_vals = anchor_sim_vals.reset_index(names=[other_col])
    anchor_sim_vals = pd.concat([anchor_sim_vals,
                                 pd.DataFrame(anchor_sim_vals[other_col].apply(get_other_model_info).tolist(),
                                              columns=info_cols)], axis=1)
    anchor_sim_vals[sim_metric_col] = sim_metric_name_mapping[sim_metric]
    anchor_sim_vals[anchor_col] = anch
    anchor_sim_vals = pd.melt(anchor_sim_vals,
                              id_vars=id_cols,
                              var_name=other_ds_col,
                              value_name=sim_ds_col,
                              )
    anchor_sim_vals.rename(columns={x_axis_ds: sim_imgnet_col}, inplace=True)
    return anchor_sim_vals


dfs = []
for anchor in anchors:
    for sim_metric, ds_sim_mat in sim_mats.items():
        anchor_sim_vals = get_melted_sim_values_metric_anchor(anchor, sim_metric, ds_sim_mat)
        dfs.append(anchor_sim_vals)

all_sims = pd.concat(dfs, axis=0).reset_index(drop=True)

In [None]:
def compute_corr(data):
    x = data[sim_imgnet_col]
    y = data[sim_ds_col]
    corr, _ = spearmanr(x, y)
    return corr


r_dfs = []
for strata in info_cols:
    grouping_cols = [sim_metric_col, anchor_col, other_ds_col, strata]

    strata_rs = all_sims.groupby(grouping_cols, dropna=False).apply(compute_corr, include_groups=False).reset_index()
    strata_rs.columns = [sim_metric_col, anchor_col, other_ds_col, strata, r_col]

    all_rs = all_sims.groupby(grouping_cols[:-1], dropna=False).apply(compute_corr, include_groups=False).reset_index()
    all_rs.columns = [sim_metric_col, anchor_col, other_ds_col, r_col]
    all_rs[strata] = 'All'

    rs = pd.concat([all_rs, strata_rs], axis=0).reset_index(drop=True)
    rs = rs.sort_values([sim_metric_col, anchor_col, other_ds_col]).reset_index(drop=True)
    r_dfs.append(rs)

In [None]:
for i in range(len(r_dfs)):
    r_dfs[i][comp_cat_col] = info_cols[i]
    r_dfs[i][comp_cat_orig_col] = info_orig_cols[i]
    r_dfs[i].rename(columns={info_cols[i]: comp_val_col}, inplace=True)

In [None]:
r_df = pd.concat(r_dfs, axis=0)

In [None]:
r_df.to_csv(storing_path / f'agg_corr_coeffs.csv', index=False)