# train_enc

> Encoder training script for midi_rae 

In [None]:
#| default_exp train_enc

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os
from itertools import chain
import multiprocessing as mp
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import wandb
from hydra import compose, initialize
from omegaconf import DictConfig, OmegaConf
import hydra
from midi_rae.vit import ViTEncoder, LightweightMAEDecoder
from midi_rae.data import PRPairDataset
from midi_rae.losses import calc_enc_loss, calc_mae_loss
from midi_rae.utils import save_checkpoint
from midi_rae.viz import make_emb_viz
from tqdm.auto import tqdm

torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')

## Curriculum Learning

In [None]:
#| export
def curr_learn(shared_ct_dict, epoch, interval=100, verbose=False): 
    "UNUSED/UNNECESSARY: curriculum learning: increase difficulty with epoch"
    if epoch < interval: return shared_ct_dict['training']
    training = shared_ct_dict['training']
    training['max_shift_x'] = min(12, 6 + epoch // interval)
    training['max_shift_y'] = min(12, 6 + epoch // interval)
    if verbose: 
        print(f"curr_learn: max_shift_x = {training['max_shift_x']}, max_shift_y = {training['max_shift_y']}")
    return training

## Compute Loss On Batch

In [None]:
#| export
def compute_batch_loss(batch, encoder, cfg, global_step, mae_decoder=None): 
    "Compute loss and return other exal auxiliary variables (for train or val)"
    device = next(encoder.parameters()).device
    img1, img2, deltas = batch['img1'].to(device), batch['img2'].to(device), batch['deltas'].to(device)
    z1, pmask1, pos1, mae_mask1 = encoder(img1, return_cls_only=False) 
    z2, pmask2, pos2, mae_mask2 = encoder(img2, return_cls_only=False) 
    loss_dict = {} 
    if mae_decoder is not None:
        recon_patches = mae_decoder(z2, pos2, mae_mask2) # just pick z2 and ignore z1 
        loss_dict['mae'] = calc_mae_loss(recon_patches, img2, pos2, mae_mask2)

    z1 = z1.reshape(-1, z1.shape[-1])
    z2 = z2.reshape(-1, z2.shape[-1])
    num_tokens =  z1.shape[0] // len(deltas)  # or just 65
    deltas = deltas.repeat_interleave(num_tokens, dim=0)
    loss_dict = loss_dict | calc_enc_loss(z1, z2, global_step, deltas=deltas, lambd=cfg.training.lambd, pmasks=(pmask1,pmask2))
    if 'mae' in loss_dict.keys(): loss_dict['loss'] += cfg.training.get('mae_lambda', 1.0) * loss_dict['mae'] 

    return loss_dict, z1, z2, pmask1, pmask2, num_tokens

getting an error: 
```
Traceback (most recent call last):
  File "/home/shawley/github/midi-rae/midi_rae/train_enc.py", line 107, in train
    loss_dict, z1, z2, pmask1, pmask2, num_tokens = compute_batch_loss(batch, model, cfg, global_step, mae_decoder=mae_decoder)
  File "/home/shawley/github/midi-rae/midi_rae/train_enc.py", line 56, in compute_batch_loss
    loss_dict = loss_dict | calc_enc_loss(z1, z2, global_step, deltas=deltas, lambd=cfg.training.lambd, pmasks=(pmask1,pmask2))
  File "/home/shawley/github/midi-rae/midi_rae/losses.py", line 60, in calc_enc_loss
    loss_dict = LeJEPA(z1[valid], z2[valid], global_step, deltas=deltas[valid], lambd=lambd)
IndexError: The shape of the mask [49920] at index 0 does not match the shape of the indexed tensor [6912, 256] at index 0

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
```
you may need to scrape `/app/data/midi-rae/midi_rae/losses.py` and `/app/data/midi-rae/midi_rae/vit.py` But I think the issue is the handling of the shapes right here in `compute_batch_loss`. help?

##### ğŸ¤–ReplyğŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ 

Looking at the error, the issue is a shape mismatch between the `valid` mask (49920 elements) and `z1`/`z2` (6912 rows).

In `compute_batch_loss`, you reshape `z1` and `z2`:
```python
z1 = z1.reshape(-1, z1.shape[-1])  # becomes [6912, 256]
z2 = z2.reshape(-1, z2.shape[-1])
```

But `pmask1` and `pmask2` are passed directly to `calc_enc_loss` **without** being reshaped to match. That's likely where the `valid` mask is constructed.

Can you check: what are the shapes of `pmask1` and `pmask2` before they get passed to `calc_enc_loss`? You could add a quick print:

```python
print(f"z1: {z1.shape}, pmask1: {pmask1.shape}, pmask2: {pmask2.shape}")
```

That'll confirm whether the pmasks need the same flattening treatment as z1/z2.

Hmm that seems unlikely. That part of the code used to work fine. What's broken is now that we've added this decoder stuff and the MAE mask stuff. Check the VIT code.  Note that Pmask is still the full set of non-empty patches. We really tried to ensure backwards compatibility.

##### ğŸ¤–ReplyğŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ 



<details class='tool-usage-details'>
<summary>view(path='/app/data/midi-rae/midi_rae/vit.py')</summary>

```json
{
  "id": "toolu_012XfSQwaas4LURPjTSVDeR7",
  "call": {
    "function": "view",
    "arguments": {
      "path": "/app/data/midi-rae/midi_rae/vit.py"
    }
  },
  "result": "'\"\"\"Components & defs for ViT-based Encoder & Decoder\"\"\"\\n\\n# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_vit.ipynb.\\n\\n# %% auto #0\\n__all__ = [\\'RoPE2D\\', \\'Attention\\', \\'TransformerBlock\\', \\'PatchEmbedding\\', \\'apply_mae_mask\\', \\'ViTEncoder\\', \\'Unpatchify\\',\\n           \\'ViTDecoder\\', \\'LightweightMAEDecoder\\']\\n\\n# %% ../nbs/02_vit.ipynb #b96051a7\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F \\n\\n# %% ../nbs/02_vit.ipynb #624570b6\\nclass RoPE2D(nn.Module):\\n    def __init__(self, head_dim):\\n        super().__init__()\\n        i = torch.arange(0, head_d<TRUNCATED>"
}
```

</details>

I see the issue. Look at what `ViTEncoder.forward` returns:

```python
return (x[:, 0] if return_cls_only else x), pmask, pos, mae_mask
```

It's returning the **original full** `pmask` and `pos`, not the masked versions (`pmask_visible`, `pos_visible`). But `x` has been masked down to only visible tokens.

So when `return_cls_only=False`, you get:
- `x` with shape `(B, N_visible, dim)` â€” only the visible (unmasked) tokens
- `pmask` with shape `(B, 65)` â€” the **full** mask before MAE masking
- `pos` with shape `(65, 2)` â€” the **full** positions

The shapes don't match. 

The fix should be in `ViTEncoder.forward` â€” it should return `pmask_visible` and `pos_visible` instead of `pmask` and `pos`. Does that match your intent for how the encoder output should be used?

Well, if I do that, then the code elsewhere fails. 
This is probably because if we do that, then For two different subsequent calls to the encoder, we get two different random MAE masks, don't we?

```
    loss_dict, z1, z2, pmask1, pmask2, num_tokens = compute_batch_loss(batch, model, cfg, global_step, mae_decoder=mae_decoder)
  File "/home/shawley/github/midi-rae/midi_rae/train_enc.py", line 56, in compute_batch_loss
    loss_dict = loss_dict | calc_enc_loss(z1, z2, global_step, deltas=deltas, lambd=cfg.training.lambd, pmasks=(pmask1,pmask2))
  File "/home/shawley/github/midi-rae/midi_rae/losses.py", line 58, in calc_enc_loss
    pmask = pmask1 & pmask2  # both non-empty
RuntimeError: The size of tensor a (5) must match the size of tensor b (7) at non-singleton dimension 1
```
Perhaps we should make it to where you could pass in the MAE mask from the first call into the second call, in which case it won't generate a new random mask, but it'll just use the one from the previous call.?

##### ğŸ¤–ReplyğŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ 

Yes, exactly right! With random MAE masking, `z1` and `z2` will have different subsets of visible tokens, so you can't directly compare them in the contrastive loss.

Your suggestion makes sense: add an optional `mae_mask` parameter to the encoder's forward method. If provided, use it instead of generating a new random one.

So the flow in `compute_batch_loss` would become:

1. First call: `z1, pmask1, pos1, mae_mask1 = encoder(img1, ...)` â€” generates random mask
2. Second call: `z2, pmask2, pos2, mae_mask2 = encoder(img2, mae_mask=mae_mask1, ...)` â€” reuses same mask

Does that match what you're thinking? If so, you'd just need to modify `ViTEncoder.forward` to accept an optional `mae_mask` argument and skip the random generation if one is provided.

## Main Training Loop

In [None]:
#| export
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def train(cfg: DictConfig):
    device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
    print("device = ",device)

    train_ds = PRPairDataset(split='train', max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y) 
    val_ds   = PRPairDataset(split='val',  max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y) 
    train_dl = DataLoader(train_ds, batch_size=cfg.training.batch_size, num_workers=4, drop_last=True, pin_memory=True)
    val_dl   = DataLoader(val_ds,   batch_size=cfg.training.batch_size, num_workers=4, drop_last=True, pin_memory=True)

    # next bit is to enable curriculum learning, dataloaders re-defined per epoch, gotta use nested function
    manager = mp.Manager()
    shared_ct_dict = manager.dict(OmegaConf.to_container(cfg))
    def worker_init_fn(worker_id):
        ds = torch.utils.data.get_worker_info().dataset
        ds.max_shift_x = shared_ct_dict['training']['max_shift_x']
        ds.max_shift_y = shared_ct_dict['training']['max_shift_y']

    model = ViTEncoder(cfg.data.in_channels, (cfg.data.image_size, cfg.data.image_size), cfg.model.patch_size, 
              cfg.model.dim, cfg.model.depth, cfg.model.heads).to(device)
    model = torch.compile(model)
    mae_decoder = LightweightMAEDecoder(patch_size=cfg.model.patch_size, dim=cfg.model.dim).to(device)

    #optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.training.lr)
    optimizer = torch.optim.AdamW(chain(model.parameters(), mae_decoder.parameters()), lr=cfg.training.lr)

    #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=cfg.training.lr, steps_per_epoch=1, epochs=cfg.training.epochs)
    scaler = torch.amp.GradScaler(device)
    if not(cfg.get('no_wandb', False)): wandb.init(project=cfg.wandb.project, config=dict(cfg), settings=wandb.Settings(start_method="fork", _disable_stats=True))

    # Training loop
    global_step = 0
    viz_every = 10
    for epoch in range(1, cfg.training.epochs+1):
        model.train()
        train_loss = 0
        if False and epoch > 1: # curriculum learning, easily turned off by setting this to False. DL's re-defined each epoch to init workers
            shared_ct_dict['training'] = curr_learn(shared_ct_dict, epoch)
            train_dl = DataLoader(train_ds, batch_size=cfg.training.batch_size, num_workers=4, drop_last=True, worker_init_fn=worker_init_fn, pin_memory=True)
            val_dl   = DataLoader(val_ds,   batch_size=cfg.training.batch_size, num_workers=4, drop_last=True, worker_init_fn=worker_init_fn, pin_memory=True)
        for batch in tqdm(train_dl, desc=f"Epoch {epoch}/{cfg.training.epochs}"):
            global_step += 1
            optimizer.zero_grad()
            with torch.autocast('cuda'):
                loss_dict, z1, z2, pmask1, pmask2, num_tokens = compute_batch_loss(batch, model, cfg, global_step, mae_decoder=mae_decoder)
            scaler.scale(loss_dict['loss']).backward()
            scaler.step(optimizer)
            scaler.update()
            train_loss += loss_dict['loss'].item()
            
        # At end of Epoch: validation, viz, etc
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_dl:
                val_loss_dict, z1, z2, pmask1, pmask2, num_tokens = compute_batch_loss(batch, model, cfg, global_step, mae_decoder=mae_decoder)
                val_loss += val_loss_dict['loss'].item()

        train_loss /= len(train_dl)
        val_loss /= len(val_dl)
        print(f"Epoch {epoch}/{cfg.training.epochs}: train_loss={train_loss:.3f} val_loss={val_loss:.3f}")
        
        if wandb.run is not None: wandb.log({"train_loss": train_loss, "val_loss": val_loss, 
           "train_sim": loss_dict['sim'], "train_sigreg": loss_dict['sigreg'], "train_anchor":loss_dict['anchor'], "train_mae":loss_dict['mae'],
           "val_sim": val_loss_dict['sim'], "val_sigreg": val_loss_dict['sigreg'], "val_anchor": val_loss_dict['anchor'], "val_mae": val_loss_dict['mae'],
           "max_shift_x":shared_ct_dict['training']['max_shift_x'], "max_shift_y":shared_ct_dict['training']['max_shift_y'], 
           "lr": optimizer.param_groups[0]['lr'], "epoch": epoch}, step=epoch)

        if epoch % viz_every == 0: 
            zs_stacked = torch.cat((z1, z2), dim=0).reshape(-1, z1.shape[-1])
            make_emb_viz(zs_stacked, num_tokens, epoch, model=model, pmasks=(pmask1,pmask2), file_idx=batch['file_idx'])

        save_checkpoint(model, optimizer, epoch, val_loss, cfg, tag="enc_")
        scheduler.step()# val_loss)

## CLI Entry Point

In [None]:
#| export
#| eval: false
if __name__ == "__main__" and "ipykernel" not in __import__("sys").modules:
    train()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()