# Train Decoder

> Training script for the ViT decoder with reconstruction losses

## Context for LLM assistance

This is part of a Representation Autoencoder (RAE) project for MIDI piano roll images.

**What we have:**
- `ViTEncoder` in `vit.py`: produces (B, 65, 768) — 64 patch tokens + 1 CLS token from 128×128 images
- `ViTDecoder` in `vit.py`: takes (B, 65, 768), strips CLS, unpatchifies back to (B, 1, 128, 128)
- `PatchGANDiscriminator` in `losses.py`: for adversarial loss
- Pre-encoded data optionally available via `07_preencode.ipynb`

**What we're doing:**
- Train decoder to reconstruct images from frozen encoder embeddings
- Losses: L1 + LPIPS + GAN (adversarial) with adaptive weighting
- Encoder is FROZEN — only decoder and discriminator train

**What's next:**
- After decoder training: `09_dit.ipynb` (diffusion transformer architecture)
- Then: `10_train_gen.ipynb` (train DiT for generation in latent space)

**Image specs:** 128×128 grayscale (1 channel), patch_size=16, giving 8×8=64 patches

In [None]:
#| default_exp train_dec

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import wandb
from omegaconf import DictConfig
import hydra
from tqdm.auto import tqdm
import lpips
from collections import namedtuple

from midi_rae.vit import ViTEncoder, ViTDecoder
from midi_rae.losses import PatchGANDiscriminator
from midi_rae.data import PRPairDataset  # note, we'll only use img2 and ignore img1
from midi_rae.utils import save_checkpoint, load_checkpoint

torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')

In [None]:
#| export
class PreEncodedDataset(Dataset):
    """Load pre-encoded embeddings + images from .pt files"""
    def __init__(self, encoded_dir):
        self.files = sorted([f for f in os.listdir(encoded_dir) if f.endswith('.pt')])
        self.encoded_dir = encoded_dir
        # Load all chunks into memory (adjust if too large)
        self.embeddings, self.images = [], []
        for f in self.files:
            data = torch.load(os.path.join(encoded_dir, f))
            self.embeddings.append(data['embeddings'])
            self.images.append(data['images'])
        self.embeddings = torch.cat(self.embeddings, dim=0)
        self.images = torch.cat(self.images, dim=0)
    
    def __len__(self): return len(self.embeddings)
    def __getitem__(self, idx): return self.embeddings[idx], self.images[idx]

In [None]:
#| export
def get_embeddings_batch(batch, encoder=None, preencoded=False, device='cuda'):
    """Get embeddings + images from batch, either pre-encoded or computed on-the-fly"""
    if preencoded:
        z, img = batch
        return z.to(device), img.to(device)
    else:
        img = batch['img2'].to(device)  # adjust key based on your dataset
        with torch.no_grad():
            z = encoder(img, return_cls_only=False)
        return z, img

