In [1]:
import json
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import spearmanr, pearsonr
from itertools import product
import textwrap
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1 import make_axes_locatable

from constants import exclude_models, exclude_models_w_mae, cat_name_mapping, ds_info_file, model_config_file, fontsizes
from helper import load_model_configs_and_allowed_models, save_or_show, load_ds_info

sys.path.append('..')
from scripts.helper import parse_datasets

from clip_benchmark.analysis.utils import retrieve_performance

In [2]:
base_path_aggregated = Path('/home/space/diverse_priors/results/aggregated')

### Config similarity data
sim_data = pd.read_csv(base_path_aggregated / 'model_sims/all_metric_ds_model_pair_similarity.csv')

### Config performance data
ds_list_perf = parse_datasets('../scripts/webdatasets_w_in1k.txt')
ds_list_perf = list(map(lambda x: x.replace('/', '_'), ds_list_perf))

ds_info = load_ds_info(ds_info_file)

results_root = '/home/space/diverse_priors/results/linear_probe/single_model'

### Config datasets to include
ds_to_include= set(ds_list_perf) - set(['cifar100-coarse', 'entity13']) 
ds_to_include.add('imagenet-subset-10k')
remaining_ds = sorted(list(set(ds_list_perf) - set(ds_to_include)))

## Storing information
suffix = ''
# suffix = '_ wo_mae'

SAVE = True
storing_path = Path(f'/home/space/diverse_priors/results/plots/scatter_sim_vs_performance_v2')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [3]:
## Filter similarity data only for desired datasets
print(sim_data.shape)
if ds_to_include:
    sim_data = sim_data[sim_data['DS'].isin(ds_to_include)].reset_index(drop=True)
print(sim_data.shape)

(100800, 9)
(92736, 9)


In [4]:
## Rename datasets with info
sim_data['DS category'] = sim_data['DS'].apply(lambda x: ds_info.loc[x, 'domain'])
sim_data['DS'] = sim_data['DS'].apply(lambda x: ds_info.loc[x, 'name'])

In [5]:
## Post-process 'pair' columns
def pp_pair_col(df_col):
    return df_col.apply(eval).apply(lambda x: f"{cat_name_mapping[x[0]]}, {cat_name_mapping[x[1]]}")


pair_columns = [col for col in sim_data.columns if 'pair' in col]
sim_data[pair_columns] = sim_data[pair_columns].apply(pp_pair_col, axis=0)
pair_columns += [None]

In [6]:
curr_excl_models = exclude_models_w_mae if 'mae' in suffix else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)

Nr. models original=64


In [7]:
## Filter only for allowed models
sim_data = sim_data[sim_data['Model 1'].isin(allowed_models) & sim_data['Model 2'].isin(allowed_models)].reset_index(drop=True)

#### Retrieve the downstream task performances. 

In [8]:
# import warnings

# # Ignore UserWarnings
# warnings.filterwarnings("ignore", category=UserWarning)

# res = []
# for ds, mid in product(ds_list_perf, allowed_models):
#     performance = retrieve_performance(
#         model_id=mid, 
#         dataset_id=ds, 
#         metric_column='test_lp_acc1',
#         results_root='/home/space/diverse_priors/results/linear_probe/single_model',
#         regularization="weight_decay",
#         allow_db_results=False
#     )
#     res.append({
#         'DS': ds,
#         'Model': mid,
#         'TestAcc': performance
#     })
# perf_res = pd.DataFrame(res)

In [9]:
# perf_res.to_csv(base_path_aggregated/ f'single_model_performance/all_ds{suffix}.csv', index=False)

In [10]:
perf_res = pd.read_csv(base_path_aggregated / f'single_model_performance/all_ds.csv')

In [11]:
if ds_to_include:
    perf_res = perf_res[perf_res['DS'].isin(ds_to_include)].reset_index(drop=True)
perf_res['DS category'] = perf_res['DS'].apply(lambda x: ds_info.loc[x, 'domain'])
perf_res['DS'] = perf_res['DS'].apply(lambda x: ds_info.loc[x, 'name'])
perf_res = perf_res[perf_res['Model'].isin(allowed_models)].reset_index(drop=True)

#### Combine model similarities and performance measures

