# swin

> Swin Transformer V2 Encoder for midi-rae — drop-in replacement for ViTEncoder

In [None]:
#| default_exp swin

In [None]:
#| hide
from nbdev.showdoc import *

## Design Overview

### What this module does

`SwinEncoder` is a drop-in replacement for `ViTEncoder` that uses the **Swin Transformer V2**
architecture. It takes a piano roll image `(B, 1, 128, 128)` and returns an `EncoderOutput`
with hierarchical multi-scale patch states.

### Why Swin V2?

- **Hierarchical representation**: 7 levels from finest (64×64 grid, dim=4) down to a single
  CLS-like token (1×1, dim=256), compared to ViT's flat single-scale output
- **Efficient attention**: Windowed attention with shifted windows — O(N) instead of O(N²)
- **V2 improvements**: Cosine attention with learned log-scale temperature, continuous position
  bias via CPB MLP, res-post-norm for training stability

### Architecture

| Stage | Grid | Patch covers | Dim | Depths | Heads |
|-------|------|-------------|-----|--------|-------|
| 0 | 64×64 | 2×2 | 4 | 1 | 1 |
| 1 | 32×32 | 4×4 | 8 | 1 | 1 |
| 2 | 16×16 | 8×8 | 16 | 2 | 1 |
| 3 | 8×8 | 16×16 | 32 | 2 | 2 |
| 4 | 4×4 | 32×32 | 64 | 6 | 4 |
| 5 | 2×2 | 64×64 | 128 | 2 | 8 |
| 6 | 1×1 | 128×128 | 256 | 1 | 16 |

Config is in `configs/config_swin.yaml`.

### Implementation approach

We use **timm's `SwinTransformerV2Stage` directly** — no copied or modified Swin internals.
Our `SwinEncoder` wrapper handles only:

1. **Patch embedding** — `Conv2d(1, 4, kernel_size=2, stride=2)` + LayerNorm
2. **Empty patch detection** — patches where all pixels are black get a learnable `empty_token`
3. **MAE masking** (SimMIM-style) — masked patches get a learnable `mask_token`, grid stays
   intact so windowed attention works unmodified. Two-rate sampling: non-empty patches masked
   at `mask_ratio`, empty patches at `mask_ratio × empty_mask_ratio` (default 5%)
4. **Hierarchical output** — collects each stage's output into `HierarchicalPatchState`
   (coarsest-first), packaged as `EncoderOutput`

### Key differences from ViTEncoder

