## Attention heatmaps for trained attentive probes across layers for Masked Autoencoders
This notebook visualizes the attention weights of learned attentive probes across different layers of Masked Autoencoders (MAE) models. It generates heatmaps to illustrate how attention is distributed across layers for various datasets, providing insights into the probe's focus during training.

In [None]:
%load_ext autoreload
%autoreload 2
import sys
import torch
import pandas as pd
sys.path.append('..')
sys.path.append('../..')
import re
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib as mpl
import pickle
from functools import partial
from src.data.feature_combiner import StackedZeroPadFeatureCombiner
from src.data.data_loader import get_combined_feature_dl
from src.utils.utils import load_model_from_file
from src.utils.attention_utils import get_attention_weights
from constants import base_model_name_mapping

from constants import BASE_PATH_PROJECT, FOLDER_SUBSTRING, ds_info_file
from helper import load_ds_info

In [None]:
fontsizes = { 'title': 14, 'legend': 13, 'label': 13, 'ticks': 12, }
fontsizes_cols = { 'title': 18, 'legend': 17, 'label': 17, 'ticks': 16, }
fontsizes_cols = { 'title': 24, 'legend': 23, 'label': 23, 'ticks': 23, }

# Choose fontsizes or fontsizes_cols
FS = fontsizes_cols  # or fontsizes_cols

mpl.rcParams.update({
    "axes.titlesize": FS["title"],
    "axes.labelsize": FS["label"],
    "xtick.labelsize": FS["ticks"],
    "ytick.labelsize": FS["ticks"],
    "legend.fontsize": FS["legend"],
    "figure.titlesize": FS["title"],   # for plt.suptitle
})
device = "cuda" if torch.cuda.is_available() else "cpu"
device
base_feature_dir = BASE_PATH_PROJECT / "features"
#base_model_dir = BASE_PATH_PROJECT / f"models_{FOLDER_SUBSTRING}" 
base_model_dir = BASE_PATH_PROJECT / f"models_{FOLDER_SUBSTRING}_rebuttal" 
# helper: extract block number if present, otherwise put at the end
def extract_block_key(s):
    m = re.search(r'blocks\.(\d+)', s)
    return int(m.group(1)) if m else 999  # "norm" goes last

# --- Sort columns: CLS first, then AP, numeric order, 'last' at end ---
def layer_sort_key(x):
    x = str(x)
    is_cls = "cls" in x
    is_last = "last" in x
    try:
        num = int(x.split("@")[-1]) if not is_last else 999
    except ValueError:
        num = 999
    return (0 if is_cls else 1, num)

def format_layer_label(label: str) -> str:
    label = str(label)
    if label.startswith("cls@"):
        return "CLS " + ("Last" if "last" in label else label.split("@")[-1])
    elif label.startswith("ap@"):
        return "AP " + ("Last" if "last" in label else label.split("@")[-1])
    else:
        return label

def format_layer_number(label: str) -> str:
    if "last" in label:
        return "Last"
    return label.split("@")[-1]

    
# Desired row order (match your table exactly, but use df's dataset names)
dataset_order = [
    # Natural (MD)
    "STL-10", "CIFAR-10", "Caltech-101", "PASCAL VOC 2007", "CIFAR-100", "Country-211",
    # Natural (SD)
    "Pets", "Flowers", "Stanford Cars", "FGVC Aircraft", "GTSRB", "SVHN",
    # Specialized "PCAM",
    "EuroSAT", "RESISC45", "Diabetic Retinopathy",
    # Structured
    "DTD", "FER2013", "Dmlab"
]
# Indices where category changes occur (lengths of groups above)
category_breaks = {
    "Natural (MD)": 6,
    "Natural (SD)": 12,
    "Specialized": 15,
    "Structured": 18
}
palette_list = list(plt.cm.tab20c.colors)
palette_list = palette_list[4:8][::-1]
# Prepend white to the palette
palette_with_white = [(1.,1.,1.), *palette_list]
multi_color_cmap = LinearSegmentedColormap.from_list(
    "multi_gradient_with_white", 
    palette_with_white
)

