# Do Intermediate Representations Help?
This notebook compares vision transformer representation methods to determine whether using intermediate layers improves classification performance over final layer only.

- **Data**
    - **Input**: aggregated results `complete_set_of_run.pkl` with performance metrics
    - **Models**:  9/11 vision transformers (OpenCLIP ViT, DINOv2, standard ViT, and potentially MAE)
    - **Datasets**: 20 classification tasks
- **Representation Sources:**
    - Last layer: CLS token or all tokens
    - Multi-layer: CLS + Average Pooling from:
      - Middle & last blocks
      - Quarterly blocks (1/4, 3/4 positions)
      - All blocks
- **Probing:**
    - Linear classifiers on frozen features
    - Attention mechanisms over token sequences
- **Outputs**
    1. **Figure 2**: Boxplots of accuracy gains across *base models*
    2. **Statistical tests**: Wilcoxon signed-rank tests with FDR correction

In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sys

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
from scipy.stats import wilcoxon

sys.path.append('..')
sys.path.append('../..')

from constants import BASE_PATH_PROJECT, FOLDER_SUBSTRING, experiment_with_probe_type_order_list, experiment_order_list
from helper import init_plotting_params, save_or_show

In [None]:
init_plotting_params()

In [None]:
SAVE = 'both'

base_storing_path = BASE_PATH_PROJECT / f"results_{FOLDER_SUBSTRING}_rebuttal/plots/do_intermediate_reps_help"
if SAVE:
    base_storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
all_runs= pd.read_pickle(BASE_PATH_PROJECT / f'results_{FOLDER_SUBSTRING}_rebuttal/aggregated/complete_set_of_run.pkl')

In [None]:
all_runs = all_runs.drop(index=all_runs[(all_runs['nr_layers'] == 1) & all_runs['contains_intermediate']].index).copy().reset_index(drop=True)
all_runs = all_runs[all_runs['probe_type'].isin(['cae', 'linear'])].copy().reset_index(drop=True)

In [None]:
selected_models = sorted(all_runs['base_model'].unique())
selected_models

## Figure 2

In [None]:
allowed_ds = list(set(all_runs['dataset'].unique()) - set(['imagenet-subset-50k']))
base_runs = all_runs[all_runs['model_size'] == 'base']
base_runs = base_runs[base_runs['dataset'].isin(allowed_ds)].reset_index(drop=True).copy()
base_runs.shape

In [None]:
base_runs['abs_perf_gain_test_lp_bal_acc1'] = base_runs['abs_perf_gain_test_lp_bal_acc1'] * 100

In [None]:
curr_order = experiment_order_list[1:]
curr_order

In [None]:
base_runs = base_runs[base_runs['experiment'].isin(curr_order)]
base_runs.shape

In [None]:
copy_for_attentive = base_runs[base_runs['experiment']=='CLS last layer'].copy().reset_index(drop=True)
copy_for_attentive['task'] = "attentive_probe"

base_runs = pd.concat([base_runs, copy_for_attentive]).reset_index(drop=True)

In [None]:
base_runs[['task', 'Experiment']].value_counts().sort_index()

In [None]:
base_runs['base_model_fmt'].unique()

In [None]:
tab20c = plt.cm.tab20c.colors  
palette_list = list(tab20c[:8])
reversed_palette = []
for group_start in [0, 4]:
    group = palette_list[group_start:group_start+4]
    reversed_group = group[::-1]  # reverse the group
    reversed_palette.extend(reversed_group)

reversed_palette = [tab20c[17]] + reversed_palette


