# Multi-Granularity Representation Extraction for HTS-AT

This notebook extracts attention representations at three granularities:

1. **Per-head**: $\text{Attn} \cdot V$ pre-projection $W^O$, for each of the 184 heads individually.
2. **Per-block**: mean of head representations within the same Swin block.
3. **Per-layer**: mean of block representations within the same HTS-AT stage.

## Architecture Recap

HTS-AT has 4 stages (layers), 12 blocks total, 184 heads total:

| Stage | Blocks | Heads/block | $d_h$ | $D_\ell = H_\ell \cdot d_h$ |
|-------|--------|-------------|-------|---------------------------|
| L0    | 2      | 4           | 24    | 96                        |
| L1    | 2      | 8           | 24    | 192                       |
| L2    | 6      | 16          | 24    | 384                       |
| L3    | 2      | 32          | 24    | 768                       |

Note that $d_h = 24$ is constant across all stages since $d_h = D_\ell / H_\ell = (96 \cdot 2^\ell) / (4 \cdot 2^\ell) = 24$.

## What We Extract (Pre-Projection)

For each head $h$ in block $b$ of layer $\ell$, we capture:

$$\mathbf{H}_{\ell,b,h} = \text{Attn}_{\ell,b,h} \cdot V_{\ell,b,h} \in \mathbb{R}^{N_W \cdot B \times M \times d_h}$$

where $N_W$ is the number of spatial windows and $M = w^2 = 64$ tokens per window.
We then spatial mean-pool to obtain a single vector per sample:

$$\mathbf{r}_{\ell,b,h} = \frac{1}{N_W M} \sum_{i=1}^{N_W} \sum_{j=1}^{M} \mathbf{H}_{\ell,b,h}[i,j,:] \in \mathbb{R}^{24}$$

## Cell 0 ‚Äî Imports & Configuration

In [1]:
import torch
import numpy as np
from tqdm.notebook import tqdm
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from CLAPWrapper import CLAPWrapper
from datasets.esc50 import ESC50
from datasets.tinysol import TinySOL
from datasets.vocalsound import VocalSound

# ‚îÄ‚îÄ Configuration ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
DATASET     = ESC50          # ‚Üê change to TinySOL or VocalSound as needed
DATA_ROOT   = '../data'
BATCH_SIZE  = 1              # process one sample at a time (safest for hook logic)
SAVE_DIR    = 'heads_representations'

# HTS-AT constants (fixed by architecture)
HTSAT_DEPTHS = [2, 2, 6, 2]   # blocks per stage
HTSAT_HEADS  = [4, 8, 16, 32] # attention heads per stage
HTSAT_EMBED  = 96             # base embedding dim
HEAD_DIM     = 24             # head_dim = layer_dim / n_heads = constant = 24

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"‚úÖ Device: {device}")

‚úÖ Device: cpu


## Cell 1 ‚Äî Load Dataset & Model

Access path through CLAP's module hierarchy:
`CLAPWrapper` $\to$ `CLAP` $\to$ `AudioEncoder.base` $\to$ `HTSATWrapper.htsat`

We also build here the `block_info_df` table that maps every global block index
$b \in \{0, \ldots, 11\}$ to its stage $\ell$ and intra-stage position.
This table is shared by all three extraction cells below.

In [None]:
import os
os.makedirs(SAVE_DIR, exist_ok=True)

# Dataset
dataset = DATASET(root=DATA_ROOT, download=False)
print(f"üìä Dataset : {len(dataset)} samples, {len(dataset.classes)} classes")
print(f"   Classes  : {dataset.classes[:5]}... (+{len(dataset.classes)-5} more)")

# CLAP model
print("\nLoading CLAP ...", end='')
wrapper       = CLAPWrapper(version='2023', use_cuda=torch.cuda.is_available())
clap_model    = wrapper.clap
clap_model.eval()

