# Train Decoder

> Training script for the ViT decoder with reconstruction losses

## Context for LLM assistance

This is part of a Representation Autoencoder (RAE) project for MIDI piano roll images.

**What we have:**
- `ViTEncoder` in `vit.py`: produces (B, 65, 768) â€” 64 patch tokens + 1 CLS token from 128Ã—128 images
- `ViTDecoder` in `vit.py`: takes (B, 65, 768), strips CLS, unpatchifies back to (B, 1, 128, 128)
- `PatchGANDiscriminator` in `losses.py`: for adversarial loss
- Pre-encoded data optionally available via `07_preencode.ipynb`

**What we're doing:**
- Train decoder to reconstruct images from frozen encoder embeddings
- Losses: L1 + LPIPS + GAN (adversarial) with adaptive weighting
- Encoder is FROZEN â€” only decoder and discriminator train

**What's next:**
- After decoder training: `09_dit.ipynb` (diffusion transformer architecture)
- Then: `10_train_gen.ipynb` (train DiT for generation in latent space)

**Image specs:** 128Ã—128 grayscale (1 channel), patch_size=16, giving 8Ã—8=64 patches

In [None]:
#| default_exp train_dec

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import wandb
from omegaconf import DictConfig
import hydra
from tqdm.auto import tqdm
import lpips
from collections import namedtuple

from midi_rae.vit import ViTEncoder, ViTDecoder
from midi_rae.losses import PatchGANDiscriminator
from midi_rae.data import PRPairDataset  # note, we'll only use img2 and ignore img1
from midi_rae.utils import save_checkpoint

torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')

In [None]:
#| export
class PreEncodedDataset(Dataset):
    """Load pre-encoded embeddings + images from .pt files"""
    def __init__(self, encoded_dir):
        self.files = sorted([f for f in os.listdir(encoded_dir) if f.endswith('.pt')])
        self.encoded_dir = encoded_dir
        # Load all chunks into memory (adjust if too large)
        self.embeddings, self.images = [], []
        for f in self.files:
            data = torch.load(os.path.join(encoded_dir, f))
            self.embeddings.append(data['embeddings'])
            self.images.append(data['images'])
        self.embeddings = torch.cat(self.embeddings, dim=0)
        self.images = torch.cat(self.images, dim=0)
    
    def __len__(self): return len(self.embeddings)
    def __getitem__(self, idx): return self.embeddings[idx], self.images[idx]

In [None]:
#| export
def get_embeddings_batch(batch, encoder=None, preencoded=False, device='cuda'):
    """Get embeddings + images from batch, either pre-encoded or computed on-the-fly"""
    if preencoded:
        z, img = batch
        return z.to(device), img.to(device)
    else:
        img = batch['img2'].to(device)  # adjust key based on your dataset
        with torch.no_grad():
            z = encoder(img, return_cls_only=False)
        return z, img

