# Model Family Performance Comparison: Comprehensive Heatmap Analysis

This notebook generates comprehensive performance comparison heatmaps across different the model families (CLIP, DINOv2, ViT) and sizes (Small, Base, Large). It processes pre-computed performance tables and creates styled visualizations to compare experimental configurations.

This notebook loads the results created by `create_tables_full_model_performances.ipynb` and generates performance tables for each model family on all task and some probe type combinations.
These are the tables of Appendix Figure 5. 

The analysis loads CSV performance tables for each model and creates color-coded heatmaps showing test balanced accuracy across datasets and experimental setups (different layer combinations and pooling strategies).

**Output**: Multi-model heatmaps with distinct color schemes for each model family, displaying performance metrics across datasets. Each model within a family uses a different colormap, enabling easy comparison of how model size affects performance across various experimental configurations and datasets.

In [None]:
import sys
import pandas as pd
from pathlib import Path
import matplotlib.cm as cm
import matplotlib.colors as mcolors

sys.path.append("..")
sys.path.append("../..")

from helper import init_plotting_params
from constants import base_model_name_mapping, BASE_PATH_PROJECT, FOLDER_SUBSTRING

In [3]:
init_plotting_params()

{
  "agg.path.chunksize": 0,
  "axes.labelsize": 13.0,
  "axes.titlesize": 14.0,
  "axes3d.trackballsize": 0.667,
  "boxplot.flierprops.markersize": 6.0,
  "boxplot.meanprops.markersize": 6.0,
  "errorbar.capsize": 0.0,
  "figure.figsize": [
    6.4,
    4.8
  ],
  "figure.labelsize": "large",
  "figure.titlesize": "large",
  "font.cursive": [
    "Apple Chancery",
    "Textile",
    "Zapf Chancery",
    "Sand",
    "Script MT",
    "Felipa",
    "Comic Neue",
    "Comic Sans MS",
    "cursive"
  ],
  "font.family": [
    "sans-serif"
  ],
  "font.fantasy": [
    "Chicago",
    "Charcoal",
    "Impact",
    "Western",
    "xkcd script",
    "fantasy"
  ],
  "font.monospace": [
    "DejaVu Sans Mono",
    "Bitstream Vera Sans Mono",
    "Computer Modern Typewriter",
    "Andale Mono",
    "Nimbus Mono L",
    "Courier New",
    "Courier",
    "Fixed",
    "Terminal",
    "monospace"
  ],
  "font.sans-serif": [
    "DejaVu Sans",
    "Bitstream Vera Sans",
    "Computer Modern Sans Serif

In [None]:
base_res_path = BASE_PATH_PROJECT / f"results_{FOLDER_SUBSTRING}_rebuttal/plots/per_model_all_performances"

In [5]:
color_maps = [
    'Purples', 'Blues', 'Greens', 'Oranges', 'Reds',
    'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
    'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn'
    ]
model_name_mapping = {
    "mae-vit-base-p16_v2": "MAE-B-16",
    "mae-vit-large-p16_v2": "MAE-L-16",
    "mae-vit-base-p16": "MAE-B-16",
    "mae-vit-large-p16": "MAE-L-16",
    
}
exp_name_mapping = {
    'CLS last layer': "Last layer\n(CLS,\nlinear)",
    'AP last layer': "Last layer\n(AP,\nlinear)",
    'CLS+AP last layer (linear)': "Last layer\n(CLS+AP,\nlinear)",
    'CLS+AP layers from middle & last block (linear)': "Two layers\n(CLS+AP,\nlinear)",
    'CLS+AP layers from middle & last blocks (linear)': "Two layers\n(CLS+AP,\nlinear)",
    'CLS+AP layers from quarterly block (linear)': "Four layers\n(CLS+AP,\nlinear)",
    'CLS+AP layers from quarterly blocks (linear)': "Four layers\n(CLS+AP,\nlinear)",
    'CLS+AP layers from all blocks (linear)': "All layers\n(CLS+AP,\nlinear)",
    'CLS+AP last layer (attentive)': "Last layer\n(CLS+AP,\nattentive)",
    'CLS+AP layers from middle & last block (attentive)': "Two layers\n(CLS+AP,\nattentive)",
    'CLS+AP layers from middle & last blocks (attentive)': "Two layers\n(CLS+AP,\nattentive)",
    'CLS+AP layers from quarterly block (attentive)': "Four layers\n(CLS+AP,\nattentive)",
    'CLS+AP layers from quarterly blocks (attentive)': "Four layers\n(CLS+AP,\nattentive)",
    'CLS+AP layers from all blocks (attentive)': "All layers\n(CLS+AP,\nattentive)",
    'All tokens last layer (attentive)': "Last layer\n(all tokens,\nattentive)",
}

In [6]:
from collections import defaultdict
model_path_dict = defaultdict(list)
for i, path in enumerate(base_res_path.rglob("*.csv")):
    model_name = path.stem.split('perf_table_')[-1]
    if "_v2" in model_name:
        continue
    if 'dinov2-' in model_name:
        model_path_dict['dinov2'].append(path)
    elif 'mae-' in model_name: 
        model_path_dict['mae'].append(path)
    elif 'vit_' in model_name: 
        model_path_dict['vit'].append(path)
    else:
        model_path_dict['clip'].append(path)

In [7]:
df_list = {}
for model_family, models_paths in model_path_dict.items():
    for i, path in enumerate(models_paths):
        model_name = base_model_name_mapping[path.stem.split("perf_table_")[1]]
    
        df = pd.read_csv(path, index_col=0, header=1)
        df.index.name = None
        df = df.drop(index='dataset_fmt')
        df.rename(index={'Diabetic Retinopathy': 'Diabetic\nRetinopathy',
                         "PASCAL VOC 2007": "PASCAL VOC07"}, inplace=True)
        df.columns = [exp_name_mapping[col] for col in df.columns]
        for col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
        df = df.dropna(axis=1, how='all')
    
        # add model name as meta column
        df.insert(0, "Model", model_name)
    
        df_list[model_name] = df
df_list.keys()

dict_keys(['ViT-L-16', 'ViT-S-16', 'ViT-B-16', 'CLIP-B-32', 'CLIP-B-16', 'CLIP-L-14', 'MAE-B-16', 'MAE-L-16', 'DINOv2-L-14', 'DINOv2-S-14', 'DINOv2-B-14'])

In [8]:
tmp = pd.concat([df.loc['mean perf. gain'] for name, df in df_list.items() if 'MAE' not in name ], axis=1).T
tmp.drop(columns=['Model']).mean().sort_values()

Last layer\n(CLS,\nlinear)                   0.0
Last layer\n(CLS+AP,\nlinear)           1.305768
Last layer\n(CLS+AP,\nattentive)        1.842278
Two layers\n(CLS+AP,\nlinear)           2.583626
Four layers\n(CLS+AP,\nlinear)          3.291752
Two layers\n(CLS+AP,\nattentive)         3.68103
All layers\n(CLS+AP,\nlinear)           3.796811
Four layers\n(CLS+AP,\nattentive)       4.816793
Last layer\n(all tokens,\nattentive)    5.309449
All layers\n(CLS+AP,\nattentive)        5.538655
dtype: object

In [9]:
model_lists_to_consider = [
    ['ViT-S-16', 'ViT-B-16', 'ViT-L-16'],
    ['CLIP-B-32', 'CLIP-B-16', 'CLIP-L-14',],
    ['MAE-B-16', 'MAE-L-16', ],
    ['DINOv2-S-14', 'DINOv2-B-14', 'DINOv2-L-14']
]

# Remove 'Model' column from each df and create a multi-level column index
drop_col_1 = ['Last layer\n(CLS+AP,\nlinear)','Two layers\n(CLS+AP,\nlinear)','Four layers\n(CLS+AP,\nlinear)', 'Last layer\n(CLS+AP,\nattentive)']
drop_col_2 = ['Last layer\n(CLS+AP,\nlinear)','Last layer\n(CLS+AP,\nattentive)',]


def style_multimodel_heatmap(df, model_cmaps, precision=2):
    """
    df: pandas DataFrame with MultiIndex columns (level 0 = model, level 1 = metric)
    model_cmaps: dict of {model_name: colormap_name}
    """
    models = df.columns.get_level_values(0).unique()

    def highlight_row(row):
        styles = []
        for model in models:
            cmap_name = model_cmaps.get(model, "Reds")
            cmap = cm.get_cmap(cmap_name)

            # Extract only this row for the model
            values = row[model].values.astype(float)
            
            # Normalize per row (row-wise within model)
            vmin, vmax = values.min(), values.max()
            if vmin == vmax:  # avoid divide-by-zero
                normed = [0.5] * len(values)
            else:
                normed = (values - vmin) / (vmax - vmin)
            
            for n in normed:
                rgba = cmap(n)
                color = mcolors.to_hex(rgba)
                styles.append(f"background-color: {color}")
        return styles

    styler = df.style.format(precision=precision, na_rep="")
    styler = styler.apply(highlight_row, axis=1)
    # Apply background_gradient per model block
    for model in df.columns.get_level_values(0).unique():
        #print(model)
        cmap_name = model_cmaps.get(model, "Reds")
        # select all columns under this model
        subset = [col for col in df.columns if col[0] == model]
        styler = styler.background_gradient(cmap=cmap_name, axis=1, subset=subset)
    return styler

    

def get_combined_df(models_to_consider):
    dfs_multi = []
    for model in models_to_consider:
        df = df_list[model]
        if model not in models_to_consider:
            continue
        df_copy = df.drop(columns="Model").copy()
        if model == 'CLIP-B-32':
            df_copy = df_copy.drop(columns=['Two layers\n(CLS+AP,\nattentive)','Four layers\n(CLS+AP,\nattentive)'])
        try:
            df_copy = df_copy.drop(columns=drop_col_1)
        except:
            df_copy = df_copy.drop(columns=drop_col_2)
        df_copy.columns = pd.MultiIndex.from_product([[model], df_copy.columns])
        dfs_multi.append(df_copy)
    combined_df = pd.concat(dfs_multi, axis=1)
    return combined_df


def style_combined(combined_df):
    color_maps = {
        'ViT-S-16': "Reds",
        'ViT-B-16': "Greens", 
        'ViT-L-16': "Blues",
        'CLIP-B-32': "Reds",
        'CLIP-B-16': "Greens", 
        'CLIP-L-14': "Blues", 
        'MAE-B-16': "Greens", 
        'MAE-L-16': "Blues", 
        'DINOv2-S-14': "Reds",
        'DINOv2-B-14': "Greens",
        'DINOv2-L-14': "Blues",
    }
    
    styled = style_multimodel_heatmap(combined_df, model_cmaps=color_maps)
    styled = styled.set_table_styles([
        # top-level column headers (models)
        {'selector': 'th.col_heading.level0', 'props': [('text-align', 'center')]},
        # second-level column headers (metrics)
        {'selector': 'th.col_heading.level1', 'props': [('text-align', 'center')]},
        # row headers (dataset names)
        {'selector': '.row_heading', 'props': [('text-align', 'left')]}
    ])
    col_width="75px"
    data_font_size="15px"
    header_font_size="13px"
    index_font_size="13px"
    return styled.set_table_styles([
            {'selector': 'table', 'props': [('table-layout', 'fixed')]},
            
            # Data cells - largest font
            {'selector': 'td:not(.row_heading)', 'props': [
                ('width', col_width),
                ('text-align', 'center'),
                ('font-size', data_font_size),
                ('font-weight', 'bold'),
                ('white-space', 'nowrap')
            ]},
            
            {'selector': 'th.col_heading', 'props': [
                ('width', col_width),
                ('font-size', header_font_size),
                ('vertical-align', 'top'),
                ('text-align', 'center'),
            ]},
            
            {'selector': '.row_heading', 'props': [
                ('font-weight', 'bold'),
                ('font-size', index_font_size),
                ('width', 'auto'),
                ('text-align', 'left'),
            ]}
        ])



for models_to_consider in model_lists_to_consider:
    combined_df = get_combined_df(models_to_consider)
    display(style_combined(combined_df))
    



  cmap = cm.get_cmap(cmap_name)


Unnamed: 0_level_0,ViT-S-16,ViT-S-16,ViT-S-16,ViT-S-16,ViT-B-16,ViT-B-16,ViT-B-16,ViT-B-16,ViT-B-16,ViT-B-16,ViT-L-16,ViT-L-16,ViT-L-16,ViT-L-16
Unnamed: 0_level_1,"Last layer (CLS, linear)","Last layer (all tokens, attentive)","All layers (CLS+AP, linear)","All layers (CLS+AP, attentive)","Last layer (CLS, linear)","Last layer (all tokens, attentive)","All layers (CLS+AP, linear)","Two layers (CLS+AP, attentive)","Four layers (CLS+AP, attentive)","All layers (CLS+AP, attentive)","Last layer (CLS, linear)","Last layer (all tokens, attentive)","All layers (CLS+AP, linear)","All layers (CLS+AP, attentive)"
CIFAR-10,96.6,96.7,96.91,97.05,97.72,97.87,97.91,97.89,98.12,98.14,98.96,98.96,99.01,98.99
CIFAR-100,84.67,85.38,86.7,86.82,88.22,89.6,90.01,89.41,89.86,90.09,91.85,92.86,93.07,93.03
Caltech-101,94.85,95.23,94.97,96.01,96.19,95.8,96.85,96.59,96.51,96.54,96.94,97.14,97.25,97.36
Country-211,15.49,12.64,18.22,18.81,16.01,15.27,20.35,17.73,19.99,20.7,19.52,19.56,23.2,23.81
DTD,73.35,73.88,76.91,77.34,73.78,77.61,79.84,77.98,79.63,78.83,77.82,80.21,81.06,80.0
Diabetic Retinopathy,45.83,46.52,51.04,51.42,47.24,46.86,51.09,50.22,51.79,52.14,47.27,46.92,51.58,52.26
Dmlab,43.27,52.44,49.86,50.54,44.0,54.98,51.23,49.05,51.37,52.49,46.74,57.58,54.21,55.11
EuroSAT,95.77,97.14,98.03,98.23,95.09,97.45,98.24,97.81,98.35,98.55,96.64,98.11,98.02,98.29
FER2013,53.22,62.9,60.91,65.04,54.89,64.71,61.28,60.49,65.83,67.15,58.5,66.49,65.77,69.95
FGVC Aircraft,42.29,49.36,44.89,48.53,52.38,61.38,53.52,52.59,56.91,60.35,46.31,65.3,50.35,58.81


  cmap = cm.get_cmap(cmap_name)


Unnamed: 0_level_0,CLIP-B-32,CLIP-B-32,CLIP-B-32,CLIP-B-32,CLIP-B-16,CLIP-B-16,CLIP-B-16,CLIP-B-16,CLIP-B-16,CLIP-B-16,CLIP-L-14,CLIP-L-14,CLIP-L-14,CLIP-L-14
Unnamed: 0_level_1,"Last layer (CLS, linear)","Last layer (all tokens, attentive)","All layers (CLS+AP, linear)","All layers (CLS+AP, attentive)","Last layer (CLS, linear)","Last layer (all tokens, attentive)","All layers (CLS+AP, linear)","Two layers (CLS+AP, attentive)","Four layers (CLS+AP, attentive)","All layers (CLS+AP, attentive)","Last layer (CLS, linear)","Last layer (all tokens, attentive)","All layers (CLS+AP, linear)","All layers (CLS+AP, attentive)"
CIFAR-10,93.63,94.61,95.4,95.69,94.38,95.96,96.14,95.54,96.17,96.41,97.23,98.06,98.04,98.34
CIFAR-100,76.98,80.09,82.66,83.9,77.43,81.55,85.09,82.74,84.67,85.84,84.16,87.16,88.12,89.01
Caltech-101,92.69,93.02,93.2,94.81,94.15,94.53,93.85,93.61,94.9,94.83,96.7,96.45,96.33,96.73
Country-211,22.94,19.88,24.87,28.03,27.22,25.8,30.71,29.26,32.41,32.86,35.12,34.35,39.25,41.6
DTD,71.65,68.4,77.82,77.02,72.5,72.39,79.73,80.43,80.05,79.79,76.38,79.41,81.17,81.28
Diabetic Retinopathy,41.14,42.48,49.11,50.05,43.15,46.92,52.27,49.6,51.8,52.51,44.11,49.36,52.52,53.96
Dmlab,40.17,55.21,50.6,53.65,41.91,56.3,51.37,49.58,49.86,55.82,44.38,59.84,55.31,59.34
EuroSAT,89.96,96.75,97.69,98.03,90.0,96.52,97.72,97.05,97.72,97.89,92.48,97.51,98.51,98.58
FER2013,59.57,65.16,66.31,69.52,64.14,68.19,68.07,67.52,66.86,71.89,67.16,72.89,72.43,74.18
FGVC Aircraft,38.9,50.4,42.5,48.04,51.83,56.12,51.04,53.43,54.79,56.59,60.98,67.62,58.96,65.07


  cmap = cm.get_cmap(cmap_name)


Unnamed: 0_level_0,MAE-B-16,MAE-B-16,MAE-B-16,MAE-B-16,MAE-B-16,MAE-B-16,MAE-B-16,MAE-L-16,MAE-L-16,MAE-L-16,MAE-L-16,MAE-L-16
Unnamed: 0_level_1,"Last layer (CLS, linear)","Last layer (AP, linear)","Last layer (all tokens, attentive)","All layers (CLS+AP, linear)","Two layers (CLS+AP, attentive)","Four layers (CLS+AP, attentive)","All layers (CLS+AP, attentive)","Last layer (CLS, linear)","Last layer (AP, linear)","Last layer (all tokens, attentive)","All layers (CLS+AP, linear)","All layers (CLS+AP, attentive)"
CIFAR-10,50.07,57.73,92.85,84.47,73.22,83.84,88.32,69.7,66.21,95.62,93.49,94.33
CIFAR-100,26.03,32.9,75.91,66.36,49.13,66.25,71.24,40.93,40.65,82.53,79.01,79.95
Caltech-101,62.65,71.17,93.97,87.68,67.57,88.15,91.34,75.9,73.78,95.12,91.05,94.1
Country-211,4.51,6.06,10.05,11.63,10.45,12.94,13.45,5.78,6.68,11.0,15.09,16.18
DTD,45.16,57.55,68.24,68.94,63.99,71.33,70.64,55.59,62.39,70.85,74.63,73.35
Diabetic Retinopathy,36.39,40.74,47.45,49.55,47.0,47.09,50.13,34.98,41.26,44.82,50.95,51.29
Dmlab,27.12,29.65,62.04,43.25,35.26,40.58,46.47,29.3,30.58,65.43,46.36,49.87
EuroSAT,75.06,81.5,98.0,96.23,93.76,96.34,97.5,79.42,82.36,98.12,97.42,97.93
FER2013,28.18,32.77,62.23,51.6,38.92,50.75,56.51,33.79,39.79,66.73,58.86,63.35
FGVC Aircraft,9.78,9.3,64.78,30.32,20.34,27.36,34.61,11.82,11.37,73.01,40.71,46.31


  cmap = cm.get_cmap(cmap_name)


Unnamed: 0_level_0,DINOv2-S-14,DINOv2-S-14,DINOv2-S-14,DINOv2-S-14,DINOv2-B-14,DINOv2-B-14,DINOv2-B-14,DINOv2-B-14,DINOv2-B-14,DINOv2-B-14,DINOv2-L-14,DINOv2-L-14,DINOv2-L-14,DINOv2-L-14
Unnamed: 0_level_1,"Last layer (CLS, linear)","Last layer (all tokens, attentive)","All layers (CLS+AP, linear)","All layers (CLS+AP, attentive)","Last layer (CLS, linear)","Last layer (all tokens, attentive)","All layers (CLS+AP, linear)","Two layers (CLS+AP, attentive)","Four layers (CLS+AP, attentive)","All layers (CLS+AP, attentive)","Last layer (CLS, linear)","Last layer (all tokens, attentive)","All layers (CLS+AP, linear)","All layers (CLS+AP, attentive)"
CIFAR-10,96.2,96.42,96.83,96.8,98.14,98.27,98.21,98.2,98.26,98.35,99.29,99.11,99.23,99.31
CIFAR-100,83.42,84.59,84.6,85.96,89.66,89.97,90.18,90.31,90.71,90.67,92.64,93.42,93.42,93.7
Caltech-101,96.09,95.64,95.83,96.08,96.11,96.85,97.36,97.0,97.44,97.58,96.43,97.52,97.73,98.13
Country-211,16.31,14.61,17.83,19.31,19.22,20.49,22.25,22.18,22.96,24.07,21.53,23.28,26.03,28.77
DTD,76.7,78.67,78.72,80.05,81.81,82.82,82.29,82.5,83.35,82.5,79.95,83.24,82.77,83.56
Diabetic Retinopathy,47.54,48.93,52.21,52.64,47.65,50.17,51.83,51.64,53.04,53.82,48.28,51.46,53.82,55.11
Dmlab,43.41,61.21,49.82,52.64,49.49,63.58,54.57,55.57,56.39,59.03,50.87,66.29,58.55,61.67
EuroSAT,94.64,96.69,97.61,98.02,94.31,97.56,98.14,97.92,97.8,98.31,96.12,97.65,97.75,98.37
FER2013,54.61,62.78,60.02,65.1,58.21,68.35,65.42,64.79,65.74,68.09,61.46,69.94,67.8,71.23
FGVC Aircraft,66.79,72.76,57.69,68.04,70.6,79.07,63.14,71.2,71.06,75.46,71.12,82.59,64.52,78.19