# The HTS-AT Swin transformer inside CLAP
# Access path: CLAPWrapper.clap  ‚Üí  CLAP  ‚Üí  AudioEncoder.base  ‚Üí  HTSATWrapper.htsat
audio_encoder = clap_model.audio_encoder.base.htsat
audio_encoder.eval()

print('OK')

print(f"\n‚úÖ CLAP model loaded in eval mode")
print(f"   Stages  : {audio_encoder.num_layers}")
print(f"   Depths  : {audio_encoder.depths}")
print(f"   Heads   : {audio_encoder.num_heads}")
print(f"   Embed   : {audio_encoder.embed_dim}")

# Build global block index table (used by all three extractors)
block_info = []
global_block = 0
for layer_idx, (depth, n_heads) in enumerate(zip(HTSAT_DEPTHS, HTSAT_HEADS)):
    layer_dim = HTSAT_EMBED * (2 ** layer_idx)
    for block_idx in range(depth):
        block_info.append({
            'global_block': global_block,
            'layer': layer_idx,
            'block_in_layer': block_idx,
            'n_heads': n_heads,
            'layer_dim': layer_dim,
            'head_dim': HEAD_DIM,
        })
        global_block += 1

block_info_df = pd.DataFrame(block_info)
N_BLOCKS = len(block_info_df)
N_HEADS_TOTAL = sum(d * h for d, h in zip(HTSAT_DEPTHS, HTSAT_HEADS))
print(f"\n   Total blocks : {N_BLOCKS}")
print(f"   Total heads  : {N_HEADS_TOTAL}")
display(block_info_df)

Loading audio files


2000it [00:00, 11332.85it/s]

üìä Dataset : 2000 samples, 50 classes
   Classes  : ['airplane', 'breathing', 'brushing teeth', 'can opening', 'car horn']... (+45 more)
Loading CLAP ...




OK

‚úÖ CLAP model loaded in eval mode
   Stages  : 4
   Depths  : [2, 2, 6, 2]
   Heads   : [4, 8, 16, 32]
   Embed   : 96

   Total blocks : 12
   Total heads  : 184


Unnamed: 0,global_block,layer,block_in_layer,n_heads,layer_dim,head_dim
0,0,0,0,4,96,24
1,1,0,1,4,96,24
2,2,1,0,8,192,24
3,3,1,1,8,192,24
4,4,2,0,16,384,24
5,5,2,1,16,384,24
6,6,2,2,16,384,24
7,7,2,3,16,384,24
8,8,2,4,16,384,24
9,9,2,5,16,384,24


## Cell 2 ‚Äî Build Stratified Sample List

We collect exactly $\lfloor N / C \rfloor$ samples per class, where $N$ is the total
dataset size and $C$ the number of classes. Samples are sorted by class index to ensure
a deterministic ordering that is consistent across all three extraction cells.

In [None]:
# Stratified sampling: equal samples per class
samples_per_class = len(dataset) // len(dataset.classes)
print(f"Samples per class: {samples_per_class}")

class_buckets = defaultdict(list)
for idx in range(len(dataset)):
    audio_path, class_name, one_hot = dataset[idx]
    class_idx = torch.argmax(one_hot).item()
    if len(class_buckets[class_idx]) < samples_per_class:
        class_buckets[class_idx].append((audio_path, class_idx))
    if (len(class_buckets) == len(dataset.classes) and
            all(len(v) >= samples_per_class for v in class_buckets.values())):
        break

# Flatten sorted by class index ‚Üí deterministic ordering
sample_list = []

for class_idx in sorted(class_buckets.keys()):
    sample_list.extend(class_buckets[class_idx])

sample_labels = np.array([label for _, label in sample_list])
N_SAMPLES     = len(sample_list)

print(f"‚úÖ {N_SAMPLES} samples collected, {len(class_buckets)} classes covered")

Samples per class: 40
‚úÖ 2000 samples collected, 50 classes covered


## Cell 3 ‚Äî Extractor 1: Per-Head Representations

