In [None]:
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from itertools import product

from helper import load_model_configs_and_allowed_models, load_similarity_matrices, save_or_show

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
# sim_metrics = similarity_metrics
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
    # 'rsa_method_correlation_corr_method_spearman',
]

y_axis_ds = parse_datasets('../scripts/webdatasets_wo_imagenet.txt')
y_axis_ds = list(map(lambda x: x.replace('/', '_'), y_axis_ds))
ds_list = ['imagenet-subset-10k'] + y_axis_ds

cm = 0.393701

SAVE = True 
storing_path = Path('/home/space/diverse_priors/results/plots/mean_std_sim_matrix')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
model_configs, allowed_models = load_model_configs_and_allowed_models(
    path='../scripts/models_config.json',
    exclude_models=['SegmentAnything_vit_b', 'DreamSim_dino_vitb16', 'DreamSim_open_clip_vitb32'],
    exclude_alignment=True,
)

print(model_configs.shape, len(allowed_models))

In [None]:
info_orig_cols = ['objective', 'architecture_class', 'dataset_class', 'size_class']
info_cols = ['Objective', 'Architecture', 'Dataset size', 'Model size']

In [None]:
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)

In [None]:
mean_sim_mats = {}
std_sim_mats = {}
for sim_metric in sim_metrics:
    result = np.stack(list(sim_mats[sim_metric].values()), axis=0)
    mean_res = pd.DataFrame(result.mean(axis=0), index=allowed_models, columns=allowed_models)
    std_res = pd.DataFrame(result.std(axis=0), index=allowed_models, columns=allowed_models)
    mean_sim_mats[sim_metric] = mean_res
    std_sim_mats[sim_metric] = std_res

### Similarity matrices plots

Create figures with 3 subplots in the first row showing the similarity matrices for three different datasets (one natural images, single domain ds, and structured data). The second row should contain two subplots showing the mean and std of the similarity matrices across all datasets.
We create one figure for each similarity metric.

In [None]:
ds_lists = dict(
    ds_row_1_v1 = ['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_eurosat'],
ds_row_1_v2 = ['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam'],
ds_row_1_v3 = ['imagenet-subset-10k', 'wds_vtab_pets', 'wds_vtab_eurosat'],
ds_row_1_v4 = ['imagenet-subset-10k', 'wds_vtab_pets', 'wds_vtab_pcam'],
)

ds_name_mapping= {
    'imagenet-subset-10k': 'ImageNet (natural)',
    'wds_vtab_flowers': 'Flowers (single domain)',
    'wds_vtab_pets': 'Pets (single domain)',
    'wds_vtab_eurosat': 'Eurosat (structured)',
    'wds_vtab_pcam': 'PCAM (structured)'
}

In [None]:
def draw_vertical_and_horizontal_cats(ax_h, ax_v, width_bar=5, legend_h=False, legend_v=False):
    segments = (model_configs['objective'].value_counts().sort_index()).to_list()
    segments_names = (model_configs['objective'].value_counts().sort_index()).index.to_list()
    
    bar_width = 54
    start = 0
    bottom = bar_width
    colors = sns.color_palette('tab10', len(segments)).as_hex()
    for i, color in enumerate(colors):
        curr_width = segments[i]
        seg_name = segments_names[i]
        ax_h.barh('val', 
                    width=curr_width, 
                    left=start, 
                    height=width_bar, 
                    color=color, 
                    align='center',
                    label=seg_name)
        ax_v.bar('val', 
                    height=curr_width, 
                    bottom=bottom - curr_width, 
                    width=width_bar, 
                    color=color, 
                    align='center',
                   label=seg_name)
        start += curr_width
        bottom -= curr_width
    if legend_h:
        ax_h.legend(title="Model objectives", bbox_to_anchor=(1.01, 1), loc='upper left', frameon=False)
    if legend_v: 
        ax_v.legend(title="Model objectives", bbox_to_anchor=(1.01, 1), loc='upper left', frameon=False)
    ax_h.set_xlim(0, sum(segments))
    ax_h.axis('off')
    ax_v.set_ylim(0, sum(segments))
    ax_v.axis('off')

