In [3]:
import sys
sys.path.append("./src")

from models.vdm import VDM
from models.vdm_unet import UNetVDM
from models.encoder import Encoder
from utils.training import Trainer
from utils.evaluation import Evaluator
from utils.logging import init_logger
from utils.utils import (
    Config,
    make_cifar,
    make_mnist,
)
from accelerate import Accelerator
from accelerate.utils import set_seed
import yaml
import torch

# Load config file

In [2]:
CONFIG_FILE = "examples/config.yaml"
DATA_PATH = "data"

with open(CONFIG_FILE, "r") as f:
    cfg = Config(**yaml.safe_load(f))

if cfg.use_mnist:
    cfg.input_channels = 1
    shape = (cfg.input_channels, 28, 28)
    train_set = make_mnist(train=True, download=True, root_path=DATA_PATH)
    validation_set = make_mnist(train=False, download=False, root_path=DATA_PATH)
else:
    cfg.input_channels = 3
    shape = (cfg.input_channels, 32, 32)
    train_set = make_cifar(train=True, download=True, root_path=DATA_PATH)
    validation_set = make_cifar(train=False, download=False, root_path=DATA_PATH)

set_seed(cfg.seed)

model = UNetVDM(cfg)
encoder = Encoder(shape, cfg) if cfg.use_encoder else None
diffusion = VDM(model, cfg, image_shape=shape, encoder=encoder)

# Training a model

In [None]:
accelerator = Accelerator(split_batches=True)
init_logger(accelerator)

Trainer(
    diffusion,
    train_set,
    validation_set,
    accelerator,
    make_opt=lambda params: torch.optim.AdamW(
        params, cfg.lr, betas=(0.9, 0.99), weight_decay=cfg.weight_decay, eps=1e-8
    ),
    config=cfg,
).train()

# Evaluating a model

In [None]:
Evaluator(
    diffusion,
    train_set,
    validation_set,
    config=cfg,
).eval()