In [None]:

'''
Import and hyperparams
'''
import math
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import navit_rf as nrf
from navit_rf import (
    make_packing_collate,
    make_reflow_collate,
    sample_rectified_flow,
    random_resized_transform,
    build_or_load_reflow_dataset,
)

# --- paths ---
DATA_ROOT = Path("../../../cat/")        # change me
CKPT_DIR = Path("experiments/navit_rf/outputs2/checkpoints")
SAMPLE_DIR = Path("experiments/navit_rf/outputs2/samples")
CKPT_DIR.mkdir(parents=True, exist_ok=True)
SAMPLE_DIR.mkdir(parents=True, exist_ok=True)

# --- hyperparameters ---
BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 1
PRINT_EVERY = 50         # batches between loss prints
SAMPLE_EVERY = 1         # epochs between generations
NOISE_STD = 1.0          # matches Gaussian anchor
N_SAMPLE_IMAGES = 8
SAMPLER_STEPS = 100
PATCH_SIZE = 4
MAX_TOKENS_PER_PACK = 512
DEVICE = 'mps'#torch.device("cuda" if torch.cuda.is_available() else "cpu")
REFLOW_DIR = DATA_ROOT.parent / f"{DATA_ROOT.name}_reflow"


In [None]:
'''
Dataset, Dataloader, Optimizer
'''

img_paths = nrf.gather_image_paths(DATA_ROOT)
transform = random_resized_transform(
    noise_std=NOISE_STD,
    scale_range=(0.0625, 0.125),  # scales H and W by the same factor, keeps aspect ratio
)
dataset = nrf.ImageDataset(img_paths, transform=transform)

collate_fn = make_packing_collate(
    patch_size=PATCH_SIZE,
    max_tokens_per_pack=MAX_TOKENS_PER_PACK,
)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn,
)

model = nrf.ViTVelocity(
    patch=PATCH_SIZE,
    in_ch=3,
    d_model=512,
    depth=8,
    n_head=8,
    mlp_ratio=4.0,
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = torch.nn.MSELoss()

In [None]:
ckpt = torch.load('/Users/giodegeronimo/Desktop/595CV/ece595-term-paper/experiments/navit_rf/experiments/navit_rf/outputs/checkpoints/vit_velocity_epoch0260.pth')
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])

In [None]:

'''
Standard training (saves latest & best)
'''
best_loss = float('inf')
latest_ckpt = CKPT_DIR / 'latest.pth'
best_ckpt = CKPT_DIR / 'best.pth'


with torch.no_grad():
    init_samples = sample_rectified_flow(
        model,
        n=N_SAMPLE_IMAGES,
        device=DEVICE,
        img_size=dataset[0].shape[-1] if isinstance(dataset[0], torch.Tensor) else dataset[0][0].shape[-1],
        steps=SAMPLER_STEPS,
    )
init_grid = SAMPLE_DIR / 'samples_epoch0000.png'
save_image(init_samples, init_grid, nrow=int(math.sqrt(N_SAMPLE_IMAGES)), normalize=False)
print(f'Saved initial samples to {init_grid}')

