# 0. Imports

In [1]:
from sampling.conditional_probability_path import GaussianConditionalProbabilityPath
from sampling.sampleable import PixelArtSampler
from sampling.noise_scheduling import LinearAlpha, LinearBeta
from models.unet import PixelArtUNet
from training.trainer import UnguidedTrainer
from diff_eq.ode_sde import UnguidedVectorFieldODE
from diff_eq.simulator import EulerSimulator

import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Train

In [None]:
# Initialize probability path
path = GaussianConditionalProbabilityPath(
    p_data = PixelArtSampler('../dataset/images'),
    p_simple_shape = [4, 128, 128],
    alpha = LinearAlpha(),
    beta = LinearBeta()
).to(device)

# Initialize model
unet = PixelArtUNet(
    channels = [64, 128, 256, 512],
    num_residual_layers = 4,
    t_embed_dim = 40
)

# Initialize trainer
trainer = UnguidedTrainer(path=path, model=unet)

# Train!
trainer.train(num_epochs = 1, device=device, lr=1e-3, batch_size=1)

Model size: 132.3578 MiB


  0%|          | 0/1 [00:00<?, ?it/s]

# 2. Visualize results

In [None]:
num_samples = 10
num_timesteps = 100

# Graph
fig, ax = plt.plot()

# Setup ode and simulator
ode = UnguidedVectorFieldODE(unet)
simulator = EulerSimulator(ode)

# Sample initial conditions
x0 = path.p_simple.sample(num_samples) # (num_samples, 4, 128, 128)

# Simulate
ts = torch.linspace(0,1,num_timesteps).view(1, -1, 1, 1, 1).expand(num_samples, -1, 1, 1, 1).to(device)
x1 = simulator.simulate(x0, ts,)

# Plot
grid = make_grid(x1, nrow=num_samples, normalize=True, value_range=(-1,1))
ax.imshow(grid.permute(1, 2, 0).cpu())
ax.axis("off")
plt.show()