model_name_mapping = {"Clip":"CLIP-B-16","Dinov2": "DINOv2-B-14",  "ViT-B":"ViT-B-16",
                      'dinov2-vit-base-p14': "DINOv2-B-14",
                      'OpenCLIP_ViT-B-16_openai':"CLIP-B-16",
                       'vit_base_patch16_224':"ViT-B-16",
                     'dinov2-vit-small-p14':"DINOv2-S-14",'OpenCLIP_ViT-B-32_openai':"CLIP-B-32",'vit_small_patch16_224':"ViT-S-16",
                      'dinov2-vit-large-p14':"DINOv2-L-14",'OpenCLIP_ViT-L-14_openai':"CLIP-L-16",'vit_large_patch16_224':"ViT-L-16"
                      }
model_name_mapping = base_model_name_mapping

In [None]:
data_info = load_ds_info(ds_info_file)
# all_runs_path_large = BASE_PATH_PROJECT / f'results_{FOLDER_SUBSTRING}_exp/aggregated/all_runs_v10.pkl'
all_runs_path = BASE_PATH_PROJECT / f'results_{FOLDER_SUBSTRING}_end2end_finetuning/aggregated/complete_set_of_run.pkl'
all_runs = pd.read_pickle(all_runs_path)
# Filter 
all_runs_large= all_runs[all_runs["base_model_fmt"].isin(["MAE-B-16","MAE-L-16"]) &\
    all_runs["task"].isin(["attentive_probe"]) &\
    all_runs["nr_layers"].isin([24,48])
    ]


In [None]:

attn_path = BASE_PATH_PROJECT / 'results/plots'

def get_attn_weights(model_list,save_name="total_attn_weights_large.pkl"):
    try:
        with open(attn_path / save_name, "rb") as f:
            total_attn_weights = pickle.load(f)
        print("Just loaded weights")
    except:
        
        total_attn_weights = {}
        for base_model in model_list:
            df = all_runs_large[all_runs_large["base_model"]==base_model]
            full_attn_weights = {}
            for index, row in df.iterrows():
                ds = row["dataset"].replace("/","_")
                if "imagenet" in ds:
                    continue
                feature_dir = base_feature_dir / ds
                names = eval(row["model_ids"])
                # sort by group (ap/cls), then by block number
                sorted_names = sorted(names, key=lambda x: (("cls" in x), extract_block_key(x)))
                # get new sorted indices
                sorted_indices = [names.index(s) for s in sorted_names]
                
                feature_dirs = [feature_dir / mid.replace('@',"/") for mid in names]
                #print(feature_dirs)
                if np.isnan(row["dim"]):
                    print("Somehow Nan in row",row)
                    continue
                #print(int(row["dim"]))
                _, feature_test_loader = get_combined_feature_dl(
                        feature_dirs=feature_dirs,
                        batch_size = 2048,
                        num_workers=0,
                        fewshot_k=-1,
                        feature_combiner_cls=partial(StackedZeroPadFeatureCombiner, shared_dim=int(row["dim"])),
                        normalize = True,
                        load_train = False,
                    )
                model_path = base_model_dir / ds / row["model_id_n_hopt_slug"] / "model.pkl"
                model = load_model_from_file(
                    model_path = model_path, 
                    device = device
                )
                model.eval()
                test_attn_weigths = get_attention_weights(model, feature_test_loader).squeeze().mean(axis=0)[:,sorted_indices]
                full_attn_weights[ds] = (sorted_names, test_attn_weigths)
            total_attn_weights[base_model] = full_attn_weights
        with open(attn_path / save_name, "wb") as f:
            pickle.dump(total_attn_weights, f)
    return total_attn_weights