**What we capture.** For each head $h$ in block $b$ of layer $\ell$:

$$\mathbf{r}_{\ell,b,h} = \frac{1}{N_W M}\sum_{i,j} (\text{Attn}_{\ell,b,h} \cdot V_{\ell,b,h})[i,j,:] \in \mathbb{R}^{24}$$

This is computed *before* the output projection $W^O$. After $W^O$ the $H_\ell$ heads
are linearly mixed into a single $D_\ell$-dimensional vector, destroying individual
head structure. Pre-projection is therefore the only point where per-head subspaces
are still disentangled.

**Hook target.** We register one `register_forward_hook` on `WindowAttention` inside
each `SwinTransformerBlock`. The hook:

1. Reads `input[0]` ‚Äî the windowed token sequence $\in \mathbb{R}^{N_W B \times M \times D_\ell}$.
2. Re-computes $Q, K, V$ by passing the input through `module.qkv` (no grad).
3. Computes $\text{Attn}_{h} \cdot V_h$ for each head $h$ without going through `proj` or `proj_drop`.
4. Mean-pools over $(N_W B, M)$ to obtain $\mathbf{r}_{\ell,b,h} \in \mathbb{R}^{24}$.

**Output.** 184 tensors of shape $[N_{\text{samples}},\, 24]$, one per head.

In [6]:
class HeadLevelExtractor:
    """
    Registers one forward hook per SwinTransformerBlock attention module.
    Each hook:
      1. Re-computes QKV from the attention input (no grad).
      2. Computes attn @ V for each head ‚Üí shape [nW*B, N, head_dim].
      3. Mean-pools over (nW*B, N) ‚Üí shape [head_dim].
      4. Appends the pooled vector to head_outputs[head_id].
    """

    def __init__(self, model):
        self.model       = model
        self.head_outputs = defaultdict(list)  # head_id ‚Üí list of [head_dim] tensors
        self.hooks        = []

    # ‚îÄ‚îÄ hook factory ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    def _make_hook(self, layer_idx, block_idx, n_heads):
        def hook(module, input, output):
            # input[0]: [nW*B, N, C]  where C = layer_dim, N = window_size^2
            x_in        = input[0]                          # [nW*B, N, C]
            _, attn_w   = output                            # attn_w: [nW*B, n_heads, N, N]
            nWB, N, C   = x_in.shape
            head_dim    = C // n_heads

            # Re-compute QKV (no grad, same weights as forward pass)
            with torch.no_grad():
                qkv = module.qkv(x_in)                     # [nW*B, N, 3*C]
            qkv = qkv.reshape(nWB, N, 3, n_heads, head_dim).permute(2, 0, 3, 1, 4)
            # q, k, v: [nW*B, n_heads, N, head_dim]
            v = qkv[2]

            for h in range(n_heads):
                # attn_w[:, h]: [nW*B, N, N]
                # v[:, h]:      [nW*B, N, head_dim]
                head_out = torch.matmul(attn_w[:, h], v[:, h])  # [nW*B, N, head_dim]
                # Global spatial mean-pool ‚Üí [head_dim]
                pooled   = head_out.mean(dim=[0, 1]).detach().cpu()
                head_id  = f"L{layer_idx}_B{block_idx}_H{h}"
                self.head_outputs[head_id].append(pooled)

        return hook

    # ‚îÄ‚îÄ registration ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    def register_hooks(self):
        for layer_idx, layer in enumerate(self.model.layers):
            n_heads = self.model.num_heads[layer_idx]
            for block_idx, block in enumerate(layer.blocks):
                h = block.attn.register_forward_hook(
                    self._make_hook(layer_idx, block_idx, n_heads)
                )
                self.hooks.append(h)
        print(f"‚úÖ Registered {len(self.hooks)} hooks")

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks.clear()

    def clear(self):
        self.head_outputs.clear()

    def finalize(self):
        """Stack per-head lists ‚Üí dict of tensors [N_samples, head_dim]."""
        return {hid: torch.stack(vecs) for hid, vecs in self.head_outputs.items()}


