# 0. Imports

In [None]:
%cd ../../..

In [2]:
import torch

from sampling.conditional_probability_path import GaussianConditionalProbabilityPath
from sampling.sampleable import PixelArtSampler
from sampling.noise_scheduling import LinearAlpha, LinearBeta
from models.unet import PixelArtUNet
from training.trainer import UnguidedTrainer
from training.evaluation import FID
from training.ema import EMA
from diff_eq.ode_sde import UnguidedVectorFieldODE
from diff_eq.simulator import EulerSimulator
from utils.visualization import visualize_training_logs, plot_generated_images
from utils.helpers import clear_cuda, tensor_to_rgba_image, rgba_to_rgb, save_generated_assets, normalize_to_unit


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Train

In [None]:
clear_cuda()

# Initialize probability path
sampler = PixelArtSampler('dataset/images').to(device)
path = GaussianConditionalProbabilityPath(
    p_data=sampler,
    p_simple_shape = [4, 128, 128],
    alpha = LinearAlpha(),
    beta = LinearBeta()
).to(device)

# Initialize model
unet = PixelArtUNet(
    channels = [128, 256, 512],
    num_residual_layers = 2,
    t_embed_dim = 128
)

ema = EMA(model=unet, max_decay=0.9998)

# Initialize evaluation metric
metric = FID(
    feature=2048,
    normalize=True,
    image_size=(299, 299)
)

# Initialize trainer
trainer = UnguidedTrainer(
    path=path,
    model=unet,
    experiment_dir="training/experiments/unet",
    eval_metric=metric,
    ema=ema,
)

# Train :D
trainer.train(
    device=device,
    num_epochs=17000,
    batch_size=64,
    lr=1e-4,
    lr_warmup_steps_frac=0.1,
    val_batch_size=128,
    num_val_batches=4,
    validate_every=250,
    val_timesteps=100,
    resume=False,
    num_images_to_save=10,
    save_images_every=250,
)

# 2. Visualize model performance over time

In [None]:
visualize_training_logs(log_path="training/experiments/unet/training_log.csv")

# 3. Visualize results

In [None]:
num_samples = 4
num_timesteps = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load model from checkpoint ---
cp = torch.load("training/experiments/unet/best_model.pt", map_location=device)

model = PixelArtUNet(
    channels = [128, 256, 512],
    num_residual_layers = 2,
    t_embed_dim = 128
).to(device)
model.eval()

# Restore EMA weights
ema = EMA(model)
ema.load_state_dict(cp["ema_state"])
ema.apply_shadow()   # now model has EMA weights

# --- Setup path ---
sampler = PixelArtSampler('dataset/images').to(device)
path = GaussianConditionalProbabilityPath(
    p_data=sampler,
    p_simple_shape=[4, 128, 128],
    alpha=LinearAlpha(),
    beta=LinearBeta()
).to(device)
path.eval()

# --- Simulate ---
ts = torch.linspace(0, 1, num_timesteps).view(1, -1, 1, 1, 1).expand(num_samples, -1, 1, 1, 1).to(device)
x0 = path.p_simple.sample(num_samples).to(device)  # (num_samples, 4, 128, 128)
ode = UnguidedVectorFieldODE(model)
simulator = EulerSimulator(ode)
x1 = simulator.simulate(x0, ts)  # (num_samples, 4, 128, 128)

x1 = normalize_to_unit(x1) # [-1, 1] -> [0, 1]

# Save to assets
imgs = tensor_to_rgba_image(x1)
save_generated_assets(images=imgs, num_timesteps=num_timesteps)

# Convert to RGB
x1_rgb = rgba_to_rgb(x1)
imgs_rgb = tensor_to_rgba_image(x1_rgb)
plot_generated_images(imgs)

# 4. Evaluate on the test set

In [None]:
# Initialize model
model = PixelArtUNet(
    channels=[128, 256, 512],
    num_residual_layers=2,
    t_embed_dim=128,
).to(device)
model.eval()

# --- Load checkpoint with EMA ---
checkpoint_path = "training/experiments/unet_mid_ema/best_model.pt"
cp = torch.load(checkpoint_path, map_location=device)

ema.load_state_dict(cp["ema_state"])

# Path
sampler = PixelArtSampler('dataset/images').to(device)
path = GaussianConditionalProbabilityPath(
    p_data=sampler,
    p_simple_shape=[4, 128, 128],
    alpha=LinearAlpha(),
    beta=LinearBeta()
).to(device)
path.eval()

# Evaluation metric
metric = FID(
    feature=2048,
    normalize=True,
    image_size=(299, 299)
)

# Trainer
trainer = UnguidedTrainer(
    path=path,
    model=model,
    experiment_dir="training/experiments/unet_mid_ema",
    eval_metric=metric,
    ema=ema
)

# Evaluate on test set
test_fid = trainer.evaluate(
    batch_size=128,
    device=device,
    num_timesteps=100,
    mode="test"
)

print(f"FID on the test set (100 timesteps): {test_fid:.4f}")