total_attn_weights_mae = get_attn_weights([ "mae-vit-base-p16", "mae-vit-large-p16"],
                                                "total_attn_weights_mae.pkl")
    

In [None]:
all_dfs = {}
records = []
for model_name, model_results in total_attn_weights_mae.items():
    for ds, (layer_names, attn) in model_results.items():
        if "imagenet" in ds:
            print("No Imagenet")
            continue
        # Mean over Heads:
        attn_mean = attn.mean(axis=0)  
        for i, lname in enumerate(layer_names):
            token_type = "CLS" if "cls" in lname else "AVG"   # split CLS vs AVG
            if "Clip" in model_name or "CLIP" in model_name:
                layer_id = lname.split("openai_")[-1].replace("visual.transformer.resblocks.","").replace(".ln_2","").replace("visual","last")
                #print(layer_id)
            else:
                layer_id = lname.split("_")[-1].replace(".norm2","").replace("norm","last").replace("blocks.","")
            records.append([model_name, ds, token_type, layer_id, attn_mean[i],i])
        
df = pd.DataFrame(records, columns=["Model","Dataset","TokenType","Layer","Attention","layer_idx"])


In [None]:

plt.figure(figsize=(10,6))
mean_per_dataset = df.groupby(["Dataset","Layer"])["Attention"].mean().reset_index()
pivot = mean_per_dataset.pivot_table(index="Dataset", columns="Layer", values="Attention")
#sorted_cols = sorted(pivot.columns, key=lambda x: (("cls" in str(x), "last" in str(x)), 999 if "last" in x else int(x.split("@")[-1])))
sorted_cols = sorted(pivot.columns, key=layer_sort_key)

#sorted_cols = sorted(pivot.columns, key=lambda x: (("cls" in str(x), "visual" in str(x)), 999 if "visual" in x else int(x.split("@")[-1]) ))
pivot = pivot[sorted_cols]
from scipy.cluster.hierarchy import linkage, leaves_list

# Compute linkage on dataset rows
row_linkage = linkage(pivot.fillna(0), method="ward")
row_order = leaves_list(row_linkage)

# Reorder datasets by dendrogram
pivot = pivot.iloc[row_order]

sns.heatmap(pivot, cmap=multi_color_cmap,)#"viridis")
if False:
    sns.clustermap(
        pivot,
        cmap="viridis",
        row_cluster=True, col_cluster=False,  # cluster only datasets
        figsize=(10,8)
    )
plt.suptitle("Layer Attention Heatmap (Datasets Clustered by Similarity)", y=1.02)
#plt.show()
#plt.title("Layer Attention Heatmap (per dataset)")
#fn = BASE_PATH_PROJECT / f"results_{FOLDER_SUBSTRING}_rebuttal/plots"
#plt.savefig(fn / "MAEAttentionMapwoIN.png",bbox_inches='tight')
plt.show()

In [None]:
#size = "Large"
#df = all_dfs[size]
# --- Aggregate per Model × Dataset × Layer ---
mean_per_dataset = (
    df.groupby(["Model","Dataset","Layer"])["Attention"]
      .mean()
      .reset_index()
)

models = df["Model"].unique()
models = sorted(models)
# --- Compute global vmin/vmax for shared colormap ---
all_values = mean_per_dataset["Attention"].values
vmin, vmax = all_values.min(), all_values.max()

#fig, axes = plt.subplots(1, len(models), figsize=(9*len(models), 6), sharey=True)
fig, axes = plt.subplots(1, len(models), figsize=(9*len(models), 8), sharey=True)
fig.subplots_adjust(wspace=0.05)
if len(models) == 1:
    axes = [axes]