In [12]:
def get_model_perf(row):
    m1_perf = perf_res.loc[(perf_res['Model'] == row['Model 1']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    m2_perf = perf_res.loc[(perf_res['Model'] == row['Model 2']) & (perf_res['DS'] == row['DS']), 'TestAcc'].item()
    return m1_perf, m2_perf, np.abs(m1_perf - m2_perf)


In [13]:
performance_per_pair = pd.DataFrame(sim_data.apply(get_model_perf, axis=1).tolist(),
                                    columns=['Model 1 perf.', 'Model 2 perf.', 'abs. diff. perf.']).reset_index(drop=True)

In [14]:
sim_data_new = pd.concat([sim_data, performance_per_pair], axis=1)

#### Compute the correlations between the performance gaps and the model similarities

In [15]:
def get_correlation(subset_data):
    corr_sp, _ = spearmanr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
    corr_pr, _ = pearsonr(subset_data['Similarity value'], subset_data['abs. diff. perf.'])
    return {'spearmanr': corr_sp, 'pearsonr': corr_pr}


r_coeffs = sim_data_new.groupby(['Similarity metric', 'DS'])[['Similarity value', 'abs. diff. perf.']].apply(
    get_correlation)
r_coeffs = pd.DataFrame(r_coeffs.tolist(), index=r_coeffs.index)

#### Plot the swarmplots (i.e., correlation distributions) for each dataset category

In [16]:
r_coeffs_tmp = r_coeffs.reset_index()
r_coeffs_tmp['name'] = r_coeffs_tmp['DS']
tmp = pd.merge(r_coeffs_tmp, ds_info.reset_index(names=['DS']), how='left', on='name')
tmp = tmp.drop(columns=['DS_y'])
tmp = tmp[~tmp.duplicated()].reset_index(drop=True)
tmp = tmp.sort_values(['Similarity metric', 'domain', 'spearmanr']).reset_index(drop=True)
if SAVE: 
    fn = storing_path / 'corr_perf_vs_sim_per_ds.csv'
    tmp.to_csv(fn, index=False)
# tmp

In [17]:
melted_ds_perf_sim_corr = pd.melt(
    tmp,
    id_vars=['Similarity metric', 'DS_x', 'name', 'domain'],
    var_name='Correlation metric',
    value_name='Correlation coefficient'
)

In [18]:
def wrap_labels(ax, width=10, break_long_words=False):
    labels = []
    for label in ax.get_xticklabels():
        text = label.get_text()
        labels.append(textwrap.fill(text, width=width, break_long_words=break_long_words))
    ax.set_xticklabels(labels, ha='center')

# Create the plot
g = sns.catplot(
    melted_ds_perf_sim_corr,
    x='domain',
    y='Correlation coefficient',
    col='Similarity metric',
    hue='Correlation metric',
    order=sorted(melted_ds_perf_sim_corr['domain'].unique()),
    height=3,
    aspect=1.1,
    legend_out=True  # Ensure legend is outside
)

# Add horizontal line
g.map(plt.axhline, y=-0.5, color='grey', linestyle='--')

# Apply text wrapping to x-axis labels
for ax in g.axes.flat:
    wrap_labels(ax)

# Set titles and labels
g.set_titles("{col_name}")
g.set_xlabels("")

# Adjust the layout
plt.tight_layout()
g.fig.subplots_adjust(wspace=0.1)  # Adjust bottom and right margins

# Move the legend
sns.move_legend(g, bbox_to_anchor=(1, 0.5), loc='lower left')

save_or_show(g.fig, storing_path / f'swarm_corr_perf_vs_sim_per_ds_cat.pdf', SAVE)

  ax.set_xticklabels(labels, ha='center')
  ax.set_xticklabels(labels, ha='center')


stored img at /home/space/diverse_priors/results/plots/scatter_sim_vs_performance_v2/swarm_corr_perf_vs_sim_per_ds_cat.pdf.


#### Plot the barplots (i.e., correlation distributions) for each dataset category

In [20]:
domain_colors = {
    'pearsonr':{
        'Natural (multi-domain)': '#1f77b4',
        'Natural (single-domain)': '#ff7f0e',
        'Specialized': '#2ca02c',
        'Structured': '#d62728'
    },
    'spearmanr':{
        'Natural (multi-domain)': '#aec7e8',
        'Natural (single-domain)': '#ffbb78',
        'Specialized': '#98df8a',
        'Structured': '#ff9896'
    } 
}

df = melted_ds_perf_sim_corr[(melted_ds_perf_sim_corr['Similarity metric'] == 'CKA linear')]

# Create the plot
plt.figure(figsize=(10, 5))

# Calculate bar positions
unique_names = df['name'].unique()
x = np.arange(len(unique_names))
width = 0.4

# Plot bars for each metric
for i, metric in enumerate(['pearsonr', 'spearmanr']):
    mask = df['Correlation metric'] == metric
    data = df[mask]
    
    # Create colors list
    colors = [domain_colors[metric][domain] for domain in data['domain']]
    
    plt.bar(x[data['name'].isin(unique_names)] + (width if metric == 'spearmanr' else -width)/2,
           data['Correlation coefficient'],
           width,
           label=metric,
           color=colors)

# Customize the plot
plt.ylabel('Correlation Coefficient', fontsize=12)
plt.xticks(x, unique_names, rotation=45, ha='right', fontsize=11)
plt.tick_params('both', labelsize=11)
plt.axhline(-.3, alpha=0.5, ls=':', c='grey', zorder=-1)
plt.axhline(-.5, alpha=0.5, ls=':', c='grey', zorder=-1)
plt.axhline(-.7, alpha=0.5, ls=':', c='grey', zorder=-1)

# Create custom legend for domains and metrics
domain_patches = [plt.Rectangle((0,0),1,1, fc=color, label=domain) 
                 for domain, color in domain_colors['pearsonr'].items()]
metric_patches = [
    plt.Rectangle((0,0),1,1, fc='gray', label='Pearson coefficient (Darker)'),
    plt.Rectangle((0,0),1,1, fc='gray', alpha=0.7, label='Spearman coefficient (Lighter)')
]
# Add both legends
plt.legend(handles=domain_patches + metric_patches, 
          loc='upper left',
          bbox_to_anchor=(1, 1),
          title='Domains and Metrics',
          frameon=False,
          fontsize=11,
          title_fontsize=11)
plt.tight_layout()
save_or_show(plt.gcf(), storing_path / f'bar_corr_perf_vs_sim_per_ds_cat_cka_linear.pdf', SAVE)

stored img at /home/space/diverse_priors/results/plots/scatter_sim_vs_performance_v2/bar_corr_perf_vs_sim_per_ds_cat_cka_linear.pdf.