for epoch in range(1, EPOCHS + 1):
    model.train()
    running = 0.0
    total = 0.0
    steps = 0
    for step, batch in enumerate(dataloader, start=1):
        steps = step
        images = batch['images'].to(DEVICE)
        patch_hw = batch['patch_hw'].to(DEVICE)
        packs = batch['packs']

        x0 = torch.randn_like(images) * NOISE_STD
        t = torch.rand(images.size(0), device=DEVICE)
        xt = nrf.linear_probability_path(x0, images, t)
        target = images - x0

        optimizer.zero_grad(set_to_none=True)
        preds = model(xt, t, patch_hw=patch_hw, packs=packs)
        loss = criterion(preds, target)
        loss.backward()
        optimizer.step()

        running += loss.item()
        total += loss.item()
        if step % PRINT_EVERY == 0:
            avg = running / PRINT_EVERY
            print(f"[epoch {epoch}/{EPOCHS}] step {step}: loss = {avg:.4f}")
            running = 0.0

    epoch_loss = total / max(steps, 1)
    torch.save({'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, latest_ckpt)
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save({'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, best_ckpt)
        print(f'best checkpoint updated at epoch {epoch}')

    if epoch % SAMPLE_EVERY == 0:
        samples = sample_rectified_flow(
            model,
            n=N_SAMPLE_IMAGES,
            device=DEVICE,
            img_size=images.shape[-1],
            steps=SAMPLER_STEPS,
        )
        grid_path = SAMPLE_DIR / f'samples_epoch{epoch:04d}.png'
        save_image(samples, grid_path, nrow=int(math.sqrt(N_SAMPLE_IMAGES)), normalize=False)
        print(f'Saved samples to {grid_path}')


In [None]:

'''
Build/load reflow dataset
'''
REFLOW_PAIRS = 1000
REFLOW_STEPS = 10
REFLOW_TAG = 'reflow1'  # change or set to None for timestamped folders
model.load_state_dict(torch.load(best_ckpt, map_location=DEVICE)['model'])
reflow_ds, reflow_file = build_or_load_reflow_dataset(
    model,
    dataset,
    device=DEVICE,
    noise_std=NOISE_STD,
    reflow_pairs=REFLOW_PAIRS,
    reflow_steps=REFLOW_STEPS,
    reflow_dir=REFLOW_DIR,
    tag=REFLOW_TAG,
)
print(f'Reflow dataset stored at {reflow_file}')
reflow_collate = make_reflow_collate(PATCH_SIZE)
reflow_loader = DataLoader(
    reflow_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    collate_fn=reflow_collate,
)


In [None]:

'''
Reflow-only training
'''
SAMPLER_STEPS=1
REFLOW_EPOCHS = 10
reflow_optimizer = torch.optim.Adam(model.parameters(), lr=LR)
best_reflow_loss = float('inf')
reflow_latest = CKPT_DIR / 'reflow_latest.pth'
reflow_best = CKPT_DIR / 'reflow_best.pth'


with torch.no_grad():
    init_reflow_samples = sample_rectified_flow(
        model,
        n=min(N_SAMPLE_IMAGES, len(reflow_ds)),
        device=DEVICE,
        img_size=reflow_ds[0]['shape'][0],
        steps=SAMPLER_STEPS,
    )
reflow_grid0 = SAMPLE_DIR / 'reflow_samples_epoch0000.png'
save_image(init_reflow_samples, reflow_grid0, nrow=int(math.sqrt(min(N_SAMPLE_IMAGES, len(reflow_ds)))), normalize=False)
print(f'Saved initial reflow samples to {reflow_grid0}')

for epoch in range(1, REFLOW_EPOCHS + 1):
    running = 0.0
    total = 0.0
    steps = 0
    for step, batch in enumerate(reflow_loader, start=1):
        steps = step
        x0 = batch['x0'].to(DEVICE)
        target = batch['target'].to(DEVICE)
        images = x0 + target
        patch_hw = batch['patch_hw'].to(DEVICE)
        orig_hw = batch['orig_hw'].to(DEVICE)
        packs = batch['packs']

        reflow_optimizer.zero_grad(set_to_none=True)
        t = torch.rand(x0.size(0), device=DEVICE)
        xt = nrf.linear_probability_path(x0, images, t)
        preds = model(xt, t, patch_hw=patch_hw, packs=packs, orig_hw=orig_hw)
        loss = criterion(preds, target)
        loss.backward()
        reflow_optimizer.step()

        running += loss.item()
        total += loss.item()

    epoch_loss = total / max(steps, 1)
    print(f'[reflow epoch {epoch}/{REFLOW_EPOCHS}] loss = {epoch_loss:.4f}')
    torch.save({'epoch': epoch, 'model': model.state_dict(), 'optimizer': reflow_optimizer.state_dict()}, reflow_latest)
    if epoch_loss < best_reflow_loss:
        best_reflow_loss = epoch_loss
        torch.save({'epoch': epoch, 'model': model.state_dict(), 'optimizer': reflow_optimizer.state_dict()}, reflow_best)
        print('updated best reflow checkpoint')
    if epoch % SAMPLE_EVERY == 0:
        samples = sample_rectified_flow(
            model,
            n=N_SAMPLE_IMAGES,
            device=DEVICE,
            img_size=images.shape[-1],
            steps=SAMPLER_STEPS,
        )
        grid_path = SAMPLE_DIR / f'reflow_samples_epoch{epoch:04d}.png'
        save_image(samples, grid_path, nrow=int(math.sqrt(N_SAMPLE_IMAGES)), normalize=False)
        print(f'Saved reflow samples to {grid_path}')
