<a href="https://colab.research.google.com/github/dongzhuoyao/minimal-dmd/blob/main/train_teacher_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DMD2 Teacher Training - MNIST

This notebook trains a teacher diffusion model on MNIST that will be used for DMD2 distillation.


In [1]:
# Install dependencies
!pip install torch torchvision tqdm -q


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import os
import numpy as np


## Model Definition


In [3]:
# Model definition
def get_sigmas_karras(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
    """Generate Karras noise schedule"""
    ramp = torch.linspace(0, 1, n)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return sigmas


class TimeEmbedding(nn.Module):
    """Sinusoidal time embedding"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = time[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb


class ResBlock(nn.Module):
    """Residual block with time conditioning"""
    def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_channels)
        )
        self.block1 = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1)
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )
        self.res_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, time_emb):
        h = self.block1(x)
        time_emb = self.time_mlp(time_emb)
        h = h + time_emb[:, :, None, None]
        h = self.block2(h)
        return h + self.res_conv(x)


class AttentionBlock(nn.Module):
    """Self-attention block"""
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.norm = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h)
        q, k, v = qkv.chunk(3, dim=1)

        # Reshape for attention
        q = q.view(B, C, H * W).permute(0, 2, 1)
        k = k.view(B, C, H * W)
        v = v.view(B, C, H * W).permute(0, 2, 1)

        # Attention
        scale = (C // 1) ** -0.5
        attn = torch.softmax(q @ k * scale, dim=-1)
        h = (attn @ v).permute(0, 2, 1).view(B, C, H, W)

        return x + self.proj(h)


class SimpleUNet(nn.Module):
    """Simple UNet for MNIST"""
    def __init__(self, img_channels=1, label_dim=10, time_emb_dim=128):
        super().__init__()
        self.time_emb_dim = time_emb_dim
        self.time_embed = TimeEmbedding(time_emb_dim)

        # Label embedding
        self.label_embed = nn.Embedding(label_dim, time_emb_dim)

        # Downsampling
        self.conv_in = nn.Conv2d(img_channels, 64, 3, padding=1)
        self.down1 = ResBlock(64, 128, time_emb_dim)
        self.down2 = ResBlock(128, 256, time_emb_dim)
        self.down3 = ResBlock(256, 512, time_emb_dim)

        # Middle
        self.mid_block1 = ResBlock(512, 512, time_emb_dim)
        self.mid_attn = AttentionBlock(512)
        self.mid_block2 = ResBlock(512, 512, time_emb_dim)

        # Upsampling
        self.up1 = ResBlock(512 + 256, 256, time_emb_dim)
        self.up2 = ResBlock(256 + 128, 128, time_emb_dim)
        self.up3 = ResBlock(128 + 64, 64, time_emb_dim)

        # Output
        self.norm_out = nn.GroupNorm(8, 64)
        self.conv_out = nn.Conv2d(64, img_channels, 3, padding=1)

    def forward(self, x, sigma, label, return_bottleneck=False):
        # Handle sigma
        if isinstance(sigma, (int, float)) or (isinstance(sigma, torch.Tensor) and sigma.dim() == 0):
            sigma = torch.full((x.shape[0],), float(sigma), device=x.device)
        elif sigma.dim() > 1:
            sigma = sigma.squeeze()

        # Handle label
        if label.dim() > 1:
            label = label.argmax(dim=1)

        # Time embedding from sigma
        time_emb = self.time_embed(sigma)
        label_emb = self.label_embed(label)
        time_emb = time_emb + label_emb

        # Downsampling
        h1 = self.conv_in(x)
        h2 = self.down1(h1, time_emb)
        h2_down = nn.functional.avg_pool2d(h2, 2)
        h3 = self.down2(h2_down, time_emb)
        h3_down = nn.functional.avg_pool2d(h3, 2)
        h4 = self.down3(h3_down, time_emb)
        h4_down = nn.functional.avg_pool2d(h4, 2)

        # Middle
        h = self.mid_block1(h4_down, time_emb)
        h = self.mid_attn(h)
        h = self.mid_block2(h, time_emb)

        if return_bottleneck:
            return h

        # Upsampling
        h = nn.functional.interpolate(h, size=h3.shape[2:], mode='nearest')
        h = torch.cat([h, h3], dim=1)
        h = self.up1(h, time_emb)

        h = nn.functional.interpolate(h, size=h2.shape[2:], mode='nearest')
        h = torch.cat([h, h2], dim=1)
        h = self.up2(h, time_emb)

        h = nn.functional.interpolate(h, size=h1.shape[2:], mode='nearest')
        h = torch.cat([h, h1], dim=1)
        h = self.up3(h, time_emb)

        # Output
        h = self.norm_out(h)
        h = nn.functional.silu(h)
        out = self.conv_out(h)

        return out


## Configuration


In [4]:
# Training configuration
config = {
    'data_dir': './data',
    'output_dir': './checkpoints/teacher',
    'batch_size': 128,
    'lr': 1e-4,
    'num_epochs': 100,
    'num_train_timesteps': 1000,
    'sigma_min': 0.002,
    'sigma_max': 80.0,
    'sigma_data': 0.5,
    'rho': 7.0,
    'max_grad_norm': 1.0,
    'save_every': 5000,
}


## Setup


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create output directory
os.makedirs(config['output_dir'], exist_ok=True)


Using device: cuda


## Load MNIST Dataset


In [6]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.MNIST(
    root=config['data_dir'],
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

print(f"Dataset loaded: {len(train_dataset)} training samples")


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 479kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.48MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 12.0MB/s]

Dataset loaded: 60000 training samples





## Initialize Model


In [7]:
# Initialize model
model = SimpleUNet(img_channels=1, label_dim=10).to(device)

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=0.01)

# Karras noise schedule
sigmas = get_sigmas_karras(
    config['num_train_timesteps'],
    sigma_min=config['sigma_min'],
    sigma_max=config['sigma_max'],
    rho=config['rho']
)
sigmas = sigmas.to(device)

sigma_data = config['sigma_data']

print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")


Model initialized with 18984897 parameters


## Training Loop


In [None]:
# Training loop
model.train()
global_step = 0

for epoch in range(config['num_epochs']):
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']}")
    epoch_loss = 0.0

    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device)
        labels = labels.to(device)

        # Sample random timesteps
        timesteps = torch.randint(
            0, config['num_train_timesteps'],
            (images.shape[0],),
            device=device,
            dtype=torch.long
        )
        timestep_sigma = sigmas[timesteps]

        # Add noise
        noise = torch.randn_like(images)
        noisy_images = images + timestep_sigma.reshape(-1, 1, 1, 1) * noise

        # Predict x0
        pred_x0 = model(noisy_images, timestep_sigma, labels)

        # Karras loss weighting
        snrs = timestep_sigma ** -2
        weights = snrs + 1.0 / (sigma_data ** 2)

        # Compute loss
        loss = torch.mean(
            weights.reshape(-1, 1, 1, 1) * (pred_x0 - images) ** 2
        )

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
        optimizer.step()

        epoch_loss += loss.item()
        global_step += 1

        # Update progress bar
        pbar.set_postfix({"loss": loss.item(), "avg_loss": epoch_loss / (batch_idx + 1)})

        # Save checkpoint periodically
        if global_step % config['save_every'] == 0:
            checkpoint_path = os.path.join(
                config['output_dir'],
                f"teacher_checkpoint_step_{global_step}.pt"
            )
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'step': global_step,
                'epoch': epoch,
            }, checkpoint_path)
            print(f"\nSaved checkpoint to {checkpoint_path}")

# Save final checkpoint
final_checkpoint_path = os.path.join(config['output_dir'], "teacher_final.pt")
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'step': global_step,
    'epoch': config['num_epochs'],
}, final_checkpoint_path)
print(f"\nSaved final checkpoint to {final_checkpoint_path}")


Epoch 1/100: 100%|██████████| 469/469 [02:01<00:00,  3.85it/s, loss=4.95, avg_loss=72]
Epoch 2/100:  21%|██        | 98/469 [00:25<01:36,  3.85it/s, loss=13.2, avg_loss=7.2]