<a href="https://colab.research.google.com/github/hmeyer/flow_matching/blob/main/Mnist_Flow_Matching.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
!pip install -q mediapy

The operation couldn’t be completed. Unable to locate a Java Runtime.
Please visit http://www.java.com for information on installing Java.

zsh:1: command not found: pip


## Imports

In [6]:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm.auto import tqdm
import math
from mediapy import show_image, show_images, show_video
import time
from collections import defaultdict
import einops

## Setup Accelerator

In [7]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

using device: mps


## Prepare dataset

In [8]:
train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transforms.ToTensor()
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transforms.ToTensor()
)
test_dataset = Subset(test_dataset, range(1_000))

batch_size = 32

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=2, pin_memory=True
)

print("train")
show_images(next(iter(train_loader))[0].squeeze(1))
print("test")
show_images(next(iter(test_loader))[0].squeeze(1))

100.0%
100.0%
100.0%
100.0%

train





test


## Flow Matching Model

In [10]:
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        if self.dim % 2 == 1: # zero pad if dim is odd for a final concat
            pad_tensor = torch.zeros((embeddings.shape[0], 1), device=device)
            embeddings = torch.cat((embeddings.sin(), embeddings.cos(), pad_tensor), dim=-1)
        else:
            embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class MLP(nn.Module):
    def __init__(self, time_embedding_dim, hidden_dim, out_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(time_embedding_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, t_emb):
        return self.mlp(t_emb)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim=None, num_groups=8): # Made time_emb_dim optional
        super().__init__()
        # Ensure num_groups is valid
        effective_num_groups1 = min(num_groups, in_channels // 4 if in_channels >= 4 else 1) if in_channels > 0 else 1 # Use 1 if in_channels is 1, 2, 3
        if in_channels > 0 and in_channels < effective_num_groups1 : # ensure num_channels >= num_groups
            effective_num_groups1 = in_channels

        effective_num_groups2 = min(num_groups, out_channels // 4 if out_channels >= 4 else 1) if out_channels > 0 else 1
        if out_channels > 0 and out_channels < effective_num_groups2:
            effective_num_groups2 = out_channels


        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        # GroupNorm needs num_channels > 0 and num_groups <= num_channels
        self.norm1 = nn.GroupNorm(num_groups=effective_num_groups2, num_channels=out_channels) if out_channels > 0 else nn.Identity()
        self.act1 = nn.SiLU()

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(num_groups=effective_num_groups2, num_channels=out_channels) if out_channels > 0 else nn.Identity()
        self.act2 = nn.SiLU()

        # Time projection only if time_emb_dim is provided and > 0
        self.time_proj = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None and time_emb_dim > 0 else None

        self.res_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, t_emb=None): # Made t_emb optional
        h = self.conv1(x)
        h = self.norm1(h)

        # Add time embedding if available and projected
        if self.time_proj is not None and t_emb is not None:
            time_cond = self.time_proj(t_emb)
            # Ensure time_cond is not None (it shouldn't be if self.time_proj exists)
            if time_cond is not None:
                h = h + time_cond[:, :, None, None] # Expand dims for broadcasting

        h = self.act1(h)

        h = self.conv2(h)
        h = self.norm2(h)
        h = self.act2(h)

        return h + self.res_conv(x) # Residual connection


class SelfAttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True)

    def forward(self, x):
      s = x.shape
      t = einops.rearrange(x, 'b c h w -> b (h w) c')
      t = self.attn(t, t, t)[0]
      return x + einops.rearrange(t, 'b (h w) c -> b c h w', c=s[1], h=s[2])


class DownBlock(nn.Module):
    # Modified to make time_emb_dim optional for ConvBlock
    def __init__(self, in_channels, out_channels, time_emb_dim=None, num_groups_norm=8):
        super().__init__()
        self.conv = ConvBlock(in_channels, out_channels, time_emb_dim, num_groups_norm)
        # MaxPool before Conv is common in some UNets, after in others. Let's keep it before like typical ResNets.
        # Or, is the intention from the original UNet structure Pool -> Conv? Let's assume Pool -> Conv
        self.pool = nn.MaxPool2d(2)

    def forward(self, x, t_emb=None): # Made t_emb optional
        x = self.pool(x)
        x = self.conv(x, t_emb)
        return x