- No CLS token (stage 6's single 1×1 token serves as a global summary)
- No RoPE (Swin V2 uses its own continuous position bias)
- MAE masking keeps all tokens (SimMIM-style) — no compute savings but preserves spatial grid
- `empty_mask_ratio` controls how often trivial-to-reconstruct empty patches are masked

In [None]:
#| export
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, Set, Type, Union
from functools import partial

from timm.models.swin_transformer_v2 import SwinTransformerV2Stage
from timm.layers import trunc_normal_, to_2tuple, calculate_drop_path_rates
from midi_rae.core import PatchState, HierarchicalPatchState, EncoderOutput

In [None]:
#| export
#| export
class SwinEncoder(nn.Module):
    """Swin Transformer V2 Encoder for midi-rae — drop-in replacement for ViTEncoder.

    Uses timm's SwinTransformerV2Stage directly (no copied/modified code).
    Adds custom patch embedding, learnable empty/mask tokens, and packages
    multi-scale output as EncoderOutput.

    Key differences from ViTEncoder:
    - No CLS token — Swin uses spatial features directly
    - Hierarchical multi-scale output (7 levels by default)
    - Windowed attention with shifted windows (no global attention, no RoPE)
    - V2 features: cosine attention, log-scale temperature, continuous position bias

    # TODO: HierarchicalPatchState could store window_size per level
    # TODO: EncoderOutput could store scale metadata (downsample factors per level)
    """

    def __init__(self,
        img_height: int,          # e.g. 128 (MIDI pitch range)
        img_width: int,           # e.g. 256 (time steps)
        patch_h: int = 2,
        patch_w: int = 2,
        in_chans: int = 1,        # piano rolls are single-channel
        embed_dim: int = 4,       # 2×2 patch → 4 dims; doubles each stage to 256
        depths: tuple = (1, 1, 2, 2, 6, 2, 1),
        num_heads: tuple = (1, 1, 1, 2, 4, 8, 16),
        window_size: int = 8,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop_rate: float = 0.0,
        proj_drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        norm_layer: type = nn.LayerNorm,
        mae_ratio: float = 0.0,
        empty_mask_ratio: float = 0.05,  # fraction of mae_ratio applied to empty patches
    ):
        super().__init__()
        self.num_stages = len(depths)
        self.embed_dim = embed_dim
        self.num_features = int(embed_dim * 2 ** (self.num_stages - 1))
        self.patch_h, self.patch_w = patch_h, patch_w
        self.grid_size = (img_height // patch_h, img_width // patch_w)
        self.mae_ratio = mae_ratio
        self.empty_mask_ratio = empty_mask_ratio

        # --- Patch embedding: Conv2d + LayerNorm (replaces timm PatchEmbed) ---
        self.patch_embed = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w),
        )
        self.patch_norm = norm_layer(embed_dim)
        self.pos_drop = nn.Dropout(p=drop_rate)

        # --- Learnable replacement tokens ---
        self.empty_token = nn.Parameter(torch.zeros(embed_dim))
        self.mask_token = nn.Parameter(torch.zeros(embed_dim))
        # TODO: MAE masking in Swin keeps all tokens (SimMIM-style) — no compute
        # savings, but preserves the regular spatial grid that windowed attention needs.

        # --- Build stages using timm's SwinTransformerV2Stage ---
        embed_dims = [int(embed_dim * 2 ** i) for i in range(self.num_stages)]
        dpr = calculate_drop_path_rates(drop_path_rate, list(depths), stagewise=True)

        self.stages = nn.ModuleList()
        in_dim = embed_dims[0]
        scale = 1
        for i in range(self.num_stages):
            out_dim = embed_dims[i]
            self.stages.append(SwinTransformerV2Stage(
                dim=in_dim,
                out_dim=out_dim,
                input_resolution=(self.grid_size[0] // scale, self.grid_size[1] // scale),
                depth=depths[i],
                num_heads=num_heads[i],
                window_size=window_size,
                downsample=(i > 0),
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                proj_drop=proj_drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
            ))
            in_dim = out_dim
            if i > 0:
                scale *= 2

        # --- Final norm ---
        self.norm = norm_layer(self.num_features)

        # --- Weight init (matches timm SwinTransformerV2) ---
        self.apply(self._init_weights)
        for stage in self.stages:
            stage._init_respostnorm()

    def _init_weights(self, m: nn.Module) -> None:
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    @torch.jit.ignore
    def no_weight_decay(self) -> Set[str]:
        nod = {'empty_token', 'mask_token'}
        for n, _ in self.named_parameters():
            if any(kw in n for kw in ('cpb_mlp', 'logit_scale')):
                nod.add(n)
        return nod

    def _compute_non_empty(self, img: torch.Tensor) -> torch.Tensor:
        """Detect which patches have content (non-black). Matches PatchEmbedding in vit.py."""
        patches = img.unfold(2, self.patch_h, self.patch_h).unfold(3, self.patch_w, self.patch_w)
        non_empty = (patches.amax(dim=(-1, -2)) > 0.2).squeeze(1).flatten(1)  # (B, N)
        return non_empty

    def _make_mae_mask(self, non_empty: torch.Tensor, device: torch.device) -> torch.Tensor:
        """Generate MAE mask with two-rate sampling.

        Non-empty patches masked at mae_ratio; empty patches masked at
        mae_ratio * empty_mask_ratio (much lower — they're trivial to reconstruct).

        Returns: (N,) bool — True=visible, False=masked for reconstruction.
        """
        B, N = non_empty.shape
        is_nonempty = non_empty[0]  # (N,) — use first sample as representative
        rand = torch.rand(N, device=device)
        threshold = torch.where(
            is_nonempty.bool(),
            torch.full_like(rand, 1.0 - self.mae_ratio),
            torch.full_like(rand, 1.0 - self.mae_ratio * self.empty_mask_ratio),
        )
        return rand < threshold  # True = visible

    def forward(self, x: torch.Tensor, mask_ratio: float = 0.0,
                mae_mask: Optional[torch.Tensor] = None) -> EncoderOutput:
        """
        Args:
            x: (B, C, H, W) piano roll image
            mask_ratio: override self.mae_ratio for this call (0 = no masking)
            mae_mask: (N,) bool, True=visible. If None, generated when ratio > 0.

        Returns:
            EncoderOutput matching ViTEncoder interface.
        """
        B = x.shape[0]
        device = x.device
        grid_h, grid_w = self.grid_size
        N_full = grid_h * grid_w

        # --- Detect empty patches from raw input ---
        non_empty = self._compute_non_empty(x)  # (B, N_full)

        # --- Patch embed ---
        x = self.patch_embed(x)                        # (B, C, H', W')
        x = x.permute(0, 2, 3, 1).contiguous()        # → (B, H', W', C)  NHWC
        x = self.patch_norm(x)
        B, H, W, C = x.shape

        # --- Replace empty patches with learned empty_token ---
        non_empty_4d = non_empty.view(B, H, W, 1)
        x = torch.where(non_empty_4d, x, self.empty_token.view(1, 1, 1, -1).expand_as(x))

        # --- MAE masking: replace masked positions with learned mask_token ---
        effective_ratio = mask_ratio if mask_ratio > 0 else self.mae_ratio
        if mae_mask is None and effective_ratio > 0:
            mae_mask = self._make_mae_mask(non_empty, device)
        if mae_mask is not None:
            mae_mask_4d = mae_mask.view(1, H, W, 1).expand(B, -1, -1, -1)
            x = torch.where(mae_mask_4d, x, self.mask_token.view(1, 1, 1, -1).expand_as(x))
        else:
            mae_mask = torch.ones(N_full, device=device, dtype=torch.bool)

        x = self.pos_drop(x)

        # --- Run through stages, collect intermediates in NHWC ---
        intermediates = []
        for stage in self.stages:
            x = stage(x)                               # (B, H_i, W_i, C_i)
            intermediates.append(x)

        # --- Apply final norm to last (coarsest) stage ---
        intermediates[-1] = self.norm(intermediates[-1])

        # --- Build full-resolution grid positions ---
        full_pos = torch.stack(torch.meshgrid(
            torch.arange(grid_h, device=device),
            torch.arange(grid_w, device=device),
            indexing='ij',
        ), dim=-1).reshape(-1, 2)  # (N_full, 2)

        # --- Build HierarchicalPatchState (coarsest first) ---
        levels = []
        for feat in reversed(intermediates):
            Bf, Hf, Wf, Cf = feat.shape
            emb = feat.reshape(Bf, Hf * Wf, Cf)
            pos = torch.stack(torch.meshgrid(
                torch.arange(Hf, device=device),
                torch.arange(Wf, device=device),
                indexing='ij',
            ), dim=-1).reshape(-1, 2)
            n = Hf * Wf
            levels.append(PatchState(
                emb=emb,
                pos=pos,
                non_empty=torch.ones(Bf, n, device=device),
                mae_mask=torch.ones(n, device=device, dtype=torch.bool),
            ))

        return EncoderOutput(
            patches=HierarchicalPatchState(levels=levels),
            full_pos=full_pos,
            full_non_empty=non_empty,
            mae_mask=mae_mask,
        )

In [None]:
# Test: verify SwinEncoder output shapes
B, C, H, W = 2, 1, 128, 128
enc = SwinEncoder(img_height=H, img_width=W)
x = torch.randn(B, C, H, W)
out = enc(x)

print(f'mae_mask:        {out.mae_mask.shape}')
print(f'full_pos:        {out.full_pos.shape}')
print(f'full_non_empty:  {out.full_non_empty.shape}')
print(f'num levels:      {len(out.patches.levels)}')
for i, ps in enumerate(out.patches.levels):
    print(f'  level {i}: emb={ps.emb.shape}, pos={ps.pos.shape}')

# Expected hierarchy (coarsest first), 128×128 image, 2×2 patches:
#   level 0 (coarsest): emb=(1, 1,    256) — grid 1×1  (CLS-like)
#   level 1:            emb=(1, 4,    128) — grid 2×2
#   level 2:            emb=(1, 16,    64) — grid 4×4
#   level 3:            emb=(1, 64,    32) — grid 8×8
#   level 4:            emb=(1, 256,   16) — grid 16×16
#   level 5:            emb=(1, 1024,   8) — grid 32×32
#   level 6 (finest):   emb=(1, 4096,   4) — grid 64×64