# Does Model Size or Training Objective Impact Intermediate Layer Benefits?

This notebook examines whether model size and training objective affect the performance gains from using intermediate representations.

- **Data**
    - **Input**: `complete_set_of_run.pkl` with performance metrics
    - **Models**: 9 vision transformers across 3 sizes (Small, Base, Large) and 3 training objectives (CLIP, DINOv2, supervised ViT)
    - **Datasets**: 20 classification tasks
- **Methods Compared**
    - **CLS last layer**: Baseline linear probe on final layer CLS token
    - **All tokens last layer (attentive)**: Attention over all final layer tokens  
    - **CLS+AP layers from all blocks (linear)**: Linear probe on CLS + Average Pooling from all layers
    - **CLS+AP layers from all blocks (attentive)**: Attention probe on CLS + Average Pooling from all layers
- **Visualization**: Dual-panel boxplot figure:
    - **Left panel**: Raw balanced accuracy for baseline method across models
    - **Right panel**: Performance gains over baseline for three intermediate representation methods
    - **Color coding**: Each model has distinct color, grouped by architecture family

In [None]:
import sys
sys.path.append("..")
sys.path.append("../..")
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from constants import BASE_PATH_PROJECT, FOLDER_SUBSTRING
from helper import init_plotting_params, save_or_show

In [2]:
init_plotting_params()