# ‚îÄ‚îÄ Extraction loop ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
head_extractor = HeadLevelExtractor(audio_encoder)
head_extractor.register_hooks()

for audio_path, _ in tqdm(sample_list, desc="Extracting heads"):
    audio_tensor = wrapper.load_audio_into_tensor(
        audio_path, wrapper.args.duration, resample=True
    ).reshape(1, -1).to(device)

    with torch.no_grad():
        audio_encoder(audio_tensor)

head_extractor.remove_hooks()
head_outputs_final = head_extractor.finalize()

# ‚îÄ‚îÄ Verify ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
sample_shape = head_outputs_final[list(head_outputs_final.keys())[0]].shape
print(f"\n‚úÖ Head extraction complete")
print(f"   Heads extracted : {len(head_outputs_final)}")
print(f"   Shape per head  : {sample_shape}  (N_samples √ó head_dim)")
assert sample_shape[0] == N_SAMPLES, "Sample count mismatch!"
assert sample_shape[1] == HEAD_DIM,  "Head dim mismatch!"

# ‚îÄ‚îÄ Save ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
save_path = f"{SAVE_DIR}/{DATASET.__name__.lower()}_head_outputs_final.pt"
torch.save({"head_outputs_final": head_outputs_final, "labels": sample_labels}, save_path)
print(f"   Saved ‚Üí {save_path}")

‚úÖ Registered 12 hooks


Extracting heads:   0%|          | 0/2000 [00:00<?, ?it/s]


‚úÖ Head extraction complete
   Heads extracted : 184
   Shape per head  : torch.Size([2000, 24])  (N_samples √ó head_dim)
   Saved ‚Üí heads_representations/esc50_head_outputs_final.pt


## Cell 4 ‚Äî Extractor 2: Per-Block Representations

**What we compute.** For each `SwinTransformerBlock`, we capture the residual stream
at the point between the attention sub-layer and the MLP sub-layer. Concretely, from
the block forward pass:

```python
shortcut = x
x = self.norm1(x)
# ... window partition, W-MSA/SW-MSA, window reverse ...
x = shortcut + self.drop_path(x)   # ‚Üê we capture HERE
x = x + self.drop_path(self.mlp(self.norm2(x)))
```

This corresponds to:

$$\mathbf{z}_b = x_{b,\text{in}} + \text{DropPath}(W^O \cdot \text{concat}_h(\text{Attn}_h \cdot V_h)) \in \mathbb{R}^{B \times N_\ell \times D_\ell}$$

where $W^O$ is the output projection, $\text{DropPath}$ is identity in eval mode,
and $N_\ell$, $D_\ell$ are the token count and embedding dimension of stage $\ell$.
After mean-pooling over tokens:

$$\mathbf{r}_b^{\text{block}} = \frac{1}{N_\ell} \sum_{n=1}^{N_\ell} \mathbf{z}_b[:, n, :] \in \mathbb{R}^{D_\ell}$$

This is the most informative block-level representation because it includes $W^O$ and
the residual connection, but excludes the MLP contribution, isolating the attention
sub-layer's effect on the residual stream.

Note that like per-layer, per-block vectors have **different dimensionalities** per stage
($D_0=96$, $D_1=192$, $D_2=384$, $D_3=768$).

**Requires a second forward pass** with hooks on `SwinTransformerBlock`.

**Output.** 12 tensors of shape $[N_{\text{samples}},\, D_\ell]$, one per block.

In [10]:
# ‚îÄ‚îÄ Block-level extractor: hooks on SwinTransformerBlock ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# We need to capture x AFTER the first residual addition but BEFORE the MLP.
# SwinTransformerBlock.forward does not expose this intermediate value as a
# module output, so we use a register_forward_hook on the norm2 layer, which
# receives x at exactly that point: norm2 is called as self.norm2(x) where x
# is already shortcut + drop_path(attn_out).
# norm2 input == residual stream after attention, before MLP.

