In [None]:
'''
Import and hyperparams
'''
import os
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import navit_rf as nrf
from navit_rf import make_packing_collate, sample_rectified_flow, random_resized_transform

# --- paths ---
DATA_ROOT = Path("/path/to/your/images")        # change me
CKPT_DIR = Path("experiments/navit_rf/outputs/checkpoints")
SAMPLE_DIR = Path("experiments/navit_rf/outputs/samples")
CKPT_DIR.mkdir(parents=True, exist_ok=True)
SAMPLE_DIR.mkdir(parents=True, exist_ok=True)

# --- hyperparameters ---
BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 20
PRINT_EVERY = 50         # batches between loss prints
SAMPLE_EVERY = 5         # epochs between generations
NOISE_STD = 1.0          # matches Gaussian anchor
N_SAMPLE_IMAGES = 8
SAMPLER_STEPS = 100
PATCH_SIZE = 8
MAX_TOKENS_PER_PACK = 512
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
'''
Dataset, Dataloader, Optimizer
'''

img_paths = nrf.gather_image_paths(DATA_ROOT)
transform = random_resized_transform(
    noise_std=NOISE_STD,
    scale_range=(0.2, 1.0),  # scales H and W by the same factor, keeps aspect ratio
)
dataset = nrf.ImageDataset(img_paths, transform=transform)

collate_fn = make_packing_collate(
    patch_size=PATCH_SIZE,
    max_tokens_per_pack=MAX_TOKENS_PER_PACK,
)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn,
)

model = nrf.ViTVelocity(
    patch=PATCH_SIZE,
    in_ch=3,
    d_model=256,
    depth=8,
    n_head=8,
    mlp_ratio=4.0,
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = torch.nn.MSELoss()

In [None]:
'''
Training
'''
global_step = 0
for epoch in range(1, EPOCHS + 1):
    model.train()
    running = 0.0

    for step, batch in enumerate(dataloader, start=1):
        images = batch["images"].to(DEVICE)
        patch_hw = batch["patch_hw"].to(DEVICE)
        packs = batch["packs"]

        x0 = torch.randn_like(images) * NOISE_STD
        t = torch.rand(images.size(0), device=DEVICE)
        xt = nrf.linear_probability_path(x0, images, t)
        target = nrf.velocity_target(x0, images)

        optimizer.zero_grad(set_to_none=True)
        preds = model(xt, t, patch_hw=patch_hw, packs=packs)
        loss = criterion(preds, target)
        loss.backward()
        optimizer.step()

        running += loss.item()
        global_step += 1
        if step % PRINT_EVERY == 0:
            avg = running / PRINT_EVERY
            print(f"[epoch {epoch}/{EPOCHS}] step {step}: loss = {avg:.4f}")
            running = 0.0

    ckpt_path = CKPT_DIR / f"vit_velocity_epoch{epoch:04d}.pth"
    torch.save(
        {"epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict()},
        ckpt_path,
    )
    print(f"Saved checkpoint to {ckpt_path}")

    if epoch % SAMPLE_EVERY == 0:
        samples = sample_rectified_flow(
            model,
            n=N_SAMPLE_IMAGES,
            device=DEVICE,
            img_size=images.shape[-1],  # uses current padded size
            steps=SAMPLER_STEPS,
        )
        grid_path = SAMPLE_DIR / f"samples_epoch{epoch:04d}.png"
        save_image(samples, grid_path, nrow=int(math.sqrt(N_SAMPLE_IMAGES)), normalize=False)
        print(f"Saved samples to {grid_path}")