In [None]:
#| export
def setup_dataloaders(cfg, preencoded=False):
        # --- Data ---
    if preencoded:
        train_ds = PreEncodedDataset(cfg.preencode.output_dir + '/train')
        val_ds = PreEncodedDataset(cfg.preencode.output_dir + '/val')
    else:
        train_ds = PRPairDataset(split='train', max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y) 
        val_ds   = PRPairDataset(split='val',  max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y) 
    
    train_dl = DataLoader(train_ds, batch_size=cfg.training.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_dl = DataLoader(val_ds, batch_size=cfg.training.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
    return train_dl, val_dl 

In [None]:
def setup_models(cfg, preencoded): 
    encoder = None
    if not preencoded:
        encoder = ViTEncoder(cfg.data.in_channels, cfg.data.image_size, cfg.model.patch_size,
                             cfg.model.dim, cfg.model.depth, cfg.model.heads).to(device)
        encoder = load_checkpoint(encoder, cfg.get('encoder_ckpt', 'checkpoints/enc_best.pt'))
        encoder.eval()  # frozen
        for p in encoder.parameters(): p.requires_grad = False
    
    decoder = ViTDecoder(cfg.data.in_channels, (cfg.data.image_size, cfg.data.image_size),
                         cfg.model.patch_size, cfg.model.dim, 
                         cfg.model.get('dec_depth', 4), cfg.model.get('dec_heads', 8)).to(device)
    decoder = torch.compile(decoder)
    
    discriminator = PatchGANDiscriminator(in_ch=cfg.data.in_channels).to(device)
    discriminator = torch.compile(discriminator)
    return encoder, decoder, discriminator

In [None]:
def setup_tstate(cfg, device):
    "Training_state: Losses, Optimizers, Schedulers, AMP Scalers"
    l1_loss = nn.L1Loss()
    lpips_loss = lpips.LPIPS(net='vgg').to(device)
    opt_dec = torch.optim.AdamW(decoder.parameters(), lr=cfg.training.lr)
    opt_disc = torch.optim.AdamW(discriminator.parameters(), lr=cfg.training.lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(opt_dec, max_lr=cfg.training.lr, steps_per_epoch=1, epochs=cfg.training.epochs)
    schedulerD = torch.optim.lr_scheduler.OneCycleLR(opt_disc, max_lr=cfg.training.lr, steps_per_epoch=1, epochs=cfg.training.epochs)   
    scaler_dec, scaler_disc = torch.amp.GradScaler(), torch.amp.GradScaler()
    return namedtuple('TrainState', ['opt_disc', 'opt_dec', 'scaler_disc', 'scaler_dec', 'l1_loss', 'lpips_loss'])

In [None]:
def train_step(z, img_real, decoder, discriminator, 
            tstate,  # named tuple containing optimizers, loss fns, scalers 
            ): 
    "training step for decoder (and discriminator)"
    # --- Discriminator step ---
    tstate.opt_disc.zero_grad()
    with torch.autocast('cuda'):
        img_recon = decoder(z)
        d_real = discriminator(img_real)
        d_fake = discriminator(img_recon.detach())
        loss_disc = (torch.relu(1 - d_real).mean() + torch.relu(1 + d_fake).mean()) / 2
    tstate.scaler_disc.scale(loss_disc).backward()
    tstate.scaler_disc.step(tstate.opt_disc)
    tstate.scaler_disc.update()
    
    # --- Decoder step ---
    tstate.opt_dec.zero_grad()
    with torch.autocast('cuda'):
        # img_recon = decoder(z)  # Don't need to recompute this.
        loss_l1 = tstate.l1_loss(img_recon, img_real)
        loss_lpips = tstate.lpips_loss(img_recon.repeat(1,3,1,1), img_real.repeat(1,3,1,1)).mean()  # LPIPS wants 3ch
        loss_gan = -discriminator(img_recon).mean()  # generator wants discriminator to say "real"
        loss_dec = loss_l1 + 0.1 * loss_lpips + 0.01 * loss_gan  # TODO: adaptive weighting as in RAE paper (??)
    tstate.scaler_dec.scale(loss_dec).backward()
    tstate.scaler_dec.step(tstate.opt_dec)
    tstate.scaler_dec.update()
    
    keys = ['disc', 'l1', 'lpips', 'gan', 'dec']
    vals = [loss_disc, loss_l1, loss_lpips, loss_gan, loss_dec]
    losses = { k:v.item() for k,v in zip(keys, vals) }
    return losses, img_recon.detach()


In [None]:
#| export
@hydra.main(version_base=None, config_path='../configs', config_name='config')
def train(cfg: DictConfig):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    preencoded = cfg.get('preencoded', False)
    
    train_dl, val_dl = setup_dataloaders(cfg, preencoded)
    encoder, decoder, discriminator = setup_models(cfg, preencoded) 
    tstate = setup_tstate(cfg, device)
    wandb.init(project='dec-'+cfg.wandb.project, config=dict(cfg))
    
    for epoch in range(1, cfg.training.epochs + 1):
        decoder.train()
        discriminator.train()
        train_loss = 0
        
        for batch in tqdm(train_dl, desc=f'Epoch {epoch}/{cfg.training.epochs}'):
            z, img_real = get_embeddings_batch(batch, encoder, preencoded, device)
            losses, img_recon = train_step(z, img_real, decoder, discriminator, tstate)
            train_loss += losses['dec'].item()
        
        train_loss /= len(train_dl)
        print(f'Epoch {epoch}: train_loss={train_loss:.4f}')
        wandb.log({'train_loss': train_loss, 'loss_l1': losses['l1'], 'loss_lpips': losses['lpips'], 'loss_gan': losses['gan'],
                   'loss_disc': losses['disc'], 'epoch': epoch})
        
        # TODO: validation, checkpointing, visualization e.g. reconstruction comparison
        scheduler.step()
        schedulerD.step()
    
    wandb.finish()

Hey, I know it's pretty normal for a training script to just be a one big long serial thing of code. But this way of doing code isn't actually all that well adapted to the solve it GUI. Are there any natural ways we might modularize some of the parts of the preceding cell?

##### ğŸ¤–ReplyğŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ 

Great question! That monolithic `train()` function has several natural seams where you could split things. Looking at it, I'd suggest these groupings:

1. **Data loading** â€” `get_dataloaders(cfg, preencoded)` â†’ returns `train_dl, val_dl`

2. **Model setup** â€” `get_models(cfg, device, preencoded)` â†’ returns `encoder, decoder, discriminator`

3. **Losses & optimizers** â€” `get_training_components(decoder, discriminator, cfg)` â†’ returns the loss fns, optimizers, schedulers, scalers

4. **Single training step** â€” `train_step(batch, ...)` â†’ the discriminator + decoder update logic, returns a dict of losses

5. **Epoch loop** â€” `train_epoch(...)` â†’ iterates batches, calls `train_step`, aggregates metrics

The biggest win for solveit would be extracting `train_step` â€” that way you can test one batch interactively without running a full epoch. You could load a small batch, call `train_step`, inspect gradients/outputs, tweak loss weights, etc.

Which of these would be most useful to break out first? Or is there a specific part you want to iterate on more quickly?

Okay, I broke off a couple things. What if I separated the Discriminator step and the decoder step? What would I need to pass back and forth? If it turns out I have to pass too many variables, it might not be worth it.

##### ğŸ¤–ReplyğŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ 

Looking at your code, the key coupling between the two steps is `img_recon` â€” it's generated in the discriminator step and reused in the decoder step (you even have a comment about not recomputing it).

**Discriminator step** needs: `z`, `img_real`, `decoder`, `discriminator`, `opt_disc`, `scaler_disc`  
**Returns**: `img_recon`, `loss_disc`

**Decoder step** needs: `img_recon`, `img_real`, `discriminator`, `l1_loss`, `lpips_loss`, `opt_dec`, `scaler_dec`  
**Returns**: dict of `loss_l1`, `loss_lpips`, `loss_gan`, `loss_dec`

That's not too bad â€” the main "handoff" is just `img_recon`. You could bundle the optimizer+scaler pairs into small objects/tuples if you wanted to reduce arg count further.

Alternatively, you could keep them together as a single `train_step()` that returns all the losses, and the split is just internal organization. That might give you the readability benefit without the interface overhead.

What's your instinct â€” worth the split, or keep them together?

Okay, yeah, there's kind of a give and take on this. I guess maybe the train step could work.  would that be like this...? 

```python
from collections import namedtuple
tstate = namedtuple('TrainState', ['opt_disc', 'opt_dec', 'scaler_disc', 'scaler_dec', 'l1_loss', 'lpips_loss'])
```

```python
def train_step(z, img_real, decoder, discriminator, tstate): 
    # --- Discriminator step ---
    tstate.opt_disc.zero_grad()
    with torch.autocast('cuda'):
        img_recon = decoder(z)
        d_real = discriminator(img_real)
        d_fake = discriminator(img_recon.detach())
        loss_disc = (torch.relu(1 - d_real).mean() + torch.relu(1 + d_fake).mean()) / 2
    tstate.scaler_disc.scale(loss_disc).backward()
    tstate.scaler_disc.step(tstate.opt_disc)
    tstate.scaler_disc.update()
    
    # --- Decoder step ---
    tstate.opt_dec.zero_grad()
    with torch.autocast('cuda'):
        # img_recon = decoder(z)  # Don't need to recompute this.
        loss_l1 = tstate.l1_loss(img_recon, img_real)
        loss_lpips = tstate.lpips_loss(img_recon.repeat(1,3,1,1), img_real.repeat(1,3,1,1)).mean()  # LPIPS wants 3ch
        loss_gan = -discriminator(img_recon).mean()  # generator wants discriminator to say "real"
        loss_dec = loss_l1 + 0.1 * loss_lpips + 0.01 * loss_gan  # TODO: adaptive weighting as in RAE paper (??)
    tstate.scaler_dec.scale(loss_dec).backward()
    tstate.scaler_dec.step(tstate.opt_dec)
    tstate.scaler_dec.update()
    
    keys = ['disc', 'l1', 'lpips', 'gan', 'dec']
    vals = [loss_disc, loss_l1, loss_lpips, loss_gan, loss_dec]
    losses = { k:v.item() for k,v in zip(keys, vals) }
    return losses, img_recon.detach()

##### ğŸ¤–ReplyğŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ 

Yeah, that looks clean! The namedtuple bundles things nicely, and returning `(losses, img_recon.detach())` is handy for visualization during interactive testing.

One small thing: you might want to return `img_recon` *before* detaching if you ever want to inspect gradients, but `.detach()` is the safer default to avoid holding onto the computation graph.

This should make the main loop much tidier â€” just iterate batches, call `train_step`, accumulate losses. Want to sketch out what `train_epoch` would look like with this?

I agree, that many positional arguments is can be error-prone. Feel free to suggest an alternative. I'm not sure about the train state data class. I'm not accustomed to using data classes.

##### ğŸ¤–ReplyğŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ 

A few options, from simplest to most structured:

1. **Keyword-only args** â€” just add `*` as the first param. Callers must use `train_step(z=z, img_real=img, ...)` which catches mistakes at the call site. Zero new concepts.

2. **A plain dict** â€” pass `opts={'disc': opt_disc, 'dec': opt_dec}` and `scalers={...}`. Simple but no autocomplete/type hints.

3. **A namedtuple** â€” one line to define, immutable, gives you dot access:
   ```python
   TrainState = namedtuple('TrainState', ['opt_disc', 'opt_dec', 'scaler_disc', 'scaler_dec', 'l1_loss', 'lpips_loss'])
   ```
   Then `state = TrainState(opt_disc, opt_dec, ...)` and `state.opt_disc` in the function.

Which of these appeals to you?

oooo teach me about `namedtuple`! i've never seen that. is that a way to get easy ".attribute" types of things?  when was that introduced into Python? 

##### ğŸ¤–ReplyğŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ 

Yes, exactly! `namedtuple` gives you dot-access attributes without writing a full class. It's been in Python since 2.6 (2008), so it's very well-established.

Quick example:

```python
from collections import namedtuple

Point = namedtuple('Point', ['x', 'y'])
p = Point(3, 4)
print(p.x, p.y)  # 3 4
```

It's essentially a tuple that lets you access elements by name instead of index. So `p.x` is the same as `p[0]`, but much more readable.

A few nice properties:
- Immutable (like regular tuples)
- Memory-efficient
- Works with unpacking: `x, y = p`
- Has a nice `__repr__`: `Point(x=3, y=4)`

Want to try creating one for the training state?

In [None]:
#| export
#| eval: false
if __name__ == "__main__" and "ipykernel" not in __import__("sys").modules:
    train()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()