class BlockLevelExtractor:
    """
    Registers one forward hook per SwinTransformerBlock, targeting self.norm2.
    norm2 receives x = shortcut + drop_path(attn_windows), which is exactly the
    residual stream after the attention sub-layer and before the MLP sub-layer.
    Mean-pools over the token dimension to get one vector per sample.
    Output shape per block: [D_ell] = [96 * 2^layer_idx]
    """

    def __init__(self, model):
        self.model         = model
        self.block_outputs = defaultdict(list)  # block_id ‚Üí list of [D_ell] tensors
        self.hooks         = []

    def _make_hook(self, layer_idx, block_idx):
        def hook(module, input, output):
            # norm2 input[0]: [B, N_ell, D_ell]
            # This is x = shortcut + drop_path(attn_out), i.e. post-attention
            # pre-MLP residual stream. In eval mode drop_path is identity.
            x_pre_mlp = input[0]                          # [B, N_ell, D_ell]
            pooled = x_pre_mlp.mean(dim=1).squeeze(0).detach().cpu()  # [D_ell]
            block_id = f"B{block_idx}_L{layer_idx}"
            self.block_outputs[block_id].append(pooled)
        return hook

    def register_hooks(self):
        global_block = 0
        for layer_idx, layer in enumerate(self.model.layers):
            for block_idx, block in enumerate(layer.blocks):
                h = block.norm2.register_forward_hook(
                    self._make_hook(layer_idx, block_idx)
                )
                self.hooks.append(h)
                global_block += 1
        print(f"‚úÖ Registered {len(self.hooks)} block hooks (on norm2)")

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks.clear()

    def finalize(self):
        """Stack per-block lists ‚Üí dict of tensors [N_samples, D_ell]."""
        return {bid: torch.stack(vecs) for bid, vecs in self.block_outputs.items()}


# ‚îÄ‚îÄ Extraction loop (second forward pass) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
block_extractor = BlockLevelExtractor(audio_encoder)
block_extractor.register_hooks()

for audio_path, _ in tqdm(sample_list, desc="Extracting blocks"):
    audio_tensor = wrapper.load_audio_into_tensor(
        audio_path, wrapper.args.duration, resample=True
    ).reshape(1, -1).to(device)

    with torch.no_grad():
        audio_encoder(audio_tensor)

block_extractor.remove_hooks()
block_outputs_final = block_extractor.finalize()

# ‚îÄ‚îÄ Verify ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
print(f"\n‚úÖ Block extraction complete")
print(f"   Blocks extracted : {len(block_outputs_final)}")
for key, tensor in block_outputs_final.items():
    layer_idx = int(key.split("_L")[1])
    expected_dim = 96 * (2 ** layer_idx)
    ok = tensor.shape == (N_SAMPLES, expected_dim)
    print(f"   {key}: shape {tuple(tensor.shape)}  (expected D={expected_dim})  {'‚úÖ' if ok else '‚ùå'}")
    assert tensor.shape[0] == N_SAMPLES, f"Sample count mismatch for {key}"
    assert tensor.shape[1] == expected_dim, f"Dim mismatch: got {tensor.shape[1]}, expected {expected_dim}"

assert len(block_outputs_final) == N_BLOCKS
print("   ‚úÖ All shapes verified")

# ‚îÄ‚îÄ Save ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
save_path = f"{SAVE_DIR}/{DATASET.__name__.lower()}_block_outputs_final.pt"
torch.save({"block_outputs_final": block_outputs_final, "labels": sample_labels}, save_path)
print(f"   Saved ‚Üí {save_path}")

‚úÖ Registered 12 block hooks (on norm2)


Extracting blocks:   0%|          | 0/2000 [00:00<?, ?it/s]


