In [None]:
import torch
from flow_matching.supervised.alphas_betas import LinearAlpha, LinearBeta
from flow_matching.supervised.prob_paths import GaussianConditionalProbabilityPath
from flow_matching.supervised.training import CFGTrainer
from flow_matching.mnist.sampler import MNISTSampler
from flow_matching.mnist.unet import MNISTUNet

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
path = GaussianConditionalProbabilityPath(
    p_data=MNISTSampler(),
    p_simple_shape=[1, 32, 32],
    alpha=LinearAlpha(),
    beta=LinearBeta(),
).to(device)

unet = MNISTUNet(
    channels=[16, 32, 64],  # [32, 64, 128],
    num_residual_layers=2,
    t_embed_dim=40,
    y_embed_dim=8,  # 40,
).to(device)

trainer = CFGTrainer(path=path, model=unet, eta=0.1, null_class=10)

In [4]:
trainer.train(num_epochs=1000, device=device, lr=1e-3, batch_size=32)

Training model with size: 1.231 MiB


Epoch 999, loss: 0.171: 100%|██████████| 1000/1000 [04:39<00:00,  3.57it/s]


In [5]:
torch.save(unet.state_dict(), "unet.pt")