# swin

> Swin Transformer V2 Encoder for midi-rae — drop-in replacement for ViTEncoder

In [None]:
#| default_exp swin

In [None]:
#| hide
from nbdev.showdoc import *

## Implementation Plan: `SwinEncoder` — Drop-in Replacement for `ViTEncoder`

### Goal

Create a `SwinEncoder` class in `midi_rae.swin` that can replace `ViTEncoder` with **zero changes**
to calling code. It must accept the same constructor args and return an `EncoderOutput` with the
same structure. Import dataclasses from `midi_rae.core`.

**Note any upgrades or changes to the PatchState / HierarchicalPatchState / EncoderOutput dataclasses**
that would be beneficial for the Swin architecture (e.g. storing window sizes, extra scale metadata),
but do NOT implement them yet — just leave TODO comments describing what could change.

---

### Source Material

**Copy** (don't import) the needed components from the timm Swin V2 file at:
```
/app/data/pytorch-image-models/timm/models/swin_transformer_v2.py
```

Copy these 5 components into `#| export` cells in this notebook:

1. `window_partition(x, window_size)` and `window_reverse(windows, window_size, H, W)` — free functions
2. `WindowAttention` — the V2 version with log-scale cosine attention and continuous position bias (CPB) MLP
3. `SwinTransformerV2Block` — one transformer block with windowed or shifted-windowed attention
4. `PatchMerging` — merges 2×2 patches and projects to higher dim (the downsampling layer between stages)
5. `SwinTransformerV2Stage` — a full stage = optional downsample + N blocks

**Do NOT copy** the top-level `SwinTransformerV2` class — we will write our own `SwinEncoder` wrapper.

---

### Handling timm Imports

The timm file uses these imports that need replacing:

| timm import | Replacement |
|---|---|
| `from timm.layers import DropPath` | `from torch.nn import Identity` and implement: `DropPath` = stochastic depth. Simple impl: during training, randomly drop the whole residual path with probability `drop_prob`. Or just copy timm's `DropPath` — it's ~15 lines. |
| `from timm.layers import Mlp` | Copy or rewrite: it's just `Linear → Act → Dropout → Linear → Dropout`. |
| `from timm.layers import to_2tuple` | `def to_2tuple(x): return (x, x) if isinstance(x, int) else tuple(x)` |
| `from timm.layers import _assert` | Replace with plain `assert` statements |
| `from timm.layers import PatchEmbed` | We do NOT need this — our `SwinEncoder` receives an already-patchified image (see forward spec below) OR we write our own simple conv-based patch embed. |
| `from timm.layers import ClassifierHead` | Not needed — we have no classification head |
| `from timm.data import IMAGENET_DEFAULT_MEAN/STD` | Not needed |
| All model registry decorators (`@register_model`, `generate_default_cfgs`, etc.) | Remove entirely |

**Alternatively**, if `timm` is installed in the environment, you may import these utilities directly
(`from timm.layers import DropPath, Mlp, to_2tuple`). Check with `import timm` first. If it works,
prefer importing over copying for `DropPath` and `Mlp` to reduce code. But `_assert` should still
become plain `assert`.

---

### Existing Interface to Match

From `midi_rae/core.py`:

```python
@dataclass
class PatchState:
    emb: torch.Tensor       # (B, N, dim) patch embeddings
    pos: torch.Tensor       # (N, 2) grid coordinates (row, col)
    non_empty: torch.Tensor # (B, N) content mask — 1 where patch has content
    mae_mask: torch.Tensor  # (N,) MAE visibility mask — 1=visible, 0=masked

@dataclass
class HierarchicalPatchState:
    levels: list  # list of PatchState, coarsest-first

@dataclass
class EncoderOutput:
    latent: torch.Tensor              # (B, N, dim) final encoder output
    patch_state: PatchState            # final-scale patch state
    hierarchical: HierarchicalPatchState  # multi-scale states
```

(Double-check these against the actual `midi_rae/core.py` — the field names above are from memory
and may need adjustment. Read the file to confirm.)

From `midi_rae/vit.py`, the `ViTEncoder` has roughly this interface:

```python
class ViTEncoder(nn.Module):
    def __init__(self, img_height, img_width, patch_h, patch_w,
                 in_chans=1, embed_dim=256, depth=6, num_heads=8,
                 mlp_ratio=4.0, drop_rate=0.0, mae_ratio=0.0, ...):
    def forward(self, x) -> EncoderOutput:
        # x: (B, 1, H, W) piano roll image
        # returns EncoderOutput
```

(Again, read the actual file to confirm exact args. The key point: it takes an image and returns `EncoderOutput`.)

---

### `SwinEncoder` Class Design

```python
class SwinEncoder(nn.Module):
    def __init__(self,
        img_height: int,          # e.g. 128 (MIDI pitch range)
        img_width: int,           # e.g. 256 (time steps)
        patch_h: int = 4,         # patch height
        patch_w: int = 4,         # patch width
        in_chans: int = 1,        # piano rolls are single-channel
        embed_dim: int = 96,      # Swin default; ViT used 256
        depths: tuple = (2, 2, 6, 2),   # blocks per stage
        num_heads: tuple = (3, 6, 12, 24),  # heads per stage
        window_size: int = 7,     # attention window size
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        mae_ratio: float = 0.0,   # kept for interface compat; NOT used in Swin path
    ):
```

#### Constructor should:

1. **Patch embedding**: Use a `nn.Conv2d(in_chans, embed_dim, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))` followed by a `nn.LayerNorm(embed_dim)`. This replaces timm's `PatchEmbed`.
   - After conv: reshape from `(B, C, H', W')` to `(B, H'*W', C)` for the transformer.
   - Store `self.grid_size = (img_height // patch_h, img_width // patch_w)` — needed by stages.

2. **Build stages**: Create `nn.ModuleList` of `SwinTransformerV2Stage` instances.
   - Stage 0: no downsampling, dim=embed_dim
   - Stages 1+: downsample=True (PatchMerging), dim doubles each stage
   - Use stochastic depth with linearly increasing drop path rates across all blocks.
   - `input_resolution` for stage 0 = `self.grid_size`; halves each subsequent stage.

3. **Final norm**: `nn.LayerNorm(final_dim)` where `final_dim = embed_dim * 2^(num_stages-1)`.

4. Store `self.num_stages = len(depths)`.

#### Forward should:

1. **Patch embed**: `x = self.patch_embed(img)` → reshape to `(B, N, C)`
2. **Apply dropout** if configured: `x = self.pos_drop(x)`
3. **Run through stages**, collecting intermediate outputs:
   ```python
   intermediates = []
   for stage in self.stages:
       x = stage(x)
       intermediates.append(x)  # (B, N_i, C_i) at each scale
   ```
4. **Apply final norm**: `x = self.norm(x)`
5. **Build HierarchicalPatchState**: For each intermediate, create a `PatchState`:
   - `emb`: the intermediate tensor `(B, N_i, C_i)`
   - `pos`: grid coordinates for that scale. Stage i has grid `(H_i, W_i)` where `H_i = grid_H // 2^i`, etc.
     Generate with: `torch.stack(torch.meshgrid(torch.arange(H_i), torch.arange(W_i), indexing='ij'), dim=-1).reshape(-1, 2)`
   - `non_empty`: `torch.ones(B, N_i, device=x.device)` — we skip empty-patch handling for now
   - `mae_mask`: `torch.ones(N_i, device=x.device)` — no MAE masking in Swin path for now
   - **Order**: `levels` list should be **coarsest-first** (i.e. reverse of stage order, since stage 0 is finest)
6. **Return** `EncoderOutput(latent=x, patch_state=<finest PatchState>, hierarchical=<HierarchicalPatchState>)`

---

### Important Notes

- **No CLS token**: Swin doesn't use one — the spatial output IS the representation.
- **No RoPE**: Swin V2 uses its own continuous position bias (CPB) MLP. Do not add RoPE.
- **No MAE masking during Swin stages**: The windowed attention makes token-level masking complex.
  Keep `mae_ratio` in the constructor for interface compatibility but ignore it.
  Leave a `# TODO: MAE masking not yet supported in Swin path` comment.
- **Window size vs grid size**: If `grid_size` in any dimension is smaller than `window_size`,
  the timm code handles this (it adjusts). But document the constraint.
- **Data format**: timm's V2 stages expect `(B, H*W, C)` input (NHWC flattened). They internally
  reshape to `(B, H, W, C)` for windowing. Check `SwinTransformerV2Stage.forward()` — it takes
  `(B, N, C)` and needs to know the spatial dims. The stage stores `input_resolution`.

---

### Notebook Cell Order

1. Title markdown (already done)
2. `#| default_exp swin` (already done)
3. `#| hide` + nbdev import (already done)
4. **This instruction note** (not exported)
5. `#| export` — Imports cell: `torch`, `torch.nn`, `torch.nn.functional`, `math`, `functools.partial`, `from midi_rae.core import PatchState, HierarchicalPatchState, EncoderOutput`
6. `#| export` — Helper utilities: `to_2tuple`, `DropPath` (or import from timm), `Mlp` (or import from timm)
7. `#| export` — `window_partition` and `window_reverse`
8. `#| export` — `WindowAttention`
9. `#| export` — `SwinTransformerV2Block`
10. `#| export` — `PatchMerging`
11. `#| export` — `SwinTransformerV2Stage`
12. `#| export` — `SwinEncoder`
13. Test cell (not exported): instantiate with typical MIDI piano roll dims and verify output shapes

---

### Test Cell

```python
# Test: verify SwinEncoder is a drop-in for ViTEncoder
B, C, H, W = 2, 1, 128, 256  # typical piano roll
enc = SwinEncoder(
    img_height=H, img_width=W,
    patch_h=4, patch_w=4,
    in_chans=C, embed_dim=96,
    depths=(2, 2, 6, 2),
    num_heads=(3, 6, 12, 24),
    window_size=8,
)
x = torch.randn(B, C, H, W)
out = enc(x)

print(f'latent:       {out.latent.shape}')        # expect (2, N_final, 768)
print(f'patch_state:  {out.patch_state.emb.shape}')  # same as latent
print(f'num levels:   {len(out.hierarchical.levels)}')
for i, ps in enumerate(out.hierarchical.levels):
    print(f'  level {i}: emb={ps.emb.shape}, pos={ps.pos.shape}, non_empty={ps.non_empty.shape}')

# Expected hierarchy (coarsest first):
#   level 0 (coarsest): emb=(2, 4*8, 768),  pos=(32, 2)  — grid 4×8
#   level 1:            emb=(2, 8*16, 384),  pos=(128, 2) — grid 8×16
#   level 2:            emb=(2, 16*32, 192), pos=(512, 2) — grid 16×32
#   level 3 (finest):   emb=(2, 32*64, 96),  pos=(2048, 2) — grid 32×64
```

Grid sizes above assume `img=128×256`, `patch=4×4` → initial grid `32×64`,
then halved at each stage after stage 0.