class UpBlock(nn.Module):

    def __init__(self, in_channels_prev_up, channels_from_skip, out_channels_conv, time_emb_dim, num_groups_norm):
        super().__init__()
        # Upsample the features from the previous (lower) level
        self.up_conv_transpose = nn.ConvTranspose2d(in_channels_prev_up, in_channels_prev_up // 2, kernel_size=2, stride=2)
        # The ConvBlock will take the concatenated (upsampled features + skip features)
        # channels_from_skip are the channels from the encoder path
        # in_channels_prev_up // 2 are the channels after upsampling
        self.conv = ConvBlock( (in_channels_prev_up // 2) + channels_from_skip, out_channels_conv, time_emb_dim, num_groups_norm)

    def forward(self, x_prev_up, x_skip, t_emb):
        x_upsampled = self.up_conv_transpose(x_prev_up)
        # Ensure spatial dimensions match for concatenation (can be an issue with odd dimensions if not handled by padding in ConvTranspose2d or pool)
        # For standard MNIST 28->14->7, then 7->14->28, this should be fine.
        # If there's a mismatch:
        # diffY = x_skip.size()[2] - x_upsampled.size()[2]
        # diffX = x_skip.size()[3] - x_upsampled.size()[3]
        # x_upsampled = F.pad(x_upsampled, [diffX // 2, diffX - diffX // 2,
        #                                   diffY // 2, diffY - diffY // 2])

        x_cat = torch.cat([x_upsampled, x_skip], dim=1)
        x = self.conv(x_cat, t_emb)
        return x


class UNetFlowMatcherMNIST(nn.Module):
    def __init__(self,
                 in_channels=1,
                 out_channels=1,
                 base_channels=32, # C
                 time_embedding_dim=128, # D_t
                 time_mlp_hidden_dim=512,
                 num_groups_norm=8,
                 use_attention_bottleneck=True):
        super().__init__()

        if base_channels <= 0:
            raise ValueError("base_channels must be positive.")
        if time_embedding_dim <= 0 and time_mlp_hidden_dim > 0 : # If no time embedding, MLP for it is not needed.
             time_embedding_dim = 0 # ensure consistency
             time_mlp_hidden_dim = 0


        self.time_embedding = SinusoidalTimeEmbedding(time_embedding_dim) if time_embedding_dim > 0 else nn.Identity()
        self.time_mlp = MLP(time_embedding_dim, time_mlp_hidden_dim, time_embedding_dim) if time_embedding_dim > 0 else nn.Identity()


        # Initial convolution
        self.conv_in = ConvBlock(in_channels, base_channels, time_embedding_dim, num_groups_norm) # HxW -> HxW (C)

        # Encoder
        self.down1 = DownBlock(base_channels, base_channels * 2, time_embedding_dim, num_groups_norm)     # HxW (C) -> H/2 x W/2 (2C)
        self.down2 = DownBlock(base_channels * 2, base_channels * 4, time_embedding_dim, num_groups_norm) # H/2 x W/2 (2C) -> H/4 x W/4 (4C)

        # Bottleneck
        # It operates on the output of down2 (4C channels)
        self.bottleneck_conv1 = ConvBlock(base_channels * 4, base_channels * 8, time_embedding_dim, num_groups_norm) # H/4 x W/4 (4C) -> H/4 x W/4 (8C)
        if use_attention_bottleneck:
            self.attention = SelfAttentionBlock(base_channels * 8, num_heads=4) # Operates on 8C channels
        else:
            self.attention = nn.Identity()
        self.bottleneck_conv2 = ConvBlock(base_channels * 8, base_channels * 4, time_embedding_dim, num_groups_norm) # H/4 x W/4 (8C) -> H/4 x W/4 (4C)

        # Decoder
        # Takes output from bottleneck (4C) and skip from down1 (2C)
        self.up1 = UpBlock(in_channels_prev_up=base_channels * 4, channels_from_skip=base_channels * 2, out_channels_conv=base_channels * 2, time_emb_dim=time_embedding_dim, num_groups_norm=num_groups_norm) # H/4 x W/4 (4C from bottleneck) + H/2 x W/2 (2C skip from s2) -> H/2 x W/2 (2C)

        # Takes output from up1 (2C) and skip from conv_in (C)
        self.up2 = UpBlock(in_channels_prev_up=base_channels * 2, channels_from_skip=base_channels, out_channels_conv=base_channels, time_emb_dim=time_embedding_dim, num_groups_norm=num_groups_norm)    # H/2 x W/2 (2C from up1) + H x W (C skip from s1) -> H x W (C)

        # Output
        self.conv_out = nn.Conv2d(base_channels, out_channels, kernel_size=1)

    def forward(self, x, t):
        # x: (batch_size, in_channels, H, W)
        # t: (batch_size,) scalar time values

        if isinstance(self.time_embedding, nn.Identity):
            t_emb = None # No time embedding used
        else:
            t_emb_sin = self.time_embedding(t)
            t_emb = self.time_mlp(t_emb_sin)

        # Encoder
        s1 = self.conv_in(x, t_emb)       # (B, C, H, W)
        s2 = self.down1(s1, t_emb)      # (B, 2C, H/2, W/2)
        s3 = self.down2(s2, t_emb)      # (B, 4C, H/4, W/4)

        # Bottleneck
        b = self.bottleneck_conv1(s3, t_emb)
        b = self.attention(b)
        b = self.bottleneck_conv2(b, t_emb) # Output channels 4C

        # Decoder
        u1 = self.up1(b, s2, t_emb)         # Input to up1: b (4C), skip s2 (2C) -> Output 2C
        u2 = self.up2(u1, s1, t_emb)        # Input to up2: u1 (2C), skip s1 (C)  -> Output C

        out = self.conv_out(u2)
        return out

## Flow Matching Loss and sampling

In [11]:
def flow_matching_loss(model, batch):
    x0 = batch[0]
    batch_size = x0.shape[0]
    t = torch.rand((batch_size,), device=x0.device)
    noise = torch.randn_like(x0)
    x_t = (1 - t[..., None, None, None]) * x0 + t[..., None, None, None] * noise
    # Forward pass
    x_hat = model(x_t, t)

    # Compute the loss
    loss = F.mse_loss(x_hat, x0)

    return {"loss": loss}


@torch.no_grad()
def sample_batch(model, noise, num_steps=10):
    model.eval()
    device = next(model.parameters()).device
    x = torch.zeros_like(noise)
    for i in range(num_steps):
        t = torch.tensor([(num_steps - i) / num_steps] * x.shape[0], device=device)
        x = (1 - t[..., None, None, None]) * x + t[..., None, None, None] * noise
        x = model(x, t)
    return x

## Trainer

In [13]:
class TrainLossLogger:
  def __init__(self, log_every_sec=10.0) -> None:
    self.running = defaultdict(list)
    self.log_every_sec = log_every_sec
    self.last_log = time.time()

  def log(self, data) -> None:
    for k, v in data.items():
      self.running[k].append(v)
    if self.last_log + self.log_every_sec < time.time():
      self.last_log = time.time()
      print("mean train:", {k: {float(np.array(v).mean())} for k, v in self.running.items()})
      self.running = defaultdict(list)


class Trainer:

    def __init__(self, *, model, train_loader, test_loader, loss_fn, extra_eval_fn=None, num_epochs=5, batch_size=32):
        self.model = model
        self.device = next(model.parameters()).device
        self.loss_fn = loss_fn
        self.extra_eval_fn = extra_eval_fn
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.train_loader = train_loader
        self.test_loader = test_loader

    def train(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)

        loss_logger = TrainLossLogger()

        for epoch in range(self.num_epochs):
            print(f"{epoch=} test {self.eval()}")
            if self.extra_eval_fn:
                self.extra_eval_fn(self.model)
            self.model.train()
            for i, batch in tqdm(enumerate(self.train_loader), desc=f"Epoch {epoch + 1}/{self.num_epochs}", total=len(self.train_loader), unit=" batches"):
                batch = [x.to(self.device) for x in batch]
                optimizer.zero_grad()
                loss_dict = self.loss_fn(self.model, batch)
                loss = loss_dict["loss"]
                loss.backward()
                optimizer.step()
                loss_logger.log({k: v.item() for k, v in loss_dict.items()})
        print(f"{epoch=} test {self.eval()}")
        if self.extra_eval_fn:
            self.extra_eval_fn(self.model)

    @torch.no_grad()
    def eval(self) -> float:
        torch.random.manual_seed(17)
        self.model.eval()
        total = defaultdict(list)
        for batch in self.test_loader:
            batch = [x.to(self.device) for x in batch]
            loss = self.loss_fn(self.model, batch)
            for k, v in loss.items():
              total[k].append(v.item())
        torch.random.seed()
        return {k: float(np.array(v).mean()) for k, v in total.items()}

## Train the Flow matching model

In [14]:
fm_model = UNetFlowMatcherMNIST()
fm_model.to(device)

def sample_and_viz(fm_model):
  torch.random.manual_seed(1)
  imgs = sample_batch(fm_model, noise=torch.randn((20, 1, 28, 28), device=device), num_steps=10).squeeze(1).cpu()
  show_images(imgs)
  _ = torch.random.seed()


trainer = Trainer(model=fm_model,
                  train_loader=train_loader,
                  test_loader=test_loader,
                  loss_fn=flow_matching_loss,
                  extra_eval_fn=sample_and_viz,
                  num_epochs=5)
trainer.train()

epoch=0 test {'loss': 0.21940060798078775}


Epoch 1/5:   2%|▏         | 40/1875 [00:05<02:34, 11.90 batches/s] 

mean train: {'loss': {0.05819222283286926}}


Epoch 1/5:   8%|▊         | 158/1875 [00:16<02:49, 10.13 batches/s]

mean train: {'loss': {0.038378454012493804}}


Epoch 1/5:  15%|█▍        | 276/1875 [00:26<02:14, 11.90 batches/s]

mean train: {'loss': {0.03610059860462354}}


Epoch 1/5:  21%|██        | 394/1875 [00:36<02:01, 12.16 batches/s]

mean train: {'loss': {0.034521285522760466}}


Epoch 1/5:  27%|██▋       | 512/1875 [00:46<01:53, 11.99 batches/s]

mean train: {'loss': {0.033901799564897}}


Epoch 1/5:  34%|███▍      | 634/1875 [00:56<01:43, 11.95 batches/s]

mean train: {'loss': {0.032935560011297216}}


Epoch 1/5:  40%|████      | 752/1875 [01:06<01:32, 12.20 batches/s]

mean train: {'loss': {0.03322646613785271}}


Epoch 1/5:  46%|████▌     | 861/1875 [01:16<01:41,  9.95 batches/s]

mean train: {'loss': {0.032564700452969705}}


Epoch 1/5:  52%|█████▏    | 979/1875 [01:26<01:14, 12.04 batches/s]

mean train: {'loss': {0.03292764743821601}}


Epoch 1/5:  59%|█████▊    | 1099/1875 [01:36<01:03, 12.19 batches/s]

mean train: {'loss': {0.03212654661417993}}


Epoch 1/5:  65%|██████▌   | 1221/1875 [01:46<00:54, 11.97 batches/s]

mean train: {'loss': {0.03200764110511985}}


Epoch 1/5:  72%|███████▏  | 1343/1875 [01:56<00:43, 12.09 batches/s]

mean train: {'loss': {0.03175225450855787}}


Epoch 1/5:  78%|███████▊  | 1463/1875 [02:06<00:34, 11.93 batches/s]

mean train: {'loss': {0.032098726838950285}}


Epoch 1/5:  84%|████████▍ | 1575/1875 [02:16<00:29, 10.05 batches/s]

mean train: {'loss': {0.03150306376669787}}


Epoch 1/5:  90%|█████████ | 1692/1875 [02:26<00:15, 12.17 batches/s]

mean train: {'loss': {0.031558683093894146}}


Epoch 1/5:  96%|█████████▋| 1806/1875 [02:36<00:06, 11.13 batches/s]

mean train: {'loss': {0.031331162883249004}}


Epoch 1/5: 100%|██████████| 1875/1875 [02:42<00:00, 11.51 batches/s]


epoch=1 test {'loss': 0.030956411734223366}


Epoch 2/5:   0%|          | 2/1875 [00:01<17:49,  1.75 batches/s]

mean train: {'loss': {0.031077024087102875}}


Epoch 2/5:   6%|▋         | 120/1875 [00:11<02:40, 10.91 batches/s]

mean train: {'loss': {0.031183274115546274}}


Epoch 2/5:  12%|█▏        | 233/1875 [00:21<02:19, 11.80 batches/s]

mean train: {'loss': {0.03086968371644616}}


Epoch 2/5:  18%|█▊        | 342/1875 [00:31<02:32, 10.07 batches/s]

mean train: {'loss': {0.030334588345311105}}


Epoch 2/5:  24%|██▍       | 449/1875 [00:41<02:36,  9.12 batches/s]

mean train: {'loss': {0.03080903123029404}}


Epoch 2/5:  30%|███       | 565/1875 [00:51<01:49, 11.99 batches/s]

mean train: {'loss': {0.03146296204870631}}


Epoch 2/5:  37%|███▋      | 687/1875 [01:01<01:38, 12.03 batches/s]

mean train: {'loss': {0.031189970610555538}}


Epoch 2/5:  43%|████▎     | 807/1875 [01:11<01:28, 12.02 batches/s]

mean train: {'loss': {0.03127956844496826}}


Epoch 2/5:  50%|████▉     | 929/1875 [01:21<01:17, 12.23 batches/s]

mean train: {'loss': {0.030798884963647265}}


Epoch 2/5:  56%|█████▌    | 1049/1875 [01:31<01:09, 11.89 batches/s]

mean train: {'loss': {0.030769849994352887}}


Epoch 2/5:  62%|██████▏   | 1169/1875 [01:41<00:58, 12.05 batches/s]

mean train: {'loss': {0.03098975297463827}}


Epoch 2/5:  69%|██████▉   | 1291/1875 [01:51<00:48, 12.04 batches/s]

mean train: {'loss': {0.030489741508995204}}


Epoch 2/5:  75%|███████▌  | 1411/1875 [02:01<00:38, 12.07 batches/s]

mean train: {'loss': {0.030555212331458557}}


Epoch 2/5:  82%|████████▏ | 1533/1875 [02:11<00:28, 12.21 batches/s]

mean train: {'loss': {0.03105195178115954}}


Epoch 2/5:  88%|████████▊ | 1655/1875 [02:21<00:18, 12.16 batches/s]

mean train: {'loss': {0.030658983259049595}}


Epoch 2/5:  95%|█████████▍| 1775/1875 [02:31<00:08, 11.93 batches/s]

mean train: {'loss': {0.030224927614967364}}


Epoch 2/5: 100%|██████████| 1875/1875 [02:40<00:00, 11.66 batches/s]


epoch=2 test {'loss': 0.0302952048368752}


Epoch 3/5:   0%|          | 3/1875 [00:01<12:34,  2.48 batches/s]

mean train: {'loss': {0.03030514727952411}}


Epoch 3/5:   7%|▋         | 123/1875 [00:11<02:21, 12.37 batches/s]

mean train: {'loss': {0.030406779817436353}}


Epoch 3/5:  13%|█▎        | 239/1875 [00:21<02:31, 10.81 batches/s]

mean train: {'loss': {0.030326393683967383}}


Epoch 3/5:  19%|█▊        | 351/1875 [00:31<02:12, 11.48 batches/s]

mean train: {'loss': {0.030026573687791824}}


Epoch 3/5:  25%|██▍       | 463/1875 [00:41<02:05, 11.27 batches/s]

mean train: {'loss': {0.029661762079090943}}


Epoch 3/5:  31%|███       | 572/1875 [00:51<01:57, 11.10 batches/s]

mean train: {'loss': {0.030033923858017834}}


Epoch 3/5:  36%|███▌      | 677/1875 [01:01<01:48, 11.03 batches/s]

mean train: {'loss': {0.030388783997381635}}


Epoch 3/5:  42%|████▏     | 780/1875 [01:11<01:42, 10.72 batches/s]

mean train: {'loss': {0.030631926239413375}}


Epoch 3/5:  47%|████▋     | 887/1875 [01:21<01:30, 10.89 batches/s]

mean train: {'loss': {0.030292791549429716}}


Epoch 3/5:  53%|█████▎    | 995/1875 [01:31<01:21, 10.85 batches/s]

mean train: {'loss': {0.030344602582907235}}


Epoch 3/5:  59%|█████▊    | 1101/1875 [01:41<01:10, 10.99 batches/s]

mean train: {'loss': {0.0302577660915172}}


Epoch 3/5:  64%|██████▎   | 1192/1875 [01:51<01:12,  9.49 batches/s]

mean train: {'loss': {0.029622078515016116}}


Epoch 3/5:  69%|██████▉   | 1290/1875 [02:01<00:58, 10.03 batches/s]

mean train: {'loss': {0.030043115288171232}}


Epoch 3/5:  74%|███████▍  | 1390/1875 [02:12<00:47, 10.22 batches/s]

mean train: {'loss': {0.029332432132214308}}


Epoch 3/5:  80%|███████▉  | 1491/1875 [02:22<00:37, 10.34 batches/s]

mean train: {'loss': {0.030403238048057744}}


Epoch 3/5:  84%|████████▍ | 1580/1875 [02:32<00:50,  5.80 batches/s]

mean train: {'loss': {0.029480941426218224}}


Epoch 3/5:  88%|████████▊ | 1659/1875 [02:42<00:23,  9.13 batches/s]

mean train: {'loss': {0.029542612383448626}}


Epoch 3/5:  91%|█████████ | 1697/1875 [02:52<00:41,  4.24 batches/s]

mean train: {'loss': {0.030598632370432217}}


Epoch 3/5:  93%|█████████▎| 1743/1875 [03:02<00:30,  4.35 batches/s]

mean train: {'loss': {0.030006587545832863}}


Epoch 3/5:  95%|█████████▌| 1790/1875 [03:12<00:19,  4.36 batches/s]

mean train: {'loss': {0.028611737917712394}}


Epoch 3/5:  98%|█████████▊| 1833/1875 [03:22<00:09,  4.64 batches/s]

mean train: {'loss': {0.029561487611296566}}


Epoch 3/5: 100%|██████████| 1875/1875 [03:31<00:00,  8.88 batches/s]


epoch=3 test {'loss': 0.029870674130506814}


Epoch 4/5:   0%|          | 2/1875 [00:02<38:10,  1.22s/ batches]  

mean train: {'loss': {0.03062945470040621}}


Epoch 4/5:   3%|▎         | 58/1875 [00:12<03:43,  8.11 batches/s]

mean train: {'loss': {0.030110373993271163}}


Epoch 4/5:   8%|▊         | 152/1875 [00:22<02:50, 10.08 batches/s]

mean train: {'loss': {0.029941567660946595}}


Epoch 4/5:  13%|█▎        | 252/1875 [00:33<02:35, 10.44 batches/s]

mean train: {'loss': {0.029273687671793967}}


Epoch 4/5:  19%|█▉        | 354/1875 [00:43<02:24, 10.53 batches/s]

mean train: {'loss': {0.03047730748369856}}


Epoch 4/5:  24%|██▍       | 456/1875 [00:53<02:19, 10.18 batches/s]

mean train: {'loss': {0.029452128787931887}}


Epoch 4/5:  29%|██▉       | 545/1875 [01:03<02:58,  7.47 batches/s]

mean train: {'loss': {0.02947328961454332}}


Epoch 4/5:  34%|███▎      | 629/1875 [01:13<02:07,  9.75 batches/s]

mean train: {'loss': {0.030349227111964}}


Epoch 4/5:  39%|███▉      | 729/1875 [01:23<01:56,  9.83 batches/s]

mean train: {'loss': {0.030617228355258704}}


Epoch 4/5:  44%|████▍     | 822/1875 [01:33<01:53,  9.27 batches/s]

mean train: {'loss': {0.030730500136331846}}


Epoch 4/5:  49%|████▉     | 915/1875 [01:43<01:38,  9.73 batches/s]

mean train: {'loss': {0.029444346404684487}}


Epoch 4/5:  54%|█████▍    | 1012/1875 [01:53<01:27,  9.89 batches/s]

mean train: {'loss': {0.030752361531263775}}


Epoch 4/5:  59%|█████▉    | 1113/1875 [02:03<01:16, 10.02 batches/s]

mean train: {'loss': {0.02971779116161979}}


Epoch 4/5:  64%|██████▎   | 1195/1875 [02:13<01:20,  8.48 batches/s]

mean train: {'loss': {0.030329627030324643}}


Epoch 4/5:  68%|██████▊   | 1282/1875 [02:23<00:58, 10.19 batches/s]

mean train: {'loss': {0.03058696618346974}}


Epoch 4/5:  73%|███████▎  | 1376/1875 [02:33<00:55,  8.94 batches/s]

mean train: {'loss': {0.02983856597229054}}


Epoch 4/5:  78%|███████▊  | 1470/1875 [02:43<00:47,  8.61 batches/s]

mean train: {'loss': {0.02892620231401413}}


Epoch 4/5:  83%|████████▎ | 1554/1875 [02:53<00:41,  7.78 batches/s]

mean train: {'loss': {0.030252248957930578}}


Epoch 4/5:  87%|████████▋ | 1640/1875 [03:03<00:30,  7.78 batches/s]

mean train: {'loss': {0.0308230901353581}}


Epoch 4/5:  91%|█████████ | 1706/1875 [03:13<00:28,  6.03 batches/s]

mean train: {'loss': {0.0294125169713163}}


Epoch 4/5:  96%|█████████▌| 1794/1875 [03:23<00:08,  9.29 batches/s]

mean train: {'loss': {0.02965455064126714}}


Epoch 4/5: 100%|██████████| 1875/1875 [03:33<00:00,  8.80 batches/s]


epoch=4 test {'loss': 0.02968601498287171}


Epoch 5/5:   0%|          | 2/1875 [00:01<25:11,  1.24 batches/s]

mean train: {'loss': {0.030817628849342645}}


Epoch 5/5:   5%|▌         | 99/1875 [00:11<03:13,  9.18 batches/s]

mean train: {'loss': {0.029858467236314853}}


Epoch 5/5:  10%|▉         | 180/1875 [00:21<03:05,  9.13 batches/s]

mean train: {'loss': {0.02893874927619357}}


Epoch 5/5:  15%|█▍        | 273/1875 [00:32<02:56,  9.06 batches/s]

mean train: {'loss': {0.02924711910146539}}


Epoch 5/5:  19%|█▉        | 358/1875 [00:42<02:37,  9.65 batches/s]

mean train: {'loss': {0.02937477084205431}}


Epoch 5/5:  24%|██▍       | 452/1875 [00:52<02:26,  9.69 batches/s]

mean train: {'loss': {0.030083878263038523}}


Epoch 5/5:  29%|██▉       | 545/1875 [01:02<02:23,  9.29 batches/s]

mean train: {'loss': {0.029026582676877258}}


Epoch 5/5:  34%|███▍      | 637/1875 [01:12<02:06,  9.79 batches/s]

mean train: {'loss': {0.029635130954177483}}


Epoch 5/5:  39%|███▉      | 731/1875 [01:22<02:05,  9.10 batches/s]

mean train: {'loss': {0.02943089485485503}}


Epoch 5/5:  44%|████▍     | 823/1875 [01:32<01:55,  9.11 batches/s]

mean train: {'loss': {0.028786982186948477}}


Epoch 5/5:  49%|████▊     | 911/1875 [01:42<01:42,  9.36 batches/s]

mean train: {'loss': {0.030044430210678416}}


Epoch 5/5:  54%|█████▎    | 1004/1875 [01:52<01:38,  8.85 batches/s]

mean train: {'loss': {0.030305616998223848}}


Epoch 5/5:  58%|█████▊    | 1094/1875 [02:02<02:02,  6.37 batches/s]

mean train: {'loss': {0.028955938604970772}}


Epoch 5/5:  63%|██████▎   | 1187/1875 [02:12<01:14,  9.18 batches/s]

mean train: {'loss': {0.028403331014898517}}


Epoch 5/5:  68%|██████▊   | 1280/1875 [02:22<01:00,  9.79 batches/s]

mean train: {'loss': {0.02952472091482211}}


Epoch 5/5:  73%|███████▎  | 1367/1875 [02:32<00:53,  9.54 batches/s]

mean train: {'loss': {0.029112403762751614}}


Epoch 5/5:  78%|███████▊  | 1454/1875 [02:42<00:43,  9.77 batches/s]

mean train: {'loss': {0.02813858255187328}}


Epoch 5/5:  82%|████████▏ | 1546/1875 [02:52<00:36,  9.01 batches/s]

mean train: {'loss': {0.029417393243183262}}


Epoch 5/5:  87%|████████▋ | 1640/1875 [03:02<00:24,  9.57 batches/s]

mean train: {'loss': {0.029897154625901517}}


Epoch 5/5:  92%|█████████▏| 1733/1875 [03:12<00:14,  9.72 batches/s]

mean train: {'loss': {0.029416676490537582}}


Epoch 5/5:  97%|█████████▋| 1825/1875 [03:23<00:05,  9.21 batches/s]

mean train: {'loss': {0.029350985853892307}}


Epoch 5/5: 100%|██████████| 1875/1875 [03:29<00:00,  8.96 batches/s]


epoch=4 test {'loss': 0.02945508004631847}


## More Sampling fun

In [15]:
def lerp(a, b, t):
  x = a * (1 -t) + b * t
  return x

def animate(fm_model, points=10, samples_between=100, cols=5, rows=5):
  torch.random.manual_seed(1)
  noises = [torch.randn((cols * rows, 1, 28, 28)) for _ in range(points)]
  noises.append(noises[0])  # loop
  images = []
  for a, b in tqdm(zip(noises[:-1], noises[1:]), unit=" points"):
    for t in tqdm(range(samples_between), unit=" samples"):
      n = lerp(a, b, t / samples_between)
      images.append(sample_batch(fm_model, noise=n.to(device), num_steps=10).squeeze(1).cpu())
  i = torch.stack(images, axis=0).clip(0, 1)
  i = einops.rearrange(i, "n (gw gh) w h -> n (gw w) (gh h)", gw=cols, gh=rows)
  show_video(np.array(i[...] * 255).astype(np.uint8), fps=20, codec='gif')

animate(fm_model, points=5, samples_between=100)


100%|██████████| 100/100 [00:16<00:00,  6.07 samples/s]
100%|██████████| 100/100 [00:15<00:00,  6.27 samples/s]
100%|██████████| 100/100 [00:16<00:00,  6.19 samples/s]
100%|██████████| 100/100 [00:15<00:00,  6.26 samples/s]
100%|██████████| 100/100 [00:15<00:00,  6.38 samples/s]
5 points [01:20, 16.05s/ points]
  show_video(np.array(i[...] * 255).astype(np.uint8), fps=20, codec='gif')


RuntimeError: Program 'ffmpeg' is not found; perhaps install ffmpeg using 'apt install ffmpeg'.

## Simple Guidance (just brightness)

In [None]:
@torch.no_grad()
def sample_with_guidance(model, noise, guidance_fn, guidance_scale, num_steps=10):
    model.eval()
    device = next(model.parameters()).device
    x = torch.zeros_like(noise)
    for i in range(num_steps):
      t = torch.tensor([(num_steps - i) / num_steps] * x.shape[0], device=device)
      x = (1 - t[..., None, None, None]) * x + t[..., None, None, None] * noise
      x = model(x, t)

      x_for_guidance = x.detach().clone().requires_grad_(True)
      with torch.enable_grad():
        scores = guidance_fn(x_for_guidance)
        grad = torch.autograd.grad(scores, x_for_guidance)[0]

      x += guidance_scale * grad
    return x


def intensity_guidance(num_examples=10, steps=30):
  def intensity(x):
    return x.mean()

  noise = torch.randn((num_examples, 1, 28, 28), device=device)

  images = []
  for s in tqdm(range(steps), total=steps):
    scale = torch.tensor(s / (steps - 1)) * 2 - 1
    scale *= 1e3

    images.append(sample_with_guidance(fm_model, noise=noise, num_steps=100, guidance_fn=intensity, guidance_scale=scale).squeeze(1).cpu())
  i = torch.stack(images, axis=0).clip(0, 1)
  i = einops.rearrange(i, "r c w h -> (c w) (r h)")
  show_image(i)


intensity_guidance()

100%|██████████| 30/30 [00:40<00:00,  1.35s/it]


## Classifier Model

In [18]:
class MNISTClassifier(nn.Module):
    def __init__(self,
                 in_channels=1,         # MNIST is grayscale
                 num_classes=10,        # MNIST has 10 digits
                 base_channels=32,      # Starting number of channels
                 num_groups_norm=8,
                 use_attention_bottleneck=True):
        super().__init__()

        if base_channels <= 0:
            raise ValueError("base_channels must be positive.")

        # --- Feature Extractor using UNet Encoder/Bottleneck Structure ---
        # Time embedding components are not needed for standard classification
        # We pass time_emb_dim=None to the blocks

        # Initial convolution (maintains size: 28x28 -> 28x28)
        # Input: 1x28x28
        self.conv_in = ConvBlock(in_channels, base_channels, time_emb_dim=None, num_groups=num_groups_norm)
        # Output: C x 28 x 28

        # Encoder Path
        # Downsample 1 (28x28 -> 14x14)
        self.down1 = DownBlock(base_channels, base_channels * 2, time_emb_dim=None, num_groups_norm=num_groups_norm)
        # Output: 2C x 14 x 14
        # Downsample 2 (14x14 -> 7x7)
        self.down2 = DownBlock(base_channels * 2, base_channels * 4, time_emb_dim=None, num_groups_norm=num_groups_norm)
        # Output: 4C x 7 x 7

        # Bottleneck (operates at 7x7 resolution)
        # Increase channels
        self.bottleneck_conv1 = ConvBlock(base_channels * 4, base_channels * 8, time_emb_dim=None, num_groups=num_groups_norm)
        # Output: 8C x 7 x 7

        # Optional Self-Attention
        if use_attention_bottleneck:
            self.attention = SelfAttentionBlock(base_channels * 8, num_heads=4)
        else:
            self.attention = nn.Identity()
        # Output: 8C x 7 x 7

        # Decrease channels back (ready for potential decoder, but we stop here)
        # We will use the output of attention (or bottleneck_conv1 if no attention) for classification
        # Let's use the output AFTER attention (8C channels) for max feature richness at bottleneck
        bottleneck_out_channels = base_channels * 8

        # --- Classification Head ---
        # Global Average Pooling reduces spatial dimensions (7x7) to 1x1
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        # Output: 8C x 1 x 1

        # Final Linear layer to map features to class scores
        self.classifier_head = nn.Linear(bottleneck_out_channels, num_classes)
        # Output: num_classes

    def forward(self, x):
        # x: (batch_size, in_channels, H, W) = (B, 1, 28, 28)

        # Feature Extraction (Encoder + Bottleneck)
        # Time embedding `t_emb` is None throughout
        h = self.conv_in(x, t_emb=None)      # (B, C, 28, 28)
        h = self.down1(h, t_emb=None)        # (B, 2C, 14, 14)
        h = self.down2(h, t_emb=None)        # (B, 4C, 7, 7)
        h = self.bottleneck_conv1(h, t_emb=None) # (B, 8C, 7, 7)
        h = self.attention(h)              # (B, 8C, 7, 7)

        # Classification Head
        h = self.global_avg_pool(h)         # (B, 8C, 1, 1)
        h = torch.flatten(h, 1)             # (B, 8C)
        out_logits = self.classifier_head(h) # (B, num_classes)

        return out_logits


## CE Loss

In [21]:
def ce_loss(model, batch):
    x = batch[0]
    y = batch[1]
    # Forward pass
    y_hat = model(x)

    # Compute the loss
    loss = nn.CrossEntropyLoss()(y_hat, y)

    predicted_classes = torch.argmax(y_hat, dim=1) # Shape: [batch_size]
    accuracy = (predicted_classes == y).to(torch.float32).mean().detach().cpu()
    return dict(loss=loss, accuracy=accuracy)

## Train the classifier

In [None]:
classifier = MNISTClassifier()
classifier.to(device)


trainer = Trainer(model=classifier,
                  train_loader=train_loader,
                  test_loader=test_loader,
                  loss_fn=ce_loss,
                  extra_eval_fn=None,
                  num_epochs=1)
trainer.train()



epoch=0 test {'loss': 2.332376942038536, 'accuracy': 0.10527343768626451}


Epoch 1/1:   6%|▌         | 114/1875 [00:07<01:28, 19.84 batches/s]

mean train: {'loss': {1.4783587349312646}, 'accuracy': {0.4609375}}


Epoch 1/1:  18%|█▊        | 330/1875 [00:17<01:11, 21.70 batches/s]

mean train: {'loss': {0.3385149829469069}, 'accuracy': {0.9013310185185185}}


Epoch 1/1:  29%|██▉       | 549/1875 [00:27<01:00, 22.06 batches/s]

mean train: {'loss': {0.18490778050527046}, 'accuracy': {0.9447004608294931}}


Epoch 1/1:  40%|███▉      | 748/1875 [00:37<00:50, 22.12 batches/s]

mean train: {'loss': {0.15773484058242607}, 'accuracy': {0.9544465174129353}}


Epoch 1/1:  52%|█████▏    | 967/1875 [00:47<00:41, 21.66 batches/s]

mean train: {'loss': {0.12380649167604418}, 'accuracy': {0.9605414746543779}}


Epoch 1/1:  56%|█████▋    | 1059/1875 [00:52<00:38, 21.14 batches/s]


## Classifier Guidance

In [30]:
def class_scorer(classifier, x, class_weights):
  logits = classifier(x)
  log_probs = F.log_softmax(logits, dim=-1)
  return (log_probs * class_weights[None, :]).sum(dim=-1).mean()




def number_guidance(num_examples=10):
  noise = torch.randn((num_examples, 1, 28, 28), device=device)

  images = []
  for n in tqdm(range(10)):
    one_hot = F.one_hot(torch.tensor(n), 10).to(device)
    def scorer(x):
      return class_scorer(classifier, x, one_hot)
    images.append(sample_with_guidance(fm_model, noise=noise, num_steps=100, guidance_fn=scorer, guidance_scale=20).squeeze(1).cpu())
  i = torch.stack(images, axis=0).clip(0, 1)
  i = einops.rearrange(i, "r c w h -> (c w) (r h)")
  show_image(i)

number_guidance()

  0%|          | 0/10 [00:00<?, ?it/s]