In [None]:
#| export
def setup_dataloaders(cfg, preencoded=False):
        # --- Data ---
    if preencoded:
        train_ds = PreEncodedDataset(cfg.preencode.output_dir + '/train')
        val_ds = PreEncodedDataset(cfg.preencode.output_dir + '/val')
    else:
        train_ds = PRPairDataset(split='train', max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y) 
        val_ds   = PRPairDataset(split='val',  max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y) 
    
    train_dl = DataLoader(train_ds, batch_size=cfg.training.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_dl = DataLoader(val_ds, batch_size=cfg.training.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
    return train_dl, val_dl 

In [None]:
#| export
def setup_models(cfg, device, preencoded): 
    encoder = None
    if not preencoded:
        encoder = ViTEncoder(cfg.data.in_channels, cfg.data.image_size, cfg.model.patch_size,
                             cfg.model.dim, cfg.model.depth, cfg.model.heads).to(device)
        encoder = load_checkpoint(encoder, cfg.get('encoder_ckpt', 'checkpoints/enc_best.pt'))
        encoder.eval()  # frozen
        for p in encoder.parameters(): p.requires_grad = False
    
    decoder = ViTDecoder(cfg.data.in_channels, (cfg.data.image_size, cfg.data.image_size),
                         cfg.model.patch_size, cfg.model.dim, 
                         cfg.model.get('dec_depth', 4), cfg.model.get('dec_heads', 8)).to(device)
    decoder = torch.compile(decoder)
    
    discriminator = PatchGANDiscriminator(in_ch=cfg.data.in_channels).to(device)
    discriminator = torch.compile(discriminator)
    return encoder, decoder, discriminator

In [None]:
#| export
def setup_tstate(cfg, device, decoder, discriminator):
    "Training_state: Losses, Optimizers, Schedulers, AMP Scalers"
    l1_loss = nn.L1Loss()
    lpips_loss = lpips.LPIPS(net='vgg').to(device)
    opt_dec = torch.optim.AdamW(decoder.parameters(), lr=cfg.training.lr)
    opt_disc = torch.optim.AdamW(discriminator.parameters(), lr=cfg.training.lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(opt_dec, max_lr=cfg.training.lr, steps_per_epoch=1, epochs=cfg.training.epochs)
    schedulerD = torch.optim.lr_scheduler.OneCycleLR(opt_disc, max_lr=cfg.training.lr, steps_per_epoch=1, epochs=cfg.training.epochs)   
    scaler_dec, scaler_disc = torch.amp.GradScaler(), torch.amp.GradScaler()
    return namedtuple('TrainState', ['opt_disc', 'opt_dec', 'scaler_disc', 'scaler_dec', 'l1_loss', 'lpips_loss'])

In [None]:
#| export
def train_step(z, img_real, decoder, discriminator, 
            tstate,  # named tuple containing optimizers, loss fns, scalers 
            ): 
    "training step for decoder (and discriminator)"
    # --- Discriminator step ---
    tstate.opt_disc.zero_grad()
    with torch.autocast('cuda'):
        img_recon = decoder(z)
        d_real = discriminator(img_real)
        d_fake = discriminator(img_recon.detach())
        loss_disc = (torch.relu(1 - d_real).mean() + torch.relu(1 + d_fake).mean()) / 2
    tstate.scaler_disc.scale(loss_disc).backward()
    tstate.scaler_disc.step(tstate.opt_disc)
    tstate.scaler_disc.update()
    
    # --- Decoder step ---
    tstate.opt_dec.zero_grad()
    with torch.autocast('cuda'):
        # img_recon = decoder(z)  # Don't need to recompute this.
        loss_l1 = tstate.l1_loss(img_recon, img_real)
        loss_lpips = tstate.lpips_loss(img_recon.repeat(1,3,1,1), img_real.repeat(1,3,1,1)).mean()  # LPIPS wants 3ch
        loss_gan = -discriminator(img_recon).mean()  # generator wants discriminator to say "real"
        loss_dec = loss_l1 + 0.1 * loss_lpips + 0.01 * loss_gan  # TODO: adaptive weighting as in RAE paper (??)
    tstate.scaler_dec.scale(loss_dec).backward()
    tstate.scaler_dec.step(tstate.opt_dec)
    tstate.scaler_dec.update()
    
    keys = ['disc', 'l1', 'lpips', 'gan', 'dec']
    vals = [loss_disc, loss_l1, loss_lpips, loss_gan, loss_dec]
    losses = { k:v.item() for k,v in zip(keys, vals) }
    return losses, img_recon.detach()


In [None]:
#| export
@hydra.main(version_base=None, config_path='../configs', config_name='config')
def train(cfg: DictConfig):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    preencoded = cfg.get('preencoded', False)
    
    train_dl, val_dl = setup_dataloaders(cfg, preencoded)
    encoder, decoder, discriminator = setup_models(cfg, device, preencoded) 
    tstate = setup_tstate(cfg, device, decoder, discriminator)
    wandb.init(project='dec-'+cfg.wandb.project, config=dict(cfg))
    
    for epoch in range(1, cfg.training.epochs + 1):
        decoder.train()
        discriminator.train()
        train_loss = 0
        
        for batch in tqdm(train_dl, desc=f'Epoch {epoch}/{cfg.training.epochs}'):
            z, img_real = get_embeddings_batch(batch, encoder, preencoded, device)
            losses, img_recon = train_step(z, img_real, decoder, discriminator, tstate)
            train_loss += losses['dec'].item()
        
        train_loss /= len(train_dl)
        print(f'Epoch {epoch}: train_loss={train_loss:.4f}')
        wandb.log({'train_loss': train_loss, 'loss_l1': losses['l1'], 'loss_lpips': losses['lpips'], 'loss_gan': losses['gan'],
                   'loss_disc': losses['disc'], 'epoch': epoch})
        
        # TODO: validation, checkpointing, visualization e.g. reconstruction comparison
        scheduler.step()
        schedulerD.step()
    
    wandb.finish()

In [None]:
#| export
#| eval: false
if __name__ == "__main__" and "ipykernel" not in __import__("sys").modules:
    train()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()