# MNIST

This is the file that contains the working code for the MNIST model

In [None]:
# ===========================
# 0. Imports
# ===========================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ===========================
# 1. Data (MNIST)
# ===========================
image_size = 28
batch_size = 128

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),                  # [0,1]
    transforms.Normalize(0.5, 0.5),         # [-1,1]
])

train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    transform=transform,
    download=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

# ===========================
# 2. Time Embeddings
# ===========================
class SinusoidalPosEmb(nn.Module):
    """
    Standard sinusoidal positional embedding for scalar timestep t.
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        """
        t: (B,) int or float timesteps
        returns: (B, dim)
        """
        device = t.device
        half_dim = self.dim // 2
        emb_factor = math.log(10000) / (half_dim - 1)
        # exponents: [0, 1, 2, ..., half_dim-1]
        exponents = torch.exp(torch.arange(half_dim, device=device) * -emb_factor)
        # t: (B,1), exponents: (half_dim,) -> (B, half_dim)
        t = t.float().unsqueeze(1)
        angles = t * exponents[None, :]
        emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
        if self.dim % 2 == 1:
            # pad if odd
            emb = F.pad(emb, (0, 1))
        return emb

# ===========================
# 3. UNet building blocks
# ===========================
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.time_dim = time_dim

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        self.time_mlp = nn.Linear(time_dim, out_channels)

        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act = nn.SiLU()

        if in_channels != out_channels:
            self.res_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.res_conv = nn.Identity()

    def forward(self, x, t_emb):
        """
        x: (B, C, H, W)
        t_emb: (B, time_dim)
        """
        h = self.conv1(x)
        # inject time
        t_added = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_added
        h = self.act(self.norm1(h))

        h = self.conv2(h)
        h = self.act(self.norm2(h))

        return h + self.res_conv(x)


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.resblock = ResidualBlock(in_channels, out_channels, time_dim)
        self.down = nn.Conv2d(out_channels, out_channels, 4, stride=2, padding=1)

    def forward(self, x, t_emb):
        x = self.resblock(x, t_emb)
        skip = x
        x = self.down(x)
        return x, skip


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, time_dim):
        """
        in_channels: channels coming from below
        out_channels: channels after upsample
        skip_channels: channels from the skip connection
        """
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)
        self.resblock = ResidualBlock(out_channels + skip_channels, out_channels, time_dim)

    def forward(self, x, skip, t_emb):
        x = self.up(x)
        # concatenate along channel dimension
        x = torch.cat([x, skip], dim=1)
        x = self.resblock(x, t_emb)
        return x

# ===========================
# 4. UNet model
# ===========================
class UNet(nn.Module):
    def __init__(self, in_channels=1, base_channels=64, time_dim=128, num_classes=10):
        super().__init__()
        self.in_channels = in_channels
        self.base_channels = base_channels
        self.time_dim = time_dim
        self.num_classes = num_classes

        # Time embedding: sinusoidal + MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim),
        )

        # Class embedding (for conditional generation)
        self.label_emb = nn.Embedding(num_classes, time_dim)

        # Initial conv
        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder
        chs = [base_channels, base_channels * 2, base_channels * 4]  # [64,128,256]
        self.down1 = DownBlock(chs[0], chs[1], time_dim)
        self.down2 = DownBlock(chs[1], chs[2], time_dim)

        # Bottleneck
        self.bottleneck = ResidualBlock(chs[2], chs[2], time_dim)

        # Decoder
        self.up2 = UpBlock(chs[2], chs[1], skip_channels=chs[2], time_dim=time_dim)
        self.up1 = UpBlock(chs[1], chs[0], skip_channels=chs[1], time_dim=time_dim)

        # Final conv to go back to 1 channel (noise prediction)
        self.out_conv = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, x, t, y=None):
        """
        x: (B, 1, 28, 28) noisy image
        t: (B,) integer timesteps
        y: (B,) labels 0..9 (optional)
        """
        # time embedding
        t_emb = self.time_mlp(t)  # (B, time_dim)

        # class conditioning
        if y is not None:
            y_emb = self.label_emb(y)  # (B, time_dim)
            t_emb = t_emb + y_emb

        # UNet
        x = self.init_conv(x)  # -> (B, 64, 28, 28)

        x, skip1 = self.down1(x, t_emb)  # x: (B, 128, 14, 14), skip1: (B, 128, 28, 28)
        x, skip2 = self.down2(x, t_emb)  # x: (B, 256, 7, 7),  skip2: (B, 256, 14, 14)

        x = self.bottleneck(x, t_emb)   # (B, 256, 7, 7)

        x = self.up2(x, skip2, t_emb)   # (B, 128, 14, 14)
        x = self.up1(x, skip1, t_emb)   # (B, 64, 28, 28)

        x = self.out_conv(x)            # (B, 1, 28, 28)
        return x  # predicted noise ε̂


# ===========================
# 5. Diffusion utilities (DDPM-style)
# ===========================
class Diffusion:
    def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, device="cpu"):
        self.device = device
        self.num_steps = num_steps

        self.betas = torch.linspace(beta_start, beta_end, num_steps, device=device)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alpha_cumprod_prev = torch.cat(
            [torch.tensor([1.0], device=device), self.alpha_cumprod[:-1]], dim=0
        )

        # For sampling
        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alpha_cumprod)
        self.sqrt_alpha_cumprod = torch.sqrt(self.alpha_cumprod)

        # Posterior variance for q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            self.betas
            * (1.0 - self.alpha_cumprod_prev)
            / (1.0 - self.alpha_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            torch.clamp(self.posterior_variance, min=1e-20)
        )

    def sample_timesteps(self, batch_size):
        """
        Uniformly sample timesteps t in [0, num_steps-1]
        """
        return torch.randint(0, self.num_steps, (batch_size,), device=self.device).long()

    def q_sample(self, x0, t, noise=None):
        """
        Diffuse the data (forward process): q(x_t | x_0)
        x0: (B, C, H, W)
        t:  (B,) timesteps
        """
        if noise is None:
            noise = torch.randn_like(x0)
        # gather alpha_cumprod for each t
        sqrt_alpha_cumprod_t = self.sqrt_alpha_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise

    @torch.no_grad()
    def p_sample(self, model, x_t, t, y=None):
        """
        One reverse step: p(x_{t-1} | x_t)
        x_t: (B, 1, 28, 28)
        t: scalar int timestep (not batch)
        """
        B = x_t.size(0)
        t_batch = torch.full((B,), t, device=self.device, dtype=torch.long)

        # predict noise at this timestep
        eps_theta = model(x_t, t_batch, y)

        beta_t = self.betas[t]
        alpha_t = self.alphas[t]
        alpha_cumprod_t = self.alpha_cumprod[t]
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t]
        sqrt_recip_alpha_t = (1.0 / torch.sqrt(alpha_t))

        # Estimate x_0 from x_t and ε̂
        # x_0 ≈ (x_t - sqrt(1 - ᾱ_t) * εθ) / sqrt(ᾱ_t)
        # (not strictly needed here but good to keep in mind)
        # Use the DDPM formula for the mean:
        # μθ(x_t, t) = 1/sqrt(α_t) * (x_t - β_t / sqrt(1 - ᾱ_t) * εθ)
        model_mean = sqrt_recip_alpha_t * (
            x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * eps_theta
        )

        if t == 0:
            return model_mean
        else:
            posterior_var_t = self.posterior_variance[t]
            noise = torch.randn_like(x_t)
            return model_mean + torch.sqrt(posterior_var_t) * noise

    @torch.no_grad()
    def sample(self, model, image_size, batch_size=16, y=None):
        """
        Generate samples from pure noise.
        """
        model.eval()
        x = torch.randn(batch_size, 1, image_size, image_size, device=self.device)

        # If conditional and y is None, randomly choose labels
        if y is None:
            y = torch.randint(0, 10, (batch_size,), device=self.device)

        for t in reversed(range(self.num_steps)):
            x = self.p_sample(model, x, t, y=y)

        return x   # in [-1,1] (approximately)


# ===========================
# 6. Model + Diffusion setup
# ===========================
time_dim = 128
model = UNet(
    in_channels=1,
    base_channels=64,
    time_dim=time_dim,
    num_classes=10
).to(device)

num_steps = 1000
diffusion = Diffusion(
    num_steps=num_steps,
    beta_start=1e-4,
    beta_end=0.02,
    device=device
)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# ===========================
# 7. Training Loop
# ===========================
num_epochs = 10  # bump this up (e.g. 50+) for better samples

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for x, labels in train_loader:
        x = x.to(device)           # already normalized to [-1,1]
        labels = labels.to(device)

        b = x.size(0)
        t = diffusion.sample_timesteps(b)  # (B,)

        noise = torch.randn_like(x)
        x_t = diffusion.q_sample(x, t, noise=noise)

        # predict the noise
        pred_noise = model(x_t, t, y=labels)

        loss = F.mse_loss(pred_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * b

    epoch_loss = running_loss / len(train_dataset)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.6f}")

print("Training done.")

# ===========================
# 8. Sampling & Visualization
# ===========================
@torch.no_grad()
def show_samples(model, diffusion, n=16, class_label=None):
    model.eval()
    if class_label is None:
        # random labels
        y = torch.randint(0, 10, (n,), device=device)
    else:
        y = torch.full((n,), class_label, device=device, dtype=torch.long)

    samples = diffusion.sample(model, image_size=image_size, batch_size=n, y=y)
    # de-normalize from [-1,1] back to [0,1]
    samples = (samples.clamp(-1, 1) + 1) / 2.0

    grid = utils.make_grid(samples, nrow=int(math.sqrt(n)))
    plt.figure(figsize=(4, 4))
    plt.axis("off")
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.show()

# Example: generate random digits
show_samples(model, diffusion, n=16, class_label=None)

# Example: generate only "3"s
# show_samples(model, diffusion, n=16, class_label=3)