for order in [
    ['CLIP-B-16', 'DINOv2-B-14', 'ViT-B-16', 'MAE-B-16'],
    ['CLIP-B-16', 'DINOv2-B-14', 'ViT-B-16']
]:

    plt.figure(figsize=(11, 5.5))

    ax = sns.boxplot(
        base_runs,
        x="base_model_fmt",
        y="abs_perf_gain_test_lp_bal_acc1",
        order=order,
        hue="Experiment",
        hue_order=experiment_with_probe_type_order_list[2:],
        palette=reversed_palette,
        fliersize=3,
        showfliers=False
    )
    ax.set_xlabel("")
    ax.set_ylabel("Absolute accuracy gain [pp]")
    sns.move_legend(ax, bbox_to_anchor=(1.05, 1), loc='upper left', title=False)
    
    handles, labels = ax.get_legend_handles_labels()
    new_handles = [
        Line2D([0], [0], marker='', color='white', linestyle='', alpha=0),
        handles[0],
        Line2D([0], [0], marker='', color='white', linestyle='', alpha=0),
        Line2D([0], [0], marker='', color='white', linestyle='', alpha=0),
        Line2D([0], [0], marker='', color='white', linestyle='', alpha=0),
        Line2D([0], [0], marker='', color='white', linestyle='', alpha=0)
    ]
    new_labels = [
        "Attentive probe (all tokens)",
        # labels[0].split(" (")[0] + " (AAT)",
        # labels[0].split(" (")[0],
        "Last layer",
        "",
        "",
        "",
        "Linear probe (CLS & AP)"
    ]
    
    prev_name = None
    for k, (curr_handle, curr_label) in enumerate(zip(handles, labels)):
        if k==0:
            continue
    
        if k>0 and prev_name is not None and ("linear" in prev_name and "attentive" in curr_label):
            new_handles.append(Line2D([0], [0], marker='', color='white', linestyle='', alpha=0))
            new_labels.append("Attentive probe (CLS & AP)")
    
        new_handles.append(curr_handle)
        full_label = curr_label
        curr_label = curr_label.split(" (")[0]
        curr_label = curr_label.replace("CLS+AP l", "L")
        if curr_label == "Layers from middle & last blocks":
            curr_label = "+ middle block"
        elif curr_label == "Layers from quarterly blocks":
            curr_label = "+ 1/4th and 3/4th blocks"
        elif curr_label == "Layers from all blocks":
            curr_label = "All blocks + last layer"
            
        new_labels.append(curr_label)
        prev_name = full_label

    legend = ax.legend(new_handles, new_labels, bbox_to_anchor=(0.5, 0.935), loc='lower center', ncols=3, frameon=False)
    
    ax.axhline(0, ls=':', color="grey", zorder=-1)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    plt.tight_layout()
    suffix = "_with_mae" if len(order)==4 else ""
    fn = base_storing_path / f'boxplot_base_models{suffix}_v6.pdf'
    save_or_show(plt.gcf(), fn, SAVE, show_path=False)


## Statistical Tests

In [None]:
def fdr_bh(pvals):
    pvalues = np.array(pvals)
    n_tests = len(pvals)
    sorted_indices = np.argsort(pvals)
    sorted_pvals = pvalues[sorted_indices]
    
    corrected_pvals = np.zeros_like(pvals)
    for i in range(n_tests-1, -1, -1):
        corrected_pvals[sorted_indices[i]] = min(1.0, sorted_pvals[i] * n_tests / (i + 1))
        if i < n_tests - 1:
            corrected_pvals[sorted_indices[i]] = min(corrected_pvals[sorted_indices[i]], 
                                                    corrected_pvals[sorted_indices[i+1]])
    return corrected_pvals


exp_pairs = [
    ("CLS+AP layers from all blocks (attentive)", "CLS+AP last layer (attentive)"),
    ("CLS+AP layers from all blocks (linear)", "CLS+AP last layer (linear)"),
    ("CLS+AP layers from all blocks (attentive)", "CLS+AP layers from all blocks (linear)"),
    ("CLS+AP layers from all blocks (attentive)", "All tokens last layer (attentive)"),
    
]
alpha = 0.05

pvalues = []
comb_index = []                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
for mid, mid_data in base_runs.groupby("base_model"):
    for exp1, exp2 in exp_pairs:
        exp1_data = pd.to_numeric(mid_data[mid_data['Experiment'] == exp1].sort_values('dataset')["abs_perf_gain_test_lp_bal_acc1"]).values
        exp2_data = pd.to_numeric(mid_data[mid_data['Experiment'] == exp2].sort_values('dataset')["abs_perf_gain_test_lp_bal_acc1"]).values
        statistic, pval = wilcoxon(exp1_data, exp2_data, alternative='greater')

        pvalues.append(pval)
        comb_index.append((mid, (exp1, exp2))) 

adj_pvalues = fdr_bh(pvalues)
statistical_testing = pd.DataFrame({
    "adj_pvalues": adj_pvalues, 
    "regected": adj_pvalues < alpha, 
    "base_model": [val[0] for val in comb_index],
    "exp1": [val[1][0] for val in comb_index],
    "exp2": [val[1][1] for val in comb_index],
})
if SAVE:
    fn = base_storing_path / f'statistical_tests.csv'
    statistical_testing.sort_values('adj_pvalues').to_csv(fn)
    print(f"Stored statistical testin...")
    display(statistical_testing)
else:
    display(statistical_testing.sort_values('adj_pvalues'))