# train_enc

> Encoder training script for midi_rae 

In [None]:
#| default_exp train_enc

In [None]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
import os
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
from itertools import chain
import multiprocessing as mp
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import OneCycleLR
import wandb
from hydra import compose, initialize
from omegaconf import DictConfig, OmegaConf
import hydra
from midi_rae.core import *
from midi_rae.vit import ViTEncoder, LightweightMAEDecoder
from midi_rae.swin import SwinEncoder, SwinMAEDecoder
from midi_rae.data import PRPairDataset
from midi_rae.losses import calc_enc_loss, calc_mae_loss, calc_enc_loss_multiscale
from midi_rae.utils import save_checkpoint, load_checkpoint
from midi_rae.viz import make_emb_viz, viz_mae_recon
from tqdm.auto import tqdm

torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')

## Curriculum Learning

In [4]:
#| export
def curr_learn(shared_ct_dict, epoch, interval=100, verbose=False): 
    "UNUSED/UNNECESSARY: curriculum learning: increase difficulty with epoch"
    if epoch < interval: return shared_ct_dict['training']
    training = shared_ct_dict['training']
    training['max_shift_x'] = min(12, 6 + epoch // interval)
    training['max_shift_y'] = min(12, 6 + epoch // interval)
    if verbose: 
        print(f"curr_learn: max_shift_x = {training['max_shift_x']}, max_shift_y = {training['max_shift_y']}")
    return training

## Compute Loss On Batch

In [5]:
#| export
@torch.compiler.disable
def compute_batch_loss(batch, encoder, cfg, global_step, mae_decoder=None, debug=False):
    "Compute loss and return other exal auxiliary variables (for train or val)"
    device = next(encoder.parameters()).device
    img1, img2, deltas = batch['img1'].to(device), batch['img2'].to(device), batch['deltas'].to(device)
    enc_out1 = encoder(img1, mask_ratio=cfg.training.mask_ratio)
    enc_out2 = encoder(img2, mae_mask=enc_out1.mae_mask)
    loss_dict = {}
    recon_patches = None
    if mae_decoder is not None: # recon at finest scale
        eo = enc_out2.patches[-1] # readibility/convenience variable
        recon_patches = mae_decoder(enc_out2) 
        if debug:
            for i, lv in enumerate(enc_out2.patches.levels):
                print(f"level {i}: {lv.emb.isnan().any()}, min={lv.emb.min():.4f}, max={lv.emb.max():.4f}")
        loss_dict['mae'] = calc_mae_loss(recon_patches, img2, enc_out2, lambda_visible=cfg.training.get('lambda_visible',0.1))

    if isinstance(enc_out1.patches, HierarchicalPatchState):
        z1 = [lvl.emb for lvl in enc_out1.patches.levels]
        z2 = [lvl.emb for lvl in enc_out2.patches.levels]
        non_emptys = [(l1.non_empty, l2.non_empty) for l1, l2 in zip(enc_out1.patches.levels, enc_out2.patches.levels)]
    else:
        z1 = enc_out1.patches.all_emb.reshape(-1, enc_out1.patches[1].dim)
        z2 = enc_out2.patches.all_emb.reshape(-1, enc_out2.patches[1].dim)
        non_emptys = (enc_out1.patches.all_non_empty, enc_out2.patches.all_non_empty)

    #for i, z in enumerate(z1): print(f"level {i}: shape={z.shape}, norm={z.norm():.4f}, min={z.min():.4f}, max={z.max():.4f}")
    loss_dict = loss_dict | calc_enc_loss_multiscale(z1, z2, global_step, img_size=cfg.data.image_size, deltas=deltas, lambd=cfg.training.lambd, non_emptys=non_emptys)

    if 'mae' in loss_dict.keys(): loss_dict['loss'] += cfg.training.get('mae_lambda', 1.0) * loss_dict['mae']

    if torch.isnan(loss_dict['loss']):
        print("NaN detected!", {k: v.item() if hasattr(v, 'item') else v for k, v in loss_dict.items()})
        breakpoint()
    return loss_dict, (z1, z2), (enc_out1, enc_out2), recon_patches

## Main Training Loop

In [None]:
#| export
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def train(cfg: DictConfig):
    print("config:",cfg)
    device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
    print("device = ",device)

    train_ds = PRPairDataset(image_dataset_dir=cfg.data.path, split='train', max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y) 
    val_ds   = PRPairDataset(image_dataset_dir=cfg.data.path, split='val',  max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y) 
    train_dl = DataLoader(train_ds, batch_size=cfg.training.batch_size, num_workers=8, drop_last=True, pin_memory=True, persistent_workers=True)
    val_dl   = DataLoader(val_ds,   batch_size=cfg.training.batch_size, num_workers=4, drop_last=True, pin_memory=True, persistent_workers=True)

    # next bit is to enable curriculum learning, dataloaders re-defined per epoch, gotta use nested function
    manager = mp.Manager()
    shared_ct_dict = manager.dict(OmegaConf.to_container(cfg))
    def worker_init_fn(worker_id):
        ds = torch.utils.data.get_worker_info().dataset
        ds.max_shift_x = shared_ct_dict['training']['max_shift_x']
        ds.max_shift_y = shared_ct_dict['training']['max_shift_y']

    # setup models
    patch_size = cfg.model.get('patch_size', cfg.model.get('patch_h', 16))
    dim = cfg.model.get('dim', cfg.model.get('embed_dim', 256))    
    if cfg.model.get('encoder', 'vit') == 'swin':
        from midi_rae.swin import SwinEncoder
        model = SwinEncoder(img_height=cfg.data.image_size, img_width=cfg.data.image_size,
                            patch_h=cfg.model.patch_h, patch_w=cfg.model.patch_w,
                            embed_dim=cfg.model.embed_dim, depths=cfg.model.depths,
                            num_heads=cfg.model.num_heads, window_size=cfg.model.window_size,
                            mlp_ratio=cfg.model.mlp_ratio, drop_path_rate=cfg.model.drop_path_rate).to(device)
        dims = tuple(cfg.model.embed_dim * 2**i for i in range(len(cfg.model.depths)-1, -1, -1))
        mae_decoder = SwinMAEDecoder(patch_size=patch_size, dims=dims).to(device)
    else:
        model = ViTEncoder(cfg.data.in_channels, (cfg.data.image_size, cfg.data.image_size), cfg.model.patch_size, 
              cfg.model.dim, cfg.model.depth, cfg.model.heads).to(device)
        mae_decoder = LightweightMAEDecoder(patch_size=patch_size, dim=dim).to(device)

    print("model       =",model.__class__.__name__)
    print("mae_decoder =",mae_decoder.__class__.__name__)

    optimizer = torch.optim.AdamW(chain(model.parameters(), mae_decoder.parameters()), lr=cfg.training.lr)
    epoch_start = 1
    if (cfg.get('checkpoint', False)): # use "+checkpoint=<path>" from CLI
        ckpt_path =  cfg.get('checkpoint',None)
        model, ckpt = load_checkpoint(model, ckpt_path, return_all=True)
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        epoch_start = ckpt['epoch'] + 1 # next epoch after the one completed in thecheckpoint
        try:
            mae_decoder = load_checkpoint(mae_decoder, ckpt_path.replace('enc_','maedec_'), return_all=False)
        except: pass

    if False: # skip compilation
        #model = torch.compile(model, dynamic=True)
        os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.expanduser("~/.cache/torch/inductor")
        os.makedirs(os.environ["TORCHINDUCTOR_CACHE_DIR"], exist_ok=True)
        torch._inductor.config.fx_graph_cache = True
        torch._inductor.config.compile_threads = 8
        torch._dynamo.config.cache_size_limit = 16
        #model = torch.compile(model, mode="default", fullgraph=False)
        for i, stage in enumerate(model.stages):
            model.stages[i] = torch.compile(stage, mode="default")
    
    #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
    scheduler = OneCycleLR(optimizer, max_lr=cfg.training.lr, steps_per_epoch=1, epochs=cfg.training.epochs, div_factor=5, 
                        **({'last_epoch': epoch_start-1} if epoch_start > 1 else {}))
    #scaler = torch.amp.GradScaler(device)
    
    if not(cfg.get('no_wandb', False)): 
        wandb.init(project=cfg.wandb.project, config=dict(cfg), settings=wandb.Settings(start_method="fork", _disable_stats=True))
        wandb.define_metric("epoch")
        wandb.define_metric("*", step_metric="epoch")

    # Training loop
    global_step = (epoch_start - 1) * len(train_dl)
    #viz_every = 10
    viz_every = 1 # make it fail early for debugging
    for epoch in range(epoch_start, cfg.training.epochs+1):
        if wandb.run is not None: wandb.log({"epoch": epoch})
        model.train()
        train_loss = 0
        if False and epoch > 1: # curriculum learning, easily turned off by setting this to False. DL's re-defined each epoch to init workers
            shared_ct_dict['training'] = curr_learn(shared_ct_dict, epoch)
            train_dl = DataLoader(train_ds, batch_size=cfg.training.batch_size, num_workers=4, drop_last=True, worker_init_fn=worker_init_fn, pin_memory=True)
            val_dl   = DataLoader(val_ds,   batch_size=cfg.training.batch_size, num_workers=4, drop_last=True, worker_init_fn=worker_init_fn, pin_memory=True)
        for batch in tqdm(train_dl, desc=f"Epoch {epoch}/{cfg.training.epochs}"):
            global_step += 1
            optimizer.zero_grad(set_to_none=True)
            if True: # with torch.autocast('cuda'):
                #loss_dict, z1, z2, non_emptys, pos2, mae_mask2, num_tokens, recon_patches = compute_batch_loss(batch, model, cfg, global_step, mae_decoder=mae_decoder)
                loss_dict, zs, enc_outs, recon_patches = compute_batch_loss(batch, model, cfg, global_step, mae_decoder=mae_decoder)
            #scaler.scale(loss_dict['loss']).backward()
            #scaler.step(optimizer)
            #scaler.update()
            loss_dict['loss'].backward()
            optimizer.step()
            train_loss += loss_dict['loss'].item()
            
        # At end of Epoch: validation, viz, etc
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_dl:
                #val_loss_dict, z1, z2, non_emptys, pos2, mae_mask2, num_tokens, recon_patches = compute_batch_loss(batch, model, cfg, global_step, mae_decoder=mae_decoder)
                val_loss_dict, zs, enc_outs, recon_patches = compute_batch_loss(batch, model, cfg, global_step, mae_decoder=mae_decoder)
                val_loss += val_loss_dict['loss'].item()

            train_loss /= len(train_dl)
            val_loss /= len(val_dl)
            print(f"Epoch {epoch}/{cfg.training.epochs}: train_loss={train_loss:.3f} val_loss={val_loss:.3f}")
            
            if wandb.run is not None:
                wandb.log({ 
                    "train_loss":train_loss, "train_sim":loss_dict['sim'], "train_sigreg":loss_dict['sigreg'], "train_anchor":loss_dict['anchor'], "train_mae":loss_dict['mae'],  
                    "val_loss":val_loss, "val_sim":val_loss_dict['sim'], "val_sigreg":val_loss_dict['sigreg'], "val_anchor":val_loss_dict['anchor'], "val_mae": val_loss_dict['mae'], 
                    "max_shift_x":shared_ct_dict['training']['max_shift_x'], "max_shift_y":shared_ct_dict['training']['max_shift_y'],
                    "lr": optimizer.param_groups[0]['lr'], "epoch": epoch})

                if epoch % viz_every == 0:
                    make_emb_viz(enc_outs, epoch=epoch, model=model, batch=batch)
                if mae_decoder is not None and (epoch % (viz_every//5) == 0):
                    viz_mae_recon(recon_patches, batch['img2'], enc_out=enc_outs.finest, epoch=epoch)

        save_checkpoint(model, optimizer, epoch, val_loss, cfg, tag="enc_")
        save_checkpoint(mae_decoder, optimizer, epoch, val_loss, cfg, tag="maedec_")

        scheduler.step()# val_loss)

## CLI Entry Point

In [None]:
#| export
#| eval: false
if __name__ == "__main__" and "ipykernel" not in __import__("sys").modules:
    train()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()