cmap_sim_mat = 'flare'
cmap_std_mat = 'flare'
def get_matrixplot(curr_sim_metrics_data, mean_res, std_res, ds_list):
    n, m = mean_res.shape
    n_cols = len(ds_list)
    
    fig = plt.figure(figsize=(3*12*cm, 2*12*cm))
    gs = gridspec.GridSpec(nrows=2*(m+1) + 10, ncols=3*(n+1) + 20)
    
    axs = []
    x_idx = 0
    y_idx = 0
    for i in range(n_cols):
        axs.append(fig.add_subplot(gs[y_idx, (x_idx+1):(x_idx+n+1)]))
        axs.append(fig.add_subplot(gs[(y_idx+1):(y_idx+n+1), x_idx]))
        axs.append(fig.add_subplot(gs[(y_idx+1):(y_idx+n+1), (x_idx+1):(x_idx+n+1)]))
        x_idx += n+1+10
    
    y_idx += n+1 + 10
    x_idx = 25
    add_c = 5
    for i in range(2):
        axs.append(fig.add_subplot(gs[y_idx, (x_idx+1):(x_idx+n+1)]))
        axs.append(fig.add_subplot(gs[(y_idx+1):(y_idx+n+1), x_idx]))
        axs.append(fig.add_subplot(gs[(y_idx+1):(y_idx+n+1), (x_idx+1):(x_idx+n+1)]))
        x_idx += n+1 + add_c
    
    vmin, vmax = 0, 1
    cnt = 0 
    for i, ds in enumerate(ds_list):
        ax_h = axs[cnt]
        ax_v = axs[cnt+1]
        ax_dat = axs[cnt+2]
    
        draw_vertical_and_horizontal_cats(ax_h, ax_v)
        sns.heatmap(curr_sim_metrics_data[ds], ax=ax_dat, vmin=vmin, vmax=vmax, cbar=False, cmap=cmap_sim_mat)
        ax_dat.set_title(ds_name_mapping[ds], fontsize=12, y=1.01)
        ax_dat.axis('off')
        cnt += 3
    
    # mean
    ax_h = axs[cnt]
    ax_v = axs[cnt+1]
    ax_dat = axs[cnt+2]
    draw_vertical_and_horizontal_cats(ax_h, ax_v)
    sns.heatmap(mean_res, ax=ax_dat, cbar=True, vmin=vmin, vmax=vmax, cmap=cmap_sim_mat)
    ax_dat.set_title('Mean similarity across all datasets', fontsize=12, y=1.01)
    ax_dat.axis('off')
    cnt += 3
    
    # Store the colorbar and remove it
    cbar = ax_dat.collections[0].colorbar
    cbar.remove()
    cbar_ax = fig.add_subplot(gs[y_idx:(y_idx+n+1), 11:14])
    fig.colorbar(ax_dat.collections[0], cax=cbar_ax, orientation='vertical')
    
    
    # std
    ax_h = axs[cnt]
    ax_v = axs[cnt+1]
    ax_dat = axs[cnt+2]
    draw_vertical_and_horizontal_cats(ax_h, ax_v)
    ax_h.legend(title="Model objectives", bbox_to_anchor=(1.3, 9), loc='upper left', frameon=False, fontsize=11, title_fontsize=11)
    
    sns.heatmap(std_res, ax=ax_dat, cbar=True, cmap=cmap_std_mat)
    ax_dat.set_title('Std similarities across all datasets', fontsize=12, y=1.01)
    ax_dat.axis('off')
    cnt += 3
    
    cbar = ax_dat.collections[0].colorbar
    cbar.remove()
    cbar_ax = fig.add_subplot(gs[y_idx:(y_idx+n+1), 145:148])
    fig.colorbar(ax_dat.collections[0], cax=cbar_ax, orientation='vertical', format="{x:.2f}");
    return fig

In [None]:
for sim_metric in sim_metrics:
    curr_sim_metrics_data = sim_mats[sim_metric]
    mean_res = mean_sim_mats[sim_metric]
    std_res = std_sim_mats[sim_metric]
    for ds_list in ds_lists.values():
        fig = get_matrixplot(curr_sim_metrics_data, mean_res, std_res, ds_list)
        save_or_show(fig, storing_path/ f'mean_std_sim_matrix_{sim_metric}_DS{"_".join(ds_list)}.pdf', SAVE)

### Mean vs. STD scatter plot

In [None]:
pairs = [
    ('Image-Text', 'Image-Text'),
    ('Image-Text', 'Self-Supervised'),
    ('Image-Text', 'Supervised'),
    ('Self-Supervised', 'Self-Supervised'),
    ('Self-Supervised', 'Supervised'),
    ('Supervised', 'Supervised'),
]
comb = []
for i, val in enumerate(model_configs['objective']):
    for j, val2 in enumerate(model_configs['objective']):
        if i >= j:
            continue
        if val == val2 or (val, val2) in pairs:
            comb.append(f"{val}, {val2}")
        elif (val2, val) in pairs:
            comb.append(f"{val2}, {val}")
        else:
            raise ValueError("Unknown pair")


In [None]:
n = len(sim_metrics)
fig, axs = plt.subplots(nrows=1, ncols=n, sharey=True, figsize=(n*14*cm, 10*cm))
for i, curr_sim_metric in enumerate(sim_metrics):
    mean_res = mean_sim_mats[curr_sim_metric]
    std_res = std_sim_mats[curr_sim_metric]
    mean_upper = np.triu(mean_res, k=1)
    std_upper = np.triu(std_res, k=1)
    g = sns.scatterplot(
        x=mean_upper[mean_upper != 0],
        y=std_upper[std_upper != 0],
        hue=comb,
        alpha=0.75,
        ax=axs[i],
        legend=False if i==0 else True
    )
    g.set_xlabel('Mean CKA value across all datasets', fontsize=10)
    g.set_ylabel('Std of CKA values across all datasets', fontsize=10)
    g.set_title(sim_metric_name_mapping[curr_sim_metric], fontsize=12)
    g.tick_params('both', labelsize=10)
    if i>0:
        sns.move_legend(g, "upper left", bbox_to_anchor=(1, 1), frameon=False, fontsize=10)

fig.tight_layout()
save_or_show(fig, storing_path/'mean_std_scatter_plot.pdf', SAVE)