# core

> Core data structures for midi-rae: PatchState and HierarchicalPatchState

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *


## Overview

These dataclasses bundle patch embeddings with their spatial and mask metadata,
replacing scattered positional return values and manual mask indexing throughout the codebase.

### `PatchState`
Holds a set of patch embeddings at a single spatial scale, along with their grid positions
and masks. Provides convenience properties for common operations like filtering visible patches.

### `HierarchicalPatchState`
A list of `PatchState` objects ordered **coarsest → finest** (index 0 = global/CLS level).
Currently used with two levels (CLS + patches), designed to extend to Swin-style multi-scale later.

### `EncoderOutput`
Full encoder output bundling the hierarchical patch states with the full (pre-MAE-masking)
positions and masks needed by the decoder for reconstruction.

In [None]:
#| export
import torch
from dataclasses import dataclass

In [None]:
#| export
@dataclass
class PatchState:
    """Bundle of patch embeddings at a single spatial scale with their metadata.

    Attributes:
        emb: (B, N, dim) patch embeddings
        pos: (N, 2) grid coordinates (row, col) for each patch
        non_empty: (B, N) content mask — 1 where patch has content (e.g. notes), 0 for empty
        mae_mask: (N,) MAE visibility mask — 1=visible, 0=masked out for reconstruction
    """
    emb: torch.Tensor
    pos: torch.Tensor
    non_empty: torch.Tensor
    mae_mask: torch.Tensor

    @property
    def visible(self):
        """New PatchState filtered to only MAE-visible patches"""
        m = self.mae_mask.bool()
        return PatchState(
            emb=self.emb[:, m],
            pos=self.pos[m],
            non_empty=self.non_empty[:, m],
            mae_mask=self.mae_mask[m],
        )

    @property
    def masked(self):
        """New PatchState filtered to only MAE-masked patches"""
        m = ~self.mae_mask.bool()
        return PatchState(
            emb=self.emb[:, m],
            pos=self.pos[m],
            non_empty=self.non_empty[:, m],
            mae_mask=self.mae_mask[m],
        )

    @property
    def non_empty_flat(self):
        """Flat bool mask for non-empty patches — useful for loss computation"""
        return self.non_empty.reshape(-1).bool()

    @property
    def dim(self): return self.emb.shape[-1]

    @property
    def num_patches(self): return self.emb.shape[1]

    @property
    def batch_size(self): return self.emb.shape[0]

    def to(self, device):
        return PatchState(emb=self.emb.to(device), pos=self.pos.to(device), non_empty=self.non_empty.to(device), mae_mask=self.mae_mask.to(device))

In [None]:
#| export
@dataclass
class HierarchicalPatchState:
    """Multi-scale patch states, ordered coarsest → finest (currently: [0]=CLS, [1]=spatial patches).

    Attributes:
        levels: List of `PatchState`, one per scale

    Note: enc_out.patches[i] and enc_out.patches.levels[i] are equivalent
    """
    levels: list

    def __getitem__(self, idx): return self.levels[idx] # get item pull from levels
    def __len__(self): return len(self.levels)

    @property
    def coarsest(self): return self.levels[0]

    @property
    def finest(self): return self.levels[-1]

    @property
    def num_levels(self): return len(self.levels)

    @property
    def all_emb(self):
        """Concatenate embeddings from all levels along token dim"""
        return torch.cat([level.emb for level in self.levels], dim=1)

    @property 
    def all_non_empty(self):
        "concatenate non-empty patches at all levels" 
        return torch.cat([level.non_empty for level in self.levels], dim=1)

    def to(self, device):
        return HierarchicalPatchState(levels=[x.to(device) for x in self.levels])

In [None]:
#| export
@dataclass
class EncoderOutput:
    """Full encoder output.

    Attributes:
        patches: Encoded representations (visible patches only)
        full_pos: (N_full, 2) all grid positions before MAE masking (needed by decoder)
        full_non_empty: (B, N_full) all content masks before MAE masking
        mae_mask: (N_full,) the MAE mask applied (1=visible, 0=masked)
    """
    patches: HierarchicalPatchState
    full_pos: torch.Tensor
    full_non_empty: torch.Tensor
    mae_mask: torch.Tensor

def to(self, device):
        return EncoderOutput( patches=self.patches.to(device), full_pos=self.full_pos.to(device), 
                              full_non_empty=self.full_non_empty.to(device), mae_mask=self.mae_mask.to(device) )

### Sample usage

**Encoder returns `EncoderOutput` containing a `HierarchicalPatchState`:**
```python
enc_out = encoder(img, mask_ratio=0.5)

# Access the patch hierarchy
cls_state = enc_out.patches.coarsest    # PatchState with CLS token
patch_state = enc_out.patches.finest    # PatchState with patch embeddings
```

**Working with PatchState — filtering, shapes, masks:**
```python
ps = enc_out.patches.finest

ps.emb          # (B, N_visible, dim) — patch embeddings
ps.pos          # (N_visible, 2) — grid coordinates (row, col)
ps.non_empty    # (B, N_visible) — content mask (1=has notes)
ps.mae_mask     # (N_visible,) — all True for already-filtered patches
ps.dim          # embedding dimension
ps.num_patches  # number of patches

vis = ps.visible  # new PatchState with only MAE-visible patches
```

**In compute_batch_loss (encoder training):**
```python
# BEFORE: 8 positional return values
# loss_dict, z1, z2, non_emptys, pos2, mae_mask2, num_tokens, recon_patches = ...

# AFTER:
loss_dict, enc_out1, enc_out2, recon_patches = compute_batch_loss(...)
non_emptys = (enc_out1.patches.finest.non_empty, enc_out2.patches.finest.non_empty)
```

**In LightweightMAEDecoder:**
```python
# BEFORE:
# def forward(self, z, pos_full, mae_mask): ...

# AFTER:
# def forward(self, enc_out: EncoderOutput): ...
#   — gets visible embeddings, full positions, and mae_mask all from enc_out
```

**Future Swin hierarchy (coarsest → finest):**
```python
# levels[0] = global (like CLS), levels[1] = 4x4, levels[2] = 8x8, levels[3] = 16x16
h = enc_out.patches
h.coarsest          # global summary
h.finest            # finest-resolution patches  
h.levels[1]         # intermediate scale
h.levels[1].visible # visible patches at that scale
```

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()