heatmaps = []
for ax, model_name in zip(axes, models):
    pivot = mean_per_dataset.query("Model == @model_name").pivot(
        index="Dataset", columns="Layer", values="Attention"
    )
    pivot = pivot.rename(index=lambda ds: data_info.loc[
        ds.replace("_","/")
          .replace("fgvc/air","fgvc_air")
          .replace("diabetic/retinopathy","diabetic_retinopathy"), 
        "name"
    ])

   
    sorted_cols = sorted(pivot.columns, key=layer_sort_key)
    pivot = pivot[sorted_cols]
    # Reorder datasets by dendrogram
    #pivot = pivot.iloc[row_order]
    pivot = pivot.reindex(dataset_order)

    # --- Plot heatmap ---
    hm = sns.heatmap(
        pivot, cmap=multi_color_cmap,#"viridis",
        vmin=vmin, vmax=vmax,
        cbar=False, ax=ax
    )
    heatmaps.append(hm)

    ax.set_title(f"{model_name_mapping[model_name]}")
    ax.set_xlabel("")
    if False: #ax == axes[0]:
        ax.set_ylabel("Dataset")
    else:
        ax.set_ylabel("")

    # --- Force all x-tick labels visible ---
    if len(pivot.columns) <=24:
        ax.set_xticks(np.arange(1,len(pivot.columns),2) + 0.5)
        #ax.set_xticklabels([format_layer_number(col) for col in pivot.columns], rotation=45, ha="right")
        ax.set_xticklabels([format_layer_number(col) for col in pivot.columns[1::2]], rotation=0)
    else:
        ticks_range = np.asarray([0] + list(np.arange(3,len(pivot.columns),4)))
        ax.set_xticks(ticks_range + 0.5)
        formated_col = [format_layer_number(col) for col in pivot.columns]
        print(formated_col[3::4])
        ax.set_xticklabels(formated_col[:1] + formated_col[3::4], rotation=0)

    #ax.set_xticklabels([format_layer_label(col) for col in pivot.columns], rotation=45, ha="right")
    #ax.set_xticklabels(pivot.columns, rotation=45, ha="right")#, fontsize=fontsizes_cols["ticks"])
    #ax.set_yticklabels(pivot.index.map(lambda ds: data_info.loc[ds.replace("_","/").replace("fgvc/air","fgvc_air").replace("diabetic/retinopathy","diabetic_retinopathy"), "name"]))
    ax.set_yticklabels(pivot.index)
    # Make cls@last bold
    for label in ax.get_xticklabels():
        if any(tag in label.get_text() for tag in ["Last"]):
            #label.set_fontweight("bold")
            label.set_rotation(90)   

    n_cls = sum(col.startswith("cls@") for col in pivot.columns)
    n_ap = sum(col.startswith("ap@") for col in pivot.columns)
    
    # Add horizontal separator lines
    for _, row_idx in list(category_breaks.items())[:-1]:  # skip last
        ax.hlines(row_idx, *ax.get_xlim(), colors="grey",ls="--", linewidth=1)
    ax.vlines(n_cls,*ax.get_ylim(),colors="grey", ls="--",linewidth=1)

    
    # y position: a bit below the current x-axis labels
    ypos = -0.08
    ax.text(n_cls/2/len(pivot.columns), ypos, "CLS", ha="center", va="top", fontsize=22,fontweight="bold", transform=ax.transAxes)
    ax.text((n_cls + n_ap/2)/len(pivot.columns), ypos, "AP", ha="center", va="top", fontsize=22,fontweight="bold", transform=ax.transAxes)

# --- Single shared colorbar ---
cbar = fig.colorbar(heatmaps[-1].collections[0], ax=axes, location="right", fraction=0.02, pad=0.02)
cbar.set_label("Average Attention over Heads")

#plt.suptitle("Layer Attention Heatmaps per Model", y=1.02)
fn = BASE_PATH_PROJECT / f"results_{FOLDER_SUBSTRING}_rebuttal/plots"
plt.savefig(fn / "MAEAttentionMapOverDatasetPerModelswoIN.png",dpi=600,bbox_inches='tight')
#plt.tight_layout()
plt.show()
