# Data preparation

In [None]:
# TBD

# Training

### 1. Import dependencies

In [None]:
import torch
import pytorch_lightning as pl

from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
from models import build_model
from datasets import build_dataset
from ttt.utils.utils import set_seed
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    LearningRateMonitor,
)
import hydra
from hydra import initialize, compose
from omegaconf import OmegaConf
import uuid

### 2. Compose the hydra config

In [None]:
with initialize(version_base=None, config_path="configs", job_name="training"):
    cfg = compose(config_name="training")

OmegaConf.set_struct(cfg, False)
cfg = OmegaConf.merge(cfg, cfg.method)

### 3. Build the model

In [None]:
model = build_model(cfg)

### 4. Build the data loaders

In [None]:

train_set = build_dataset(cfg)
val_set = build_dataset(cfg, val=True)

train_batch_size = max(cfg.method["train_batch_size"] // len(cfg.devices), 1)
eval_batch_size = max(cfg.method["eval_batch_size"] // len(cfg.devices), 1)

train_loader = DataLoader(
    train_set,
    batch_size=train_batch_size,
    num_workers=cfg.train_load_num_workers,
    shuffle=cfg.shuffle,
    drop_last=False,
    pin_memory=cfg.pin_memory,
    persistent_workers=cfg.persistent_workers,
)

val_loader = DataLoader(
    val_set,
    batch_size=eval_batch_size,
    num_workers=cfg.val_load_num_workers,
    shuffle=False,
    drop_last=False,
    pin_memory=cfg.pin_memory,
    persistent_workers=cfg.persistent_workers,
)

### 5. Prepare PyTorch Lightning callbacks

In [None]:
call_backs = []

exp_name = f"{cfg.exp_name}-{str(uuid.uuid4())[:5]}"
checkpoint_callback = ModelCheckpoint(
    monitor="val/dice_loss",
    filename="{epoch}-{val/dice_loss:.2f}",
    save_top_k=5,
    mode="min",  # 'min' for loss/error, 'max' for accuracy
    dirpath=f"/mnt/hdd_pool_zion/userdata/diyor/ttt_ckpt/{exp_name}",
)
learning_rate_monitor = LearningRateMonitor(logging_interval="epoch")

call_backs.extend([checkpoint_callback, learning_rate_monitor])


### 6. Instantiate the Trainer

In [None]:
trainer = pl.Trainer(
    max_epochs=cfg.method.max_epochs,
    logger=(
        None
        if cfg.debug
        else WandbLogger(
            project=cfg.wandb_project_name,
            name=exp_name,
            id=exp_name,
            config=OmegaConf.to_container(cfg),
        )
    ),
    devices=1 if cfg.debug else cfg.devices,
    gradient_clip_val=cfg.gradient_clip_val,
    accumulate_grad_batches=cfg.accumulate_grad_batches,
    accelerator="cpu" if cfg.debug else "gpu",
    profiler=cfg.profiler,
    strategy=cfg.strategy,
    callbacks=call_backs,
    check_val_every_n_epoch=cfg.check_val_every_n_epochs,
    log_every_n_steps=cfg.log_every_n_steps,
    num_sanity_val_steps=cfg.num_sanity_val_steps,
    enable_progress_bar=cfg.enable_progress_bar,
)

### 7. Train the model

In [None]:
trainer.fit(
    model=model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    ckpt_path=cfg.ckpt_path,
)

# Test-time training

### 1. Compose the hydra config

In [None]:
with initialize(version_base=None, config_path="configs", job_name="ttt"):
    cfg = compose(config_name="ttt")

OmegaConf.set_struct(cfg, False)
cfg = OmegaConf.merge(cfg, cfg.method)

### 2. Build the model

In [None]:
model = build_model(cfg)

### 3. Build the data loaders

In [None]:
train_set = build_dataset(cfg)
val_set = build_dataset(cfg, val=True)

train_batch_size = max(cfg.method["train_batch_size"] // len(cfg.devices), 1)
eval_batch_size = max(cfg.method["eval_batch_size"] // len(cfg.devices), 1)

train_loader = DataLoader(
    train_set,
    batch_size=train_batch_size,
    num_workers=cfg.train_load_num_workers,
    shuffle=cfg.shuffle,
    drop_last=False,
    pin_memory=cfg.pin_memory,
    persistent_workers=cfg.persistent_workers,
)

val_loader = DataLoader(
    val_set,
    batch_size=eval_batch_size,
    num_workers=cfg.val_load_num_workers,
    drop_last=False,
    pin_memory=cfg.pin_memory,
    persistent_workers=cfg.persistent_workers,
)

### 4. Build PyTorch Lightning callbacks

In [None]:
call_backs = []

exp_name = f"{cfg.exp_name}-{str(uuid.uuid4())[:5]}"
checkpoint_callback = ModelCheckpoint(
    monitor="train/dice_loss",
    filename="{epoch}-{train/dice_loss:.2f}",
    save_top_k=5,
    mode="min",  # 'min' for loss/error, 'max' for accuracy
    dirpath=f"/mnt/hdd_pool_zion/userdata/diyor/ttt_ckpt/{exp_name}",
)
learning_rate_monitor = LearningRateMonitor(logging_interval="epoch")

call_backs.extend([checkpoint_callback, learning_rate_monitor])


### 5. Instantiate the Trainer

In [None]:
trainer = pl.Trainer(
    max_steps=cfg.method.max_steps,
    logger=(
        None
        if cfg.debug
        else WandbLogger(
            project=cfg.wandb_project_name,
            name=exp_name,
            id=exp_name,
            config=OmegaConf.to_container(cfg),
        )
    ),
    devices=1 if cfg.debug else cfg.devices,
    gradient_clip_val=cfg.gradient_clip_val,
    accumulate_grad_batches=cfg.accumulate_grad_batches,
    accelerator="cpu" if cfg.debug else "gpu",
    profiler=cfg.profiler,
    strategy=cfg.strategy,
    callbacks=call_backs,
    check_val_every_n_epoch=cfg.check_val_every_n_epochs,
    log_every_n_steps=cfg.log_every_n_steps,
    num_sanity_val_steps=cfg.num_sanity_val_steps,
    enable_progress_bar=cfg.enable_progress_bar,
)

### 6. [IMPORTANT] Load the pre-trained checkpoint

In [None]:
model = model.__class__.load_from_checkpoint(
    cfg.ckpt_path, config=cfg, map_location=torch.device("cpu")
)

### 7. Run a single validation epoch to measure raw performance

In [None]:
trainer.validate(
    model=model,
    dataloaders=val_loader,
)

### 8. Launch test-time training

In [None]:
trainer.fit(
    model=model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)