## Setup

In [15]:
!pip install -q mediapy

## Imports

In [16]:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm.auto import tqdm
import math
from mediapy import show_image, show_images, show_video
import time
from collections import defaultdict
import einops

## Setup Accelerator

In [17]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

using device: cuda


## Num Classes

In [18]:
N_GRAYSCALES = 4

## Prepare dataset

In [19]:
class QuantizeToNGrayscales:

    def __init__(self, num_grayscales: int):
        self.num_grayscales = num_grayscales

    def __call__(self, tensor_image: torch.Tensor) -> torch.Tensor:
        scaled_image = tensor_image * (self.num_grayscales - 1)
        return torch.round(scaled_image).to(torch.int32).squeeze(0)


transform_quantize = T.Compose([
    T.ToTensor(),
    QuantizeToNGrayscales(N_GRAYSCALES)
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True,
    transform=transform_quantize
)

test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True,
    transform=transform_quantize
)

# train_dataset = torchvision.datasets.MNIST(
#     root='./data', train=True, download=True,
#     transform=transforms.ToTensor() # Keep as continuous initially
# )
# test_dataset = torchvision.datasets.MNIST(
#     root='./data', train=False, download=True,
#     transform=transforms.ToTensor()
# )
test_dataset = Subset(test_dataset, range(1_000))

batch_size = 32

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size * 2, shuffle=False, pin_memory=True
)
print()
print("train - showing original continuous images")
temp_batch = next(iter(train_loader))[0]
show_images(temp_batch.to(torch.uint32)) # This will show continuous images
print(f"Example pixel values from train_loader (continuous): {temp_batch.min()}, {temp_batch.max()}")

print("test - showing original continuous images")
temp_batch_test = next(iter(test_loader))[0]
show_images(temp_batch_test.to(torch.uint32))


train - showing original continuous images


Example pixel values from train_loader (continuous): 0, 3
test - showing original continuous images


## Flow Matching Model

In [20]:
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        if self.dim % 2 == 1: # zero pad if dim is odd for a final concat
            pad_tensor = torch.zeros((embeddings.shape[0], 1), device=device)
            embeddings = torch.cat((embeddings.sin(), embeddings.cos(), pad_tensor), dim=-1)
        else:
            embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class MLP(nn.Module):
    def __init__(self, time_embedding_dim, hidden_dim, out_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(time_embedding_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, t_emb):
        return self.mlp(t_emb)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim=None, num_groups=8):
        super().__init__()

        def get_effective_groups(channels, requested_groups):
            if channels == 0: return 1 # Should not happen with >0 channels
            # Find largest divisor of channels that is <= requested_groups
            # Default to 1 if no other common divisor or too few channels.
            if channels < requested_groups and channels > 0 : # If channels < num_groups, each channel is a group or 1 group
                 return channels # Group per channel is often max useful grouping here.
                                # Or, if channels is small (e.g. 1,2,3) use 1 group.
                                # Let's be robust: use min(channels, requested_groups) then ensure divisibility

            eff_groups = min(channels, requested_groups) if channels > 0 else 1
            while eff_groups > 0 and channels % eff_groups != 0:
                eff_groups -=1
            return eff_groups if eff_groups > 0 else 1 # Fallback to 1 group

        effective_num_groups = get_effective_groups(out_channels, num_groups)


        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(num_groups=effective_num_groups, num_channels=out_channels) if out_channels > 0 else nn.Identity()
        self.act1 = nn.SiLU()

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(num_groups=effective_num_groups, num_channels=out_channels) if out_channels > 0 else nn.Identity()
        self.act2 = nn.SiLU()

        self.time_proj = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None and time_emb_dim > 0 else None
        self.res_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, t_emb=None):
        h = self.conv1(x)
        h = self.norm1(h)

        if self.time_proj is not None and t_emb is not None:
            time_cond = self.time_proj(t_emb)
            if time_cond is not None:
                h = h + time_cond[:, :, None, None]

        h = self.act1(h)
        h = self.conv2(h)
        h = self.norm2(h)
        h = self.act2(h)
        return h + self.res_conv(x)


class SelfAttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True)

    def forward(self, x):
      s = x.shape
      t = einops.rearrange(x, 'b c h w -> b (h w) c')
      t = self.attn(t, t, t)[0]
      return x + einops.rearrange(t, 'b (h w) c -> b c h w', c=s[1], h=s[2])



