# Attention Rank Visualizer — per head or whole layer

This notebook is an implementation of concepts in the paper [When Attention Collapses: How Degenerate Layers in LLMs Enable Smaller, Stronger Models](https://arxiv.org/html/2404.08634v3).

**Methodology**
1. Build the attention matrix `A` (H,T,T) across all heads in the `CausalSelfAttention` module from a chosen layer ([see attn_rank.py line 14](attn_rank.py#L14))
2. For each head:
    - calculate an **effective rank**: the smallest `k` whose top singular values explain **90%** of the matrix energy ([see attn_rank.py line 45](attn_rank.py#L45))
    - measure **single‑column‑ness**: the fewest columns needed to cover **90%** of the squared entries of `A` ([see attn_rank.py line 55](attn_rank.py#L55))
3. Plot the heatmaps of each `A` and a per‑head rank profile [Go to heatmaps section](#heatmaps-of-attention-matrices-per-layer)
4. Compute averages over N sequences of length T [Go to batch section](#batch-over-multiple-sequences) - the paper uses `N=100` and `T=100` to offer a more stable estimate of the rank profile


**Technical Notes**
- `models.gpt2.attention.CausalSelfAttention` uses merged QKVO weights in `qkvo_w` of shape `(4, num_heads*head_dim, dim)`
- hidden states `X` of shape `(B, T, dim)` that enter an attention block (from the previous layer) are exposed in `CausalSelfAttention` as `in_t` for convenience

**Note**: Start with short sequences (e.g., `T ≤ 128`) because SVD (Singular Value Decomposition) scales cubically with `T`


# Helpers

In [1]:
import torch
import matplotlib.pyplot as plt

import models.gpt2.block

torch.set_printoptions(precision=4, sci_mode=True)

In [2]:
def show_attention_heatmap(A: torch.Tensor, head: int=0, title: str=None):
    """A: (H, T, T) or (T, T)."""
    plt.figure()
    if A.dim() == 3:
        M = A[head].detach().cpu().numpy()
    else:
        M = A.detach().cpu().numpy()
    plt.imshow(M, aspect='auto')
    plt.colorbar()
    plt.xlabel('Key position j')
    plt.ylabel('Query position i')
    if title:
        plt.title(title)
    plt.show()

def plot_head_ranks(ranks, title: str='Effective rank (90%) per head'):
    plt.figure()
    xs = list(range(len(ranks)))
    plt.plot(xs, ranks, marker='o')
    plt.xlabel('Head index')
    plt.ylabel('Rank-90%')
    plt.title(title)
    plt.show()

def plot_head_masses(masses, title: str='Fewest columns for 90% mass per head'):
    plt.figure()
    xs = list(range(len(masses)))
    plt.plot(xs, masses, marker='o')
    plt.xlabel('Head index')
    plt.ylabel('#columns for 90% mass')
    plt.title(title)
    plt.show()


In [3]:
from models.gpt2 import GPT2Core
from tools.metrics.attn_rank import attention_matrix_from_attn, per_head_metrics, average_per_head_over_sequences
from training.data_gen import DistributedDataGenerator
from tools.checkpoint import model_from_checkpoint

pt = "/Users/jonathanmiddleton/models/checkpoints/350m-instruct/20251013T1953-val1.600-step000850-run1-best.pt"
device = 'cpu'
seq_len = 100 # 100 in paper

data_loader = DistributedDataGenerator(
    "../../data/fineweb/fineweb_val_000000.bin",
    1 * seq_len,
    rank = 0,
    world_size=1,
    device=device,
)
# noinspection PyTypeChecker
model: GPT2Core = model_from_checkpoint(pt, device=device, map_location=device).eval()
p = next(data_loader)[0][None,:] # one sample
with torch.no_grad():
    model.prefill_batch(p, 256)

def build_attention_matrix(layer_id):
    with torch.no_grad():
        # noinspection PyTypeChecker
        attn: models.gpt2.CausalSelfAttention = model.blocks[layer_id].attn
        X: torch.Tensor = attn.in_t.squeeze(0) # (T, dim)
        return attention_matrix_from_attn(attn, X)

def compute_and_visualize(layer_id: int):
    with torch.no_grad():
        A = build_attention_matrix(layer_id)
        for h in range(A.shape[0]):
            show_attention_heatmap(A, head=h, title=f'Attention matrix — head {h}')
        ranks, masses, max_rank = per_head_metrics(A, device=device)
        print('Per‑head effective ranks:', ranks)
        print('Per‑head columns@90%:', masses)
        print('MaxRank(layer) =', max_rank)
        plot_head_ranks(ranks)
        plot_head_masses(masses)

# Heatmaps of Attention matrices per layer

In [4]:
import ipywidgets as widgets
from IPython.display import display, clear_output

_layer_dropdown = widgets.Dropdown(
    options=list(range(len(model.blocks))),
    value=0,
    description='Layer:',
)
_run_button = widgets.Button(description='Run', button_style='primary')
_out = widgets.Output()

display(widgets.HBox([_layer_dropdown, _run_button]), _out)

def _on_run_clicked(_):
    _run_button.disabled = True
    _run_button.description = "Running..."
    _run_button.button_style = 'info'
    try:
        with _out:
            clear_output(wait=True)
            compute_and_visualize(int(_layer_dropdown.value))
    finally:
        _run_button.disabled = False
        _run_button.description = "Run"
        _run_button.button_style = 'primary'

_run_button.on_click(_on_run_clicked)


HBox(children=(Dropdown(description='Layer:', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), …

Output()

# Batch over multiple sequences
Averaged over `N` sequences of length `T`.
- Mirroring the paper's setup: sample `N=100` sequences with `T=100`, then compute per‑head averages and finally `MaxRank(l)` as the maximum head rank per layer.
- To inspect a **single‑column** pattern directly, sort the columns of `A[h]` by their squared mass and see if the first one dominates.


In [5]:
def list_f(l: list) -> str:
    return "[" + ", ".join(f"{v:.2f}" for v in l) + "]"

def compute_batch(layer_id: int):
    with torch.no_grad():
        I = [inputs[None,:] for inputs, _ in (next(data_loader) for _ in range(seq_len))]
        Xs = [model.blocks[layer_id].attn.in_t[0] for inputs in I for _ in model.prefill_batch(inputs, seq_len)]

        avgs = average_per_head_over_sequences(model.blocks[layer_id].attn, Xs, device=device)
        print("avg_ranks_per_head: ", list_f(avgs['avg_ranks_per_head']))
        print("avg_columns90_per_head: ", list_f(avgs['avg_columns90_per_head']))
        print(f"MaxRank_layer: {avgs['MaxRank_layer']:.2f}")

In [6]:
import ipywidgets as widgets

_layer_dropdown_batch = widgets.Dropdown(
    options=list(range(len(model.blocks))),
    value=0,
    description='Layer:',
)
_run_button_batch = widgets.Button(description='Run', button_style='primary')
_out_batch = widgets.Output()

def _on_run_clicked_batch(_):
    _run_button_batch.button_style = 'info'
    _run_button_batch.description = "Running..."
    _run_button_batch.disabled = True
    try:
        with _out_batch:
            clear_output(wait=True)
            compute_batch(int(_layer_dropdown_batch.value))
    finally:
        _run_button_batch.button_style = 'primary'
        _run_button_batch.description = "Run"
        _run_button_batch.disabled = False

try:
    _run_button_batch.on_click(_on_run_clicked_batch, remove=True)
except Exception:
    pass
_run_button_batch.on_click(_on_run_clicked_batch)

display(widgets.VBox([
    widgets.HBox([_layer_dropdown_batch, _run_button_batch]),
    _out_batch
]))

VBox(children=(HBox(children=(Dropdown(description='Layer:', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12…