‚úÖ Block extraction complete
   Blocks extracted : 12
   B0_L0: shape (2000, 96)  (expected D=96)  ‚úÖ
   B1_L0: shape (2000, 96)  (expected D=96)  ‚úÖ
   B0_L1: shape (2000, 192)  (expected D=192)  ‚úÖ
   B1_L1: shape (2000, 192)  (expected D=192)  ‚úÖ
   B0_L2: shape (2000, 384)  (expected D=384)  ‚úÖ
   B1_L2: shape (2000, 384)  (expected D=384)  ‚úÖ
   B2_L2: shape (2000, 384)  (expected D=384)  ‚úÖ
   B3_L2: shape (2000, 384)  (expected D=384)  ‚úÖ
   B4_L2: shape (2000, 384)  (expected D=384)  ‚úÖ
   B5_L2: shape (2000, 384)  (expected D=384)  ‚úÖ
   B0_L3: shape (2000, 768)  (expected D=768)  ‚úÖ
   B1_L3: shape (2000, 768)  (expected D=768)  ‚úÖ
   ‚úÖ All shapes verified
   Saved ‚Üí heads_representations/esc50_block_outputs_final.pt


## Cell 5 ‚Äî Extractor 3: Per-Layer Representations

**What we compute.** For stage $\ell$, we capture the output of `BasicLayer.forward`
*after* all blocks, residual additions, and patch merging have been applied:

$$x_\ell = \text{BasicLayer}_\ell(x_{\ell-1}) \in \mathbb{R}^{B \times N_\ell \times D_\ell}$$

where $N_\ell$ is the number of tokens after patch merging and $D_\ell = 96 \cdot 2^\ell$.
We then mean-pool over the token dimension to get a single vector per sample:

$$\mathbf{r}_\ell^{\text{layer}} = \frac{1}{N_\ell} \sum_{n=1}^{N_\ell} x_\ell[:, n, :] \in \mathbb{R}^{D_\ell}$$

Note that unlike the per-head and per-block representations, the per-layer vectors
have **different dimensionalities** across stages ($D_0=96$, $D_1=192$, $D_2=384$, $D_3=768$),
because they capture the full layer output including $W^O$, MLP, residuals, and patch merging.
This is the only granularity that captures the actual residual stream flowing between stages.

**Requires a second forward pass** with hooks on `BasicLayer`, not derivable from
`head_outputs_final` or `block_outputs_final`.

**Output.** 4 tensors of shape $[N_{\text{samples}},\, D_\ell]$, one per stage.

In [11]:
# ‚îÄ‚îÄ Layer-level extractor: hooks on BasicLayer ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# BasicLayer.forward signature (from htsat.py):
#   def forward(self, x) -> (x, attn)
# where x after the call has shape [B, N_ell, D_ell], already including:
#   - all SwinTransformerBlock residual additions (attention + MLP)
#   - patch merging (PatchMerging) if layer_idx < 3
# We hook the output of each BasicLayer to get the true inter-stage residual.

class LayerLevelExtractor:
    """
    Registers one forward hook per BasicLayer (stage).
    Each hook captures the layer output x after all blocks and patch merging,
    mean-pools over the token dimension, and stores the result.
    Output shape per layer: [D_ell] = [96 * 2^layer_idx]
    """

    def __init__(self, model):
        self.model         = model
        self.layer_outputs = defaultdict(list)  # layer_id ‚Üí list of [D_ell] tensors
        self.hooks         = []

    def _make_hook(self, layer_idx):
        def hook(module, input, output):
            # BasicLayer returns (x, attn); x: [B, N_ell, D_ell]
            x_out, _ = output
            # Mean pool over token dimension ‚Üí [B, D_ell]
            # B=1 since we process one sample at a time ‚Üí squeeze ‚Üí [D_ell]
            pooled = x_out.mean(dim=1).squeeze(0).detach().cpu()
            self.layer_outputs[f"L{layer_idx}"].append(pooled)
        return hook

    def register_hooks(self):
        for layer_idx, layer in enumerate(self.model.layers):
            h = layer.register_forward_hook(self._make_hook(layer_idx))
            self.hooks.append(h)
        print(f"‚úÖ Registered {len(self.hooks)} layer hooks")

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks.clear()

    def finalize(self):
        """Stack per-layer lists ‚Üí dict of tensors [N_samples, D_ell]."""
        return {lid: torch.stack(vecs) for lid, vecs in self.layer_outputs.items()}