class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim=None, num_groups_norm=8):
        super().__init__()
        self.conv = ConvBlock(in_channels, out_channels, time_emb_dim, num_groups_norm)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x, t_emb=None):
        x_pre_conv = self.pool(x)
        x_post_conv = self.conv(x_pre_conv, t_emb)
        return x_post_conv

class UpBlock(nn.Module):
    def __init__(self, in_channels_prev_up, channels_from_skip, out_channels_conv, time_emb_dim, num_groups_norm):
        super().__init__()
        self.up_conv_transpose = nn.ConvTranspose2d(in_channels_prev_up, in_channels_prev_up // 2, kernel_size=2, stride=2)
        self.conv = ConvBlock( (in_channels_prev_up // 2) + channels_from_skip, out_channels_conv, time_emb_dim, num_groups_norm)

    def forward(self, x_prev_up, x_skip, t_emb):
        x_upsampled = self.up_conv_transpose(x_prev_up)

        if x_upsampled.shape[2:] != x_skip.shape[2:]:
            diffY = x_skip.size(2) - x_upsampled.size(2)
            diffX = x_skip.size(3) - x_upsampled.size(3)
            x_upsampled = F.pad(x_upsampled, [diffX // 2, diffX - diffX // 2,
                                           diffY // 2, diffY - diffY // 2])

        x_cat = torch.cat([x_upsampled, x_skip], dim=1)
        x = self.conv(x_cat, t_emb)
        return x


class UNetFlowMatcherMNIST(nn.Module):
    def __init__(self,
                 num_pixel_classes=1, # Changed from out_channels for discrete output
                 base_channels=32,
                 time_embedding_dim=128,
                 time_mlp_hidden_dim=512,
                 num_groups_norm=8,
                 use_attention_bottleneck=True):
        super().__init__()

        self.num_pixel_classes = num_pixel_classes
        self.embedding = nn.Embedding(num_pixel_classes, base_channels)

        if base_channels <= 0:
            raise ValueError("base_channels must be positive.")
        if time_embedding_dim <= 0 and time_mlp_hidden_dim > 0 :
             time_embedding_dim = 0
             time_mlp_hidden_dim = 0

        self.time_embedding = SinusoidalTimeEmbedding(time_embedding_dim) if time_embedding_dim > 0 else nn.Identity()
        self.time_mlp = MLP(time_embedding_dim, time_mlp_hidden_dim, time_embedding_dim) if time_embedding_dim > 0 else nn.Identity()

        self.conv_in = ConvBlock(base_channels, base_channels, time_embedding_dim, num_groups_norm)
        self.down1 = DownBlock(base_channels, base_channels * 2, time_embedding_dim, num_groups_norm)
        self.down2 = DownBlock(base_channels * 2, base_channels * 4, time_embedding_dim, num_groups_norm)
        self.bottleneck_conv1 = ConvBlock(base_channels * 4, base_channels * 8, time_embedding_dim, num_groups_norm)
        if use_attention_bottleneck:
            self.attention = SelfAttentionBlock(base_channels * 8, num_heads=4)
        else:
            self.attention = nn.Identity()
        self.bottleneck_conv2 = ConvBlock(base_channels * 8, base_channels * 4, time_embedding_dim, num_groups_norm)
        self.up1 = UpBlock(in_channels_prev_up=base_channels * 4, channels_from_skip=base_channels * 2, out_channels_conv=base_channels * 2, time_emb_dim=time_embedding_dim, num_groups_norm=num_groups_norm)
        self.up2 = UpBlock(in_channels_prev_up=base_channels * 2, channels_from_skip=base_channels, out_channels_conv=base_channels, time_emb_dim=time_embedding_dim, num_groups_norm=num_groups_norm)
        self.conv_out = nn.Conv2d(base_channels, self.num_pixel_classes, kernel_size=1)

    def forward(self, x, t):
        x = self.embedding(x)
        x = einops.rearrange(x, "b h w d -> b d h w")
        if isinstance(self.time_embedding, nn.Identity) or t is None:
            t_emb = None
        else:
            t_emb_sin = self.time_embedding(t)
            t_emb = self.time_mlp(t_emb_sin)

        s1_skip = self.conv_in(x, t_emb)
        s2_skip = self.down1(s1_skip, t_emb)
        s3_features = self.down2(s2_skip, t_emb)

        b = self.bottleneck_conv1(s3_features, t_emb)
        b = self.attention(b)
        b = self.bottleneck_conv2(b, t_emb)

        u1 = self.up1(b, s2_skip, t_emb)
        u2 = self.up2(u1, s1_skip, t_emb)

        out_logits = self.conv_out(u2)
        return out_logits

## Flow Matching Loss and sampling

In [21]:
def flow_matching_loss(model, batch, K=N_GRAYSCALES):
    x0_discrete = batch[0].to(next(model.parameters()).device)
    batch_size, H, W = x0_discrete.shape
    t_scalar = torch.rand((batch_size,), device=x0_discrete.device)
    t_broadcast = t_scalar.view(batch_size, 1, 1)
    noise_discrete = torch.randint(0, K, size=(batch_size, H, W), device=x0_discrete.device, dtype=torch.long)
    mask = torch.rand((batch_size, H, W), device=x0_discrete.device) < t_broadcast
    x_t_discrete = torch.where(mask, noise_discrete, x0_discrete)

    logits_pred_x0 = model(x_t_discrete, t_scalar)

    x0_target_for_loss = x0_discrete.squeeze(1)
    loss = F.cross_entropy(logits_pred_x0, x0_target_for_loss.to(torch.int64))
    return {"loss": loss}


@torch.no_grad()
def sample_batch(model, initial_noise_discrete, num_steps=10, K=N_GRAYSCALES):
    model.eval()
    device = next(model.parameters()).device
    x_est_x0_discrete = initial_noise_discrete.clone().to(device)
    fixed_x1_sample_discrete = torch.randint(0, K, size=initial_noise_discrete.shape, device=device, dtype=torch.long)

    for i in range(num_steps):
        t_model_scalar = (num_steps - i) / num_steps
        t_for_model_nn = torch.full((x_est_x0_discrete.shape[0],), t_model_scalar, device=device, dtype=torch.float)
        mask_prob = t_model_scalar
        mask = torch.rand_like(x_est_x0_discrete, dtype=torch.float) < mask_prob
        x_t_formed_discrete = torch.where(mask, fixed_x1_sample_discrete, x_est_x0_discrete)
        logits_x0_pred = model(x_t_formed_discrete, t_for_model_nn)
        b, _, h, w = logits_x0_pred.shape
        flat_logits = einops.rearrange(logits_x0_pred, "b c h w -> (b h w) c")
        flat_probs = F.softmax(flat_logits, dim=-1)
        flat_discrete = torch.multinomial(flat_probs, num_samples=1)
        x_est_x0_discrete = einops.rearrange(flat_discrete, "(b h w) 1 -> b h w", b=b, h=h, w=w)

    # final_image_float = x_est_x0_discrete.squeeze(1).float() / (K - 1.0)
    return x_est_x0_discrete.squeeze(1)

## Trainer

In [22]:
class TrainLossLogger:
  def __init__(self, log_every_sec=10.0) -> None:
    self.running = defaultdict(list)
    self.log_every_sec = log_every_sec
    self.last_log = time.time()

  def log(self, data) -> None:
    for k, v in data.items():
      self.running[k].append(v)
    if self.last_log + self.log_every_sec < time.time():
      self.last_log = time.time()
      print("mean train:", {k: float(np.array(v).mean()) for k, v in self.running.items()})
      self.running = defaultdict(list)


class Trainer:
    def __init__(self, *, model, train_loader, test_loader, loss_fn, extra_eval_fn=None, num_epochs=5, batch_size=32):
        self.model = model
        self.device = next(model.parameters()).device
        self.loss_fn = loss_fn
        self.extra_eval_fn = extra_eval_fn
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.train_loader = train_loader
        self.test_loader = test_loader

    def train(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        loss_logger = TrainLossLogger()

        for epoch in range(self.num_epochs):
            print(f"epoch={epoch} test {self.eval()}")
            if self.extra_eval_fn:
                self.extra_eval_fn(self.model)

            self.model.train()
            prog_bar = tqdm(enumerate(self.train_loader), desc=f"Epoch {epoch + 1}/{self.num_epochs}", total=len(self.train_loader), unit=" batches")
            for i, batch in prog_bar:
                optimizer.zero_grad()
                loss_dict = self.loss_fn(self.model, batch)
                loss = loss_dict["loss"]
                loss.backward()
                optimizer.step()
                log_items = {k: v.item() for k, v in loss_dict.items()}
                loss_logger.log(log_items)
                prog_bar.set_postfix(log_items)

        print(f"epoch={epoch} test {self.eval()}") # Final eval after last epoch
        if self.extra_eval_fn:
            self.extra_eval_fn(self.model)

    @torch.no_grad()
    def eval(self) -> float:
        torch.manual_seed(17)

        self.model.eval()
        total_loss_vals = defaultdict(list)
        for batch in self.test_loader:
            loss_dict = self.loss_fn(self.model, batch)
            for k, v in loss_dict.items():
              total_loss_vals[k].append(v.item())

        torch.random.seed()
        return {k: float(np.array(v).mean()) for k, v in total_loss_vals.items()}

## Train the Flow matching model

In [23]:
fm_model = UNetFlowMatcherMNIST(num_pixel_classes=N_GRAYSCALES, base_channels=32)
fm_model.to(device)

def sample_and_viz(model_to_eval):
  torch.manual_seed(1)
  initial_noise = torch.randint(0, N_GRAYSCALES, (20, 28, 28),
                                dtype=torch.long, device=device)
  imgs = sample_batch(model_to_eval, initial_noise, num_steps=20, K=N_GRAYSCALES)
  show_images(imgs.cpu().to(torch.uint8))
  print(f"Generated sample images (min/max): {imgs.min()}/{imgs.max()}")
  torch.random.seed()

trainer_fm = Trainer(model=fm_model,
                  train_loader=train_loader,
                  test_loader=test_loader,
                  loss_fn=flow_matching_loss,
                  extra_eval_fn=sample_and_viz,
                  num_epochs=5)
trainer_fm.train()


epoch=0 test {'loss': 1.6118702217936516}


Generated sample images (min/max): 0/3


Epoch 1/5:   0%|          | 0/1875 [00:00<?, ? batches/s]

mean train: {'loss': 0.2564704418182373}
mean train: {'loss': 0.2124959212730302}
mean train: {'loss': 0.20499897887074728}
mean train: {'loss': 0.20271398726279144}
mean train: {'loss': 0.2006598432109785}
epoch=1 test {'loss': 0.18892263527959585}


Generated sample images (min/max): 0/3


Epoch 2/5:   0%|          | 0/1875 [00:00<?, ? batches/s]

mean train: {'loss': 0.19699250350932818}
mean train: {'loss': 0.19588007116540868}
mean train: {'loss': 0.19193674649388145}
mean train: {'loss': 0.19061148583889007}
mean train: {'loss': 0.19277786860601254}
epoch=2 test {'loss': 0.18200428690761328}


Generated sample images (min/max): 0/3


Epoch 3/5:   0%|          | 0/1875 [00:00<?, ? batches/s]

mean train: {'loss': 0.1892879721323649}
mean train: {'loss': 0.1893943448728275}
mean train: {'loss': 0.190940869007917}
mean train: {'loss': 0.1888091646895934}
mean train: {'loss': 0.18848023620935586}
epoch=3 test {'loss': 0.18169338908046484}


Generated sample images (min/max): 0/3


Epoch 4/5:   0%|          | 0/1875 [00:00<?, ? batches/s]

mean train: {'loss': 0.18582893050018143}
mean train: {'loss': 0.1876825066461955}
mean train: {'loss': 0.1881301712499906}
mean train: {'loss': 0.18874692371308477}
mean train: {'loss': 0.1865629378475269}
epoch=4 test {'loss': 0.1775970133021474}


Generated sample images (min/max): 0/3


Epoch 5/5:   0%|          | 0/1875 [00:00<?, ? batches/s]

mean train: {'loss': 0.18887059558084401}
mean train: {'loss': 0.1839341744000375}
mean train: {'loss': 0.1829035377069827}
mean train: {'loss': 0.18409946641406497}
mean train: {'loss': 0.1847560100433742}
epoch=4 test {'loss': 0.1791500011458993}


Generated sample images (min/max): 0/3


## Simple Guidance (just brightness)

In [29]:
@torch.no_grad()
def sample_with_guidance(model, initial_noise_discrete, guidance_fn, guidance_scale,
                         num_steps=10, K=N_GRAYSCALES, num_guidance_samples=10):
    model.eval()
    device = next(model.parameters()).device
    x_est_x0_discrete = initial_noise_discrete.clone().to(device)
    fixed_x1_sample_discrete = torch.randint(0, K, size=initial_noise_discrete.shape, device=device, dtype=torch.long)

    for i in range(num_steps):
        t_model_scalar = (num_steps - i) / num_steps
        t_for_model_nn = torch.full((x_est_x0_discrete.shape[0],), t_model_scalar, device=device, dtype=torch.float)
        mask_prob = t_model_scalar
        mask = torch.rand_like(x_est_x0_discrete, dtype=torch.float) < mask_prob
        x_t_formed_discrete = torch.where(mask, fixed_x1_sample_discrete, x_est_x0_discrete)

        logits_pred_x0 = model(x_t_formed_discrete, t_for_model_nn)

        grad_wrt_logits = torch.zeros_like(logits_pred_x0)
        for _ in range(num_guidance_samples):
          with torch.enable_grad():
              logits_for_guidance = logits_pred_x0.clone().requires_grad_(True)
              # probs_for_guidance = F.softmax(logits_for_guidance, dim=1)
              # scores = guidance_fn(probs_for_guidance)
              onehots_for_guidance = F.gumbel_softmax(logits_for_guidance, dim=1)
              scores = guidance_fn(onehots_for_guidance)
              grad_wrt_logits += torch.autograd.grad(scores, logits_for_guidance)[0].detach()
        grad_wrt_logits /= num_guidance_samples

        logits_guided = logits_pred_x0 + guidance_scale * grad_wrt_logits
        b, _, h, w = logits_guided.shape
        flat_logits = einops.rearrange(logits_guided, "b c h w -> (b h w) c")
        flat_probs = F.softmax(flat_logits, dim=-1)
        flat_discrete = torch.multinomial(flat_probs, num_samples=1)
        x_est_x0_discrete = einops.rearrange(flat_discrete, "(b h w) 1 -> b h w", b=b, h=h, w=w)

    return x_est_x0_discrete.squeeze(1)


def intensity_guidance_viz(model_to_guide, steps=30, num_examples=10, K_val=N_GRAYSCALES):
  device = next(model_to_guide.parameters()).device
  def intensity_objective(x):
    pixel_values_tensor = torch.arange(K_val, device=device, dtype=torch.float32).view(1, K_val, 1, 1)
    expected_img_values = torch.sum(x * pixel_values_tensor, dim=1)
    return expected_img_values.mean()
  torch.manual_seed(1)
  device_viz = next(model_to_guide.parameters()).device
  initial_noise = torch.randint(0, K_val, (num_examples, 28, 28), dtype=torch.long, device=device_viz)

  images = []
  for s in tqdm(range(steps), total=steps, unit="guidance_steps", desc="Intensity Guidance"):
    scale = torch.tensor(s / (steps - 1)) * 2 - 1
    scale *= 1e4

    sampled_imgs = sample_with_guidance(
        model_to_guide, initial_noise_discrete=initial_noise, guidance_fn=intensity_objective,
        guidance_scale=scale, num_steps=20, K=K_val)
    images.append(sampled_imgs.cpu())

  torch.random.seed()

  i = torch.stack(images, axis=0).to(torch.uint8)
  i = einops.rearrange(i, "r c w h -> (c w) (r h)")
  show_image(i)

intensity_guidance_viz(fm_model)

Intensity Guidance:   0%|          | 0/30 [00:00<?, ?guidance_steps/s]

## Classifier Model

In [25]:
class MNISTClassifier(nn.Module):
    def __init__(self,
                 num_classes=10, base_channels=32,
                 num_groups_norm=8, use_attention_bottleneck=True,
                 K_pixel_input=N_GRAYSCALES):
        super().__init__()
        if base_channels <= 0:
            raise ValueError("base_channels must be positive.")
        self.embedding = nn.Conv2d(K_pixel_input, base_channels, kernel_size=1, padding=0)
        self.conv_in = ConvBlock(base_channels, base_channels, time_emb_dim=None, num_groups=num_groups_norm)
        self.down1 = DownBlock(base_channels, base_channels * 2, time_emb_dim=None, num_groups_norm=num_groups_norm)
        self.down2 = DownBlock(base_channels * 2, base_channels * 4, time_emb_dim=None, num_groups_norm=num_groups_norm)
        self.bottleneck_conv1 = ConvBlock(base_channels * 4, base_channels * 8, time_emb_dim=None, num_groups=num_groups_norm)
        if use_attention_bottleneck:
            self.attention = SelfAttentionBlock(base_channels * 8, num_heads=4)
        else:
            self.attention = nn.Identity()
        bottleneck_out_channels = base_channels * 8
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier_head = nn.Linear(bottleneck_out_channels, num_classes)
        self.num_pixel_classes = K_pixel_input

    def forward(self, x):
        x = x.to(torch.long)
        x = F.one_hot(x, num_classes=self.num_pixel_classes).to(torch.float32)
        x = einops.rearrange(x, "b h w d -> b d h w")
        return self.forward_onehot(x)

    def forward_onehot(self, x):
        h = self.embedding(x)
        h = self.conv_in(h, t_emb=None)
        h = self.down1(h, t_emb=None)
        h = self.down2(h, t_emb=None)
        h = self.bottleneck_conv1(h, t_emb=None)
        h = self.attention(h)
        h = self.global_avg_pool(h)
        h = torch.flatten(h, 1)
        out_logits = self.classifier_head(h)
        return out_logits


## CE Loss

In [26]:
def ce_loss_classifier(model, batch, K_pixel_val=N_GRAYSCALES):
    x = batch[0].to(next(model.parameters()).device)
    y_digit_labels = batch[1].to(next(model.parameters()).device)
    y_logits_pred = model(x)
    loss = nn.CrossEntropyLoss()(y_logits_pred, y_digit_labels)
    predicted_classes = torch.argmax(y_logits_pred, dim=1)
    accuracy = (predicted_classes == y_digit_labels).float().mean()
    return dict(loss=loss, accuracy=accuracy)


## Train the classifier

In [27]:
classifier_model = MNISTClassifier(K_pixel_input=N_GRAYSCALES)
classifier_model.to(device)
trainer_classifier = Trainer(
    model=classifier_model, train_loader=train_loader, test_loader=test_loader,
    loss_fn=ce_loss_classifier, extra_eval_fn=None, num_epochs=1 )
trainer_classifier.train()

epoch=0 test {'loss': 2.438400626182556, 'accuracy': 0.09121093759313226}


Epoch 1/1:   0%|          | 0/1875 [00:00<?, ? batches/s]

mean train: {'loss': 0.7561990917958592, 'accuracy': 0.7373930317848411}
mean train: {'loss': 0.17022757836530508, 'accuracy': 0.9499254473161034}
mean train: {'loss': 0.12287002402609883, 'accuracy': 0.9627647329650092}
epoch=0 test {'loss': 0.0838384751114063, 'accuracy': 0.9750000014901161}


## Classifier Guidance

In [31]:
def class_scorer(classifier_nn, x, class_weights):
    digit_logits = classifier_nn.forward_onehot(x)
    log_probs_digits = F.log_softmax(digit_logits, dim=-1)
    score = (log_probs_digits * class_weights[None, :]).sum(dim=-1).mean()
    return score


def number_guidance(model_to_guide, classifier_model, num_examples=10, K_val=N_GRAYSCALES):
  torch.manual_seed(1)
  device_viz = next(model_to_guide.parameters()).device
  initial_noise = torch.randint(0, K_val, (num_examples, 28, 28), dtype=torch.long, device=device_viz)


  images = []
  for n in tqdm(range(10)):
    one_hot = F.one_hot(torch.tensor(n), 10).to(device)
    def scorer(x):
      return class_scorer(classifier_model, x, one_hot)
    sampled_imgs = sample_with_guidance(
      fm_model, initial_noise_discrete=initial_noise, guidance_fn=scorer,
      guidance_scale=1e4, num_steps=100, K=N_GRAYSCALES)
    images.append(sampled_imgs.cpu())
  i = torch.stack(images, axis=0).to(torch.uint8)
  i = einops.rearrange(i, "r c w h -> (c w) (r h)")
  show_image(i)


number_guidance(fm_model, classifier_model)

  0%|          | 0/10 [00:00<?, ?it/s]