## Performance tables for each model on all task and probe type combinations

The following notebook generates the performance tables for each model on all task and probe type combinations. It loads the aggregated results (`complete_set_of_run.pkl`) from the experiments and formats them into tables for easy comparison.

In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import sys

from itertools import product
import numpy as np

sys.path.append('..')
sys.path.append('../..')

from constants import BASE_PATH_PROJECT, FOLDER_SUBSTRING, experiment_with_probe_type_order_list
from helper import style_multimodel_heatmap, init_plotting_params

In [None]:
init_plotting_params()

In [None]:
SAVE = 'both'

base_storing_path = BASE_PATH_PROJECT / f"results_{FOLDER_SUBSTRING}_rebuttal/plots"
if SAVE:
    base_storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
all_runs= pd.read_pickle(BASE_PATH_PROJECT / f'results_{FOLDER_SUBSTRING}_rebuttal/aggregated/complete_set_of_run.pkl')

In [None]:
all_runs = all_runs.drop(index=all_runs[(all_runs['nr_layers'] == 1) & all_runs['contains_intermediate']].index).copy().reset_index(drop=True)
all_runs = all_runs[all_runs['probe_type'].isin(['cae', 'linear'])].copy().reset_index(drop=True)

### Some santity checks on number of runs per model/task/probe type combination

In [None]:
check_runs = all_runs[['dataset', 'Experiment']].value_counts().sort_index().reset_index()

In [None]:
check_runs_pivot = pd.pivot(
    check_runs,
    index = 'dataset',
    columns = 'Experiment',
    values= 'count'
).fillna(0)

In [None]:
check_runs_pivot = check_runs_pivot[experiment_with_probe_type_order_list].sort_index()
if SAVE:
    fn = base_storing_path / "per_experiment_eval_count" / 'check_runs_pivot.csv'
    fn.parent.mkdir(parents=True, exist_ok=True)
    check_runs_pivot.to_csv(fn)

In [None]:
check_runs_pivot_styled = check_runs_pivot.style.background_gradient(cmap='Reds_r', axis=None)
check_runs_pivot_styled

In [None]:
check_runs_pivot.sum(axis=0)

In [None]:
selected_models = sorted(all_runs['base_model'].unique())
selected_models

## Per model performance tables for all task and probe type combinations

In [None]:
check_runs_pivot = pd.pivot(
    all_runs,
    index = 'dataset_fmt',
    columns = ['base_model','Experiment'],
    values= 'test_lp_bal_acc1'
)
curr_col_order = list(product(selected_models, experiment_with_probe_type_order_list))
diff = set(curr_col_order) - set(check_runs_pivot.columns.tolist())
for col in diff:
    check_runs_pivot.loc[:, col] = np.nan

check_runs_pivot = check_runs_pivot.loc[:, curr_col_order].apply(pd.to_numeric, errors='coerce')

In [None]:
color_maps = [
    'Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds',
    'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
    'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn'
    ]
for i, model in enumerate(check_runs_pivot.columns.get_level_values(0).unique()):
    tmp = check_runs_pivot.loc[:, [(model, col) for col in experiment_with_probe_type_order_list]].copy()
    tmp *= 100
    tmp2 = tmp.copy()
    for row_idx, row_data in tmp2.iterrows():
        if 'mae' in model:
            tmp2.loc[row_idx, row_data.index] = row_data - row_data.loc[(model, 'AP last layer')]
        else: 
            tmp2.loc[row_idx, row_data.index] = row_data - row_data.loc[(model, 'CLS last layer')]
    
    tmp.loc["min perf. gain", :] = tmp2.min(skipna=True, axis=0)
    tmp.loc["median perf. gain", :] =  tmp2.median(skipna=True, axis=0)
    tmp.loc["max perf. gain", :] =  tmp2.max(skipna=True, axis=0)
    tmp.loc["mean perf. gain", :] =  tmp2.mean(skipna=True, axis=0)
    tmp.loc["std perf. gain", :] =  tmp2.std(skipna=True, axis=0)
    
    styled_df = style_multimodel_heatmap(tmp, color_maps=color_maps[i:(i+1)])
    
    if SAVE:
        fn = base_storing_path / "per_model_all_performances" / f'perf_table_{model}.csv'
        fn.parent.mkdir(parents=True, exist_ok=True)
        tmp.to_csv(fn)
        print(f"Stored {model=} performance table at {fn=}.")
        print()
    else:
        display(styled_df)