# ‚îÄ‚îÄ Extraction loop (second forward pass) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
layer_extractor = LayerLevelExtractor(audio_encoder)
layer_extractor.register_hooks()

for audio_path, _ in tqdm(sample_list, desc="Extracting layers"):
    audio_tensor = wrapper.load_audio_into_tensor(
        audio_path, wrapper.args.duration, resample=True
    ).reshape(1, -1).to(device)

    with torch.no_grad():
        audio_encoder(audio_tensor)

layer_extractor.remove_hooks()
layer_outputs_final = layer_extractor.finalize()

# ‚îÄ‚îÄ Verify ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
print(f"\n‚úÖ Layer extraction complete")
print(f"   Layers extracted : {len(layer_outputs_final)}")
for key, tensor in layer_outputs_final.items():
    layer_idx = int(key[1])
    expected_dim = 96 * (2 ** layer_idx)
    print(f"   {key}: shape {tuple(tensor.shape)}  (expected D={expected_dim})")
    assert tensor.shape[0] == N_SAMPLES, f"Sample count mismatch for {key}"
    assert tensor.shape[1] == expected_dim, f"Dim mismatch for {key}: got {tensor.shape[1]}, expected {expected_dim}"

assert len(layer_outputs_final) == len(HTSAT_DEPTHS)
print("   ‚úÖ All shapes verified")

# ‚îÄ‚îÄ Save ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
save_path = f"{SAVE_DIR}/{DATASET.__name__.lower()}_layer_outputs_final.pt"
torch.save({"layer_outputs_final": layer_outputs_final, "labels": sample_labels}, save_path)
print(f"   Saved ‚Üí {save_path}")

‚úÖ Registered 4 layer hooks


Extracting layers:   0%|          | 0/2000 [00:00<?, ?it/s]


‚úÖ Layer extraction complete
   Layers extracted : 4
   L0: shape (2000, 192)  (expected D=96)


AssertionError: Dim mismatch for L0: got 192, expected 96

## Cell 6 ‚Äî Summary & Sanity Checks

We verify the following:

1. **Shape correctness**: each layer tensor has the expected dimensionality
   $D_\ell = 96 \cdot 2^\ell$ ($96, 192, 384, 768$ for $\ell = 0,1,2,3$).

2. **Fisher discriminability**: a quick per-key Fisher score
   $F = \overline{S_B / (S_W + \varepsilon)}$
   confirms that the extracted representations carry class-discriminative information
   at all three granularities. Note that layer Fisher scores are not directly comparable
   to head/block scores since the feature spaces have different dimensionalities.

In [None]:
print("=" * 60)
print("EXTRACTION SUMMARY")
print("=" * 60)

print(f"\nDataset  : {DATASET.__name__}")
print(f"Samples  : {N_SAMPLES}")
print(f"Classes  : {len(dataset.classes)}")
print(f"Device   : {device}")

print(f"\n{'Granularity':<12} {'Keys':<8} {'Shape per key'}")
print("-" * 42)
print(f"{'Head':<12} {len(head_outputs_final):<8} "
      f"{tuple(next(iter(head_outputs_final.values())).shape)}")
print(f"{'Block':<12} {len(block_outputs_final):<8} "
      f"{tuple(next(iter(block_outputs_final.values())).shape)}")