{
  "agg.path.chunksize": 0,
  "axes.labelsize": 13.0,
  "axes.titlesize": 14.0,
  "axes3d.trackballsize": 0.667,
  "boxplot.flierprops.markersize": 6.0,
  "boxplot.meanprops.markersize": 6.0,
  "errorbar.capsize": 0.0,
  "figure.figsize": [
    6.4,
    4.8
  ],
  "figure.labelsize": "large",
  "figure.titlesize": "large",
  "font.cursive": [
    "Apple Chancery",
    "Textile",
    "Zapf Chancery",
    "Sand",
    "Script MT",
    "Felipa",
    "Comic Neue",
    "Comic Sans MS",
    "cursive"
  ],
  "font.family": [
    "sans-serif"
  ],
  "font.fantasy": [
    "Chicago",
    "Charcoal",
    "Impact",
    "Western",
    "xkcd script",
    "fantasy"
  ],
  "font.monospace": [
    "DejaVu Sans Mono",
    "Bitstream Vera Sans Mono",
    "Computer Modern Typewriter",
    "Andale Mono",
    "Nimbus Mono L",
    "Courier New",
    "Courier",
    "Fixed",
    "Terminal",
    "monospace"
  ],
  "font.sans-serif": [
    "DejaVu Sans",
    "Bitstream Vera Sans",
    "Computer Modern Sans Serif

In [None]:
SAVE = 'both'

base_storing_path = BASE_PATH_PROJECT / f"results_{FOLDER_SUBSTRING}_rebuttal/plots/per_model_performance_gain_dist" 

if SAVE:
    base_storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
all_runs_path = BASE_PATH_PROJECT / f"results_{FOLDER_SUBSTRING}_rebuttal/aggregated/complete_set_of_run.pkl"
all_runs = pd.read_pickle(all_runs_path)

In [5]:
all_runs = all_runs[~all_runs['dataset'].isin(['imagenet-subset-50k'])].reset_index(drop=True)

In [6]:
metrics_cols = [
    'abs_perf_gain_train_lp_bal_acc1',
    'abs_perf_gain_test_lp_bal_acc1',
    'test_lp_bal_acc1'
]

In [7]:
all_runs[metrics_cols] = all_runs[metrics_cols].astype(float)
all_runs[metrics_cols] *= 100

In [8]:
grouping_cols = ["base_model_fmt", "Experiment"]

In [9]:
curr_order = [
    "CLS last layer",
    'All tokens last layer (attentive)',
    'CLS+AP layers from all blocks (linear)',
    'CLS+AP layers from all blocks (attentive)',
]
subset_runs = all_runs[all_runs['Experiment'].isin(curr_order)].copy().reset_index()
subset_runs.shape

(1152, 88)

In [10]:
subset_runs = subset_runs[subset_runs['probe_type'].isin(['cae', 'linear'])].copy().reset_index(drop=True)
idx_to_drop = subset_runs[(subset_runs['nr_layers'] == 1) & (subset_runs['contains_intermediate'])].index.tolist()
subset_runs = subset_runs.drop(index=idx_to_drop).copy().reset_index(drop=True)
subset_runs = subset_runs[~subset_runs['base_model'].str.startswith('mae-')].copy().reset_index(drop=True)
subset_runs.shape

(720, 88)

In [11]:
model_oder = [
    'CLIP-B-32', 
    'CLIP-B-16',
    'CLIP-L-14',
    'DINOv2-S-14',
    'DINOv2-B-14',
    'DINOv2-L-14',
    'ViT-S-16',
    'ViT-B-16',
    'ViT-L-16', 
]

exp_name_mapping = {
'CLS last layer': "Last layer\n(CLS, linear)",
'CLS+AP last layer (linear)': "Last layer\n(CLS+AP, linear)",
'CLS+AP layers from middle & last blocks (linear)': "Two layers\n(CLS+AP, linear)",
'CLS+AP layers from quarterly blocks (linear)': "Four layers\n(CLS+AP, linear)",
'CLS+AP layers from all blocks (linear)': "All layers\n(CLS+AP, linear)",
'CLS+AP last layer (attentive)': "Last layer\n(CLS+AP, attentive)",
'CLS+AP layers from middle & last blocks (attentive)': "Two layers\n(CLS+AP, attentive)",
'CLS+AP layers from quarterly blocks (attentive)': "Four layers\n(CLS+AP, attentive)",
'CLS+AP layers from all blocks (attentive)': "All layers\n(CLS+AP, attentive)",
'All tokens last layer (attentive)': "Last layer\n(all tokens, attentive)",
}

In [12]:
# Create figure with custom grid layout
fig = plt.figure(figsize=(3 + len(curr_order) * 2.5, 4))
gs = gridspec.GridSpec(
    1, 2, 
    figure=fig,
    width_ratios=[1.3, len(curr_order)-0.5], 
    wspace=0.2  
)

reversed_palette = []
palette_list = list(plt.cm.tab20c.colors)
for group_start in [8, 12]:
    group = palette_list[group_start:group_start+3]
    reversed_group = group[::-1]
    reversed_palette.extend(reversed_group)

palette_list = list(plt.cm.tab20b.colors)
for group_start in [12]:
    group = palette_list[group_start:group_start+3]
    reversed_group = group[::-1]
    reversed_palette.extend(reversed_group)

o = 1.15
reversed_palette = [tuple([o*c for c in rgb]) for rgb in reversed_palette]

ax_raw = fig.add_subplot(gs[0, 0])

leftmost_exp = curr_order[0]
leftmost_data = subset_runs[subset_runs['Experiment'] == leftmost_exp]

sns.boxplot(
    data=leftmost_data,
    x='base_model_fmt',
    y='test_lp_bal_acc1',
    hue = 'base_model_fmt',
    hue_order=model_oder,
    order=model_oder,
    fliersize=2,
    palette=reversed_palette,
    ax=ax_raw,
    gap=0,
    width=1
)

ax_raw.set_title("")
ax_raw.set_xlabel("")
ax_raw.set_ylabel("Balanced accuracy [%]")
ax_raw.spines['top'].set_visible(False)
ax_raw.spines['right'].set_visible(False) 

middle_pos = (len(model_oder) - 1) / 2
ax_raw.set_xticks([middle_pos])
ax_raw.set_xticklabels([exp_name_mapping[leftmost_exp]])
ax_raw.margins(x=0.1)


ax_main = fig.add_subplot(gs[0, 1])
g = sns.boxplot(
    subset_runs,
    x='Experiment',
    y='abs_perf_gain_test_lp_bal_acc1',
    hue="base_model_fmt",
    hue_order=model_oder,
    fliersize=2,
    showfliers=False,
    order=curr_order[1:],
    palette=reversed_palette,
    ax=ax_main
)

ax_main.set_xlabel("")
ax_main.set_ylabel("Absolute accuracy gain [pp]")

ax_main.axhline(0, ls=':', color="grey", zorder=-1)
ax_main.spines['top'].set_visible(False)
ax_main.spines['right'].set_visible(False)

custom_labels = [exp_name_mapping[val.get_text()] for val in ax_main.get_xticklabels()]
ax_main.set_xticklabels(custom_labels)

ax_main.legend().remove()
ax_raw.legend().remove()


handles, labels = ax_main.get_legend_handles_labels()

fig.legend(handles, labels, bbox_to_anchor=(0.9,0.75), loc='upper left', 
          ncols=1, frameon=False, fontsize=11)

plt.tight_layout()

fn = base_storing_path / f'boxplot_gridspec_dual_v8.pdf'
save_or_show(fig, fn, SAVE, show_path=False)

  ax_main.set_xticklabels(custom_labels)
  ax_raw.legend().remove()
  plt.tight_layout()


stored img at.
