# 0. Imports

In [3]:
from sampling.conditional_probability_path import GaussianConditionalProbabilityPath
from sampling.sampleable import PixelArtSampler
from sampling.noise_scheduling import LinearAlpha, LinearBeta
from models.unet import PixelArtUNet
from training.trainer import UnguidedTrainer
from diff_eq.ode_sde import UnguidedVectorFieldODE
from diff_eq.simulator import EulerSimulator
from evaluation import FID, InceptionV3FeatureExtractor

import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

device = 'cpu' #torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Train

In [4]:
if 'unet' in locals():
    del unet
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

# Initialize probability path
sampler = PixelArtSampler('../dataset/images').to(device)
path = GaussianConditionalProbabilityPath(
    p_data=sampler,
    p_simple_shape = [4, 128, 128],
    alpha = LinearAlpha(),
    beta = LinearBeta()
).to(device)

# Initialize model
unet = PixelArtUNet(
    channels = [32, 64, 128],
    num_residual_layers = 2,
    t_embed_dim = 40
)

# Initialize evaluation metric
inception_extractor = InceptionV3FeatureExtractor()
metric = FID(
    inception_extractor=InceptionV3FeatureExtractor(),
    extractor_input_img_size=(299, 299)
)

# Initialize trainer
trainer = UnguidedTrainer(
    path=path,
    model=unet,
    experiment_dir="experiments/unet",
    eval_metric=metric
)

# Train :D
trainer.train(
    device=device,
    num_epochs = 3,
    batch_size=1,
    lr=1e-3,
    validate_every=1,
    resume=False,
    lr_warmup_steps_frac=0.1,
    num_images_to_save=5,
    save_images_every=1
)

KeyboardInterrupt: 

# 2. Visualize results

In [None]:
# Parameters
num_samples = 10
num_timesteps = 100

# Simulate
ts = torch.linspace(0, 1, num_timesteps).view(1, -1, 1, 1, 1).expand(num_samples, -1, 1, 1, 1).to(device)
x0 = path.p_simple.sample(num_samples).to(device)  # (num_samples, 4, 128, 128)

# Run simulation
ode = UnguidedVectorFieldODE(unet)
simulator = EulerSimulator(ode)
x1 = simulator.simulate(x0, ts)  # (num_samples, 4, 128, 128)

# Make grid from output (only use first 3 channels if RGB)
img = x1[:, :3]  # (num_samples, 3, H, W)
grid = make_grid(img, nrow=num_samples, normalize=True, value_range=(-1, 1))

# Plot
plt.figure(figsize=(20, 2))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.tight_layout()
plt.show()

True
4