print(f"{'Layer':<12} {len(layer_outputs_final):<8} "
      f"{tuple(next(iter(layer_outputs_final.values())).shape)}")

# ‚îÄ‚îÄ Shape checks for layer outputs ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
print("\nüîç Layer output shape checks:")
for layer_idx in range(len(HTSAT_DEPTHS)):
    key = f"L{layer_idx}"
    tensor = layer_outputs_final[key]
    expected_dim = 96 * (2 ** layer_idx)
    ok = tensor.shape == (N_SAMPLES, expected_dim)
    print(f"   {key}: {tuple(tensor.shape)}  expected ({N_SAMPLES}, {expected_dim})  {'‚úÖ' if ok else '‚ùå'}")

# ‚îÄ‚îÄ Quick Fisher score check across granularities ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
def fisher_score(X, y):
    """Diagonal Fisher criterion: mean(S_B / (S_W + eps))."""
    classes   = np.unique(y)
    mu_global = X.mean(axis=0)
    S_B = np.zeros(X.shape[1])
    S_W = np.zeros(X.shape[1])
    for c in classes:
        Xc   = X[y == c]
        mu_c = Xc.mean(axis=0)
        S_B += len(Xc) * (mu_c - mu_global) ** 2
        S_W += ((Xc - mu_c) ** 2).sum(axis=0)
    return float((S_B / (S_W + 1e-8)).mean())

print("\nüìä Quick Fisher discriminability per key:")
print("   Note: block/layer scores are not directly comparable to head scores")
print("   since they operate in different-dimensional spaces (D_ell vs 24).\n")

print(f"   {'Key':<20} {'F':>8}  {'dim':>6}")
print("   " + "-" * 38)

for hid in sorted(head_outputs_final.keys()):
    f = fisher_score(head_outputs_final[hid].numpy(), sample_labels)
    dim = head_outputs_final[hid].shape[1]
    print(f"   {hid:<20} {f:>8.4f}  {dim:>6}")

print()
for bk in sorted(block_outputs_final.keys()):
    f = fisher_score(block_outputs_final[bk].numpy(), sample_labels)
    dim = block_outputs_final[bk].shape[1]
    print(f"   {bk:<20} {f:>8.4f}  {dim:>6}")

print()
for lk in sorted(layer_outputs_final.keys()):
    f = fisher_score(layer_outputs_final[lk].numpy(), sample_labels)
    dim = layer_outputs_final[lk].shape[1]
    print(f"   {lk:<20} {f:>8.4f}  {dim:>6}")

print(f"\n‚úÖ All files saved in '{SAVE_DIR}/':")
print(f"   {DATASET.__name__.lower()}_head_outputs_final.pt")
print(f"   {DATASET.__name__.lower()}_block_outputs_final.pt")
print(f"   {DATASET.__name__.lower()}_layer_outputs_final.pt")

EXTRACTION SUMMARY

Dataset  : ESC50
Samples  : 2000
Classes  : 50
Device   : cpu

Granularity  Keys     Shape per key
------------------------------------------
Head         184      (2000, 24)
Block        12       (2000, 24)
Layer        4        (2000, 24)

üîç Consistency checks (layer mean ‚âà mean of block means):
   L0: max abs diff = 0.00e+00  ‚úÖ
   L1: max abs diff = 0.00e+00  ‚úÖ
   L2: max abs diff = 0.00e+00  ‚úÖ
   L3: max abs diff = 0.00e+00  ‚úÖ

üìä Quick Fisher discriminability (mean across keys per granularity):
   Head  ‚Äî mean F = 0.8749  min=0.3093  max=2.7675
   Block ‚Äî mean F = 0.8597  min=0.5664  max=1.4130
   Layer ‚Äî mean F = 0.9850  min=0.6843  max=1.2716

‚úÖ All files saved in 'heads_representations/':
   esc50_head_outputs_final.pt
   esc50_block_outputs_final.pt
   esc50_layer_outputs_final.pt
