# DMD2 Distillation Training - MNIST

This notebook trains a fast feedforward model using DMD2 distillation on MNIST.


In [None]:
# Install dependencies
!pip install torch torchvision tqdm -q


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import os
import numpy as np
import copy


## Model Definitions


In [None]:
# Model definition (same as teacher training)
def get_sigmas_karras(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
    """Generate Karras noise schedule"""
    ramp = torch.linspace(0, 1, n)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return sigmas


class TimeEmbedding(nn.Module):
    """Sinusoidal time embedding"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = time[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb


class ResBlock(nn.Module):
    """Residual block with time conditioning"""
    def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_channels)
        )
        self.block1 = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1)
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )
        self.res_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, time_emb):
        h = self.block1(x)
        time_emb = self.time_mlp(time_emb)
        h = h + time_emb[:, :, None, None]
        h = self.block2(h)
        return h + self.res_conv(x)


class AttentionBlock(nn.Module):
    """Self-attention block"""
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.norm = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h)
        q, k, v = qkv.chunk(3, dim=1)
        
        # Reshape for attention
        q = q.view(B, C, H * W).permute(0, 2, 1)
        k = k.view(B, C, H * W)
        v = v.view(B, C, H * W).permute(0, 2, 1)
        
        # Attention
        scale = (C // 1) ** -0.5
        attn = torch.softmax(q @ k * scale, dim=-1)
        h = (attn @ v).permute(0, 2, 1).view(B, C, H, W)
        
        return x + self.proj(h)


class SimpleUNet(nn.Module):
    """Simple UNet for MNIST"""
    def __init__(self, img_channels=1, label_dim=10, time_emb_dim=128):
        super().__init__()
        self.time_emb_dim = time_emb_dim
        self.time_embed = TimeEmbedding(time_emb_dim)
        
        # Label embedding
        self.label_embed = nn.Embedding(label_dim, time_emb_dim)
        
        # Downsampling
        self.conv_in = nn.Conv2d(img_channels, 64, 3, padding=1)
        self.down1 = ResBlock(64, 128, time_emb_dim)
        self.down2 = ResBlock(128, 256, time_emb_dim)
        self.down3 = ResBlock(256, 512, time_emb_dim)
        
        # Middle
        self.mid_block1 = ResBlock(512, 512, time_emb_dim)
        self.mid_attn = AttentionBlock(512)
        self.mid_block2 = ResBlock(512, 512, time_emb_dim)
        
        # Upsampling
        self.up1 = ResBlock(512 + 256, 256, time_emb_dim)
        self.up2 = ResBlock(256 + 128, 128, time_emb_dim)
        self.up3 = ResBlock(128 + 64, 64, time_emb_dim)
        
        # Output
        self.norm_out = nn.GroupNorm(8, 64)
        self.conv_out = nn.Conv2d(64, img_channels, 3, padding=1)
        
    def forward(self, x, sigma, label, return_bottleneck=False):
        # Handle sigma
        if isinstance(sigma, (int, float)) or (isinstance(sigma, torch.Tensor) and sigma.dim() == 0):
            sigma = torch.full((x.shape[0],), float(sigma), device=x.device)
        elif sigma.dim() > 1:
            sigma = sigma.squeeze()
        
        # Handle label
        if label.dim() > 1:
            label = label.argmax(dim=1)
        
        # Time embedding from sigma
        time_emb = self.time_embed(sigma)
        label_emb = self.label_embed(label)
        time_emb = time_emb + label_emb
        
        # Downsampling
        h1 = self.conv_in(x)
        h2 = self.down1(h1, time_emb)
        h2_down = nn.functional.avg_pool2d(h2, 2)
        h3 = self.down2(h2_down, time_emb)
        h3_down = nn.functional.avg_pool2d(h3, 2)
        h4 = self.down3(h3_down, time_emb)
        h4_down = nn.functional.avg_pool2d(h4, 2)
        
        # Middle
        h = self.mid_block1(h4_down, time_emb)
        h = self.mid_attn(h)
        h = self.mid_block2(h, time_emb)
        
        if return_bottleneck:
            return h
        
        # Upsampling
        h = nn.functional.interpolate(h, size=h3.shape[2:], mode='nearest')
        h = torch.cat([h, h3], dim=1)
        h = self.up1(h, time_emb)
        
        h = nn.functional.interpolate(h, size=h2.shape[2:], mode='nearest')
        h = torch.cat([h, h2], dim=1)
        h = self.up2(h, time_emb)
        
        h = nn.functional.interpolate(h, size=h1.shape[2:], mode='nearest')
        h = torch.cat([h, h1], dim=1)
        h = self.up3(h, time_emb)
        
        # Output
        h = self.norm_out(h)
        h = nn.functional.silu(h)
        out = self.conv_out(h)
        
        return out


In [None]:
# Guidance Model
class GuidanceModel(nn.Module):
    """Guidance model for DMD2 training"""
    def __init__(self, num_train_timesteps=1000, sigma_min=0.002, sigma_max=80.0, 
                 sigma_data=0.5, rho=7.0, min_step_percent=0.02, max_step_percent=0.98):
        super().__init__()
        
        # Real UNet (teacher) - frozen
        self.real_unet = SimpleUNet(img_channels=1, label_dim=10)
        self.real_unet.requires_grad_(False)
        
        # Fake UNet (student) - trainable
        self.fake_unet = copy.deepcopy(self.real_unet)
        self.fake_unet.requires_grad_(True)
        
        # Training parameters
        self.sigma_data = sigma_data
        self.sigma_max = sigma_max
        self.sigma_min = sigma_min
        self.rho = rho
        self.num_train_timesteps = num_train_timesteps
        
        # Karras noise schedule
        karras_sigmas = torch.flip(
            get_sigmas_karras(num_train_timesteps, sigma_max=sigma_max, 
                             sigma_min=sigma_min, rho=rho),
            dims=[0]
        )
        self.register_buffer("karras_sigmas", karras_sigmas)
        
        self.min_step = int(min_step_percent * num_train_timesteps)
        self.max_step = int(max_step_percent * num_train_timesteps)
    
    def compute_distribution_matching_loss(self, latents, labels):
        batch_size = latents.shape[0]
        
        # Sample random timesteps
        with torch.no_grad():
            timesteps = torch.randint(
                self.min_step,
                min(self.max_step + 1, self.num_train_timesteps),
                [batch_size, 1, 1, 1],
                device=latents.device,
                dtype=torch.long
            )
            
            noise = torch.randn_like(latents)
            timestep_sigma = self.karras_sigmas[timesteps.squeeze()]
            
            # Add noise
            noisy_latents = latents + timestep_sigma.reshape(-1, 1, 1, 1) * noise
            
            # Predictions from both models
            pred_real_image = self.real_unet(noisy_latents, timestep_sigma, labels)
            pred_fake_image = self.fake_unet(noisy_latents, timestep_sigma, labels)
            
            # Compute gradient direction
            p_real = latents - pred_real_image
            p_fake = latents - pred_fake_image
            
            # Weight factor for normalization
            weight_factor = torch.abs(p_real).mean(dim=[1, 2, 3], keepdim=True)
            grad = (p_real - p_fake) / (weight_factor + 1e-8)
            grad = torch.nan_to_num(grad)
        
        # Distribution matching loss (gradient matching)
        loss = 0.5 * F.mse_loss(latents, (latents - grad).detach(), reduction="mean")
        
        loss_dict = {"loss_dm": loss}
        log_dict = {
            "dmtrain_noisy_latents": noisy_latents.detach(),
            "dmtrain_pred_real_image": pred_real_image.detach(),
            "dmtrain_pred_fake_image": pred_fake_image.detach(),
            "dmtrain_grad": grad.detach(),
            "dmtrain_gradient_norm": torch.norm(grad).item(),
            "dmtrain_timesteps": timesteps.detach(),
        }
        
        return loss_dict, log_dict
    
    def compute_loss_fake(self, latents, labels):
        batch_size = latents.shape[0]
        latents = latents.detach()  # No gradient to generator
        
        noise = torch.randn_like(latents)
        
        # Sample random timesteps
        timesteps = torch.randint(
            0,
            self.num_train_timesteps,
            [batch_size, 1, 1, 1],
            device=latents.device,
            dtype=torch.long
        )
        timestep_sigma = self.karras_sigmas[timesteps.squeeze()]
        
        # Add noise
        noisy_latents = latents + timestep_sigma.reshape(-1, 1, 1, 1) * noise
        
        # Predict x0
        fake_x0_pred = self.fake_unet(noisy_latents, timestep_sigma, labels)
        
        # Karras weighting
        snrs = timestep_sigma ** -2
        weights = snrs + 1.0 / (self.sigma_data ** 2)
        
        target = latents
        
        loss_fake = torch.mean(weights.reshape(-1, 1, 1, 1) * (fake_x0_pred - target) ** 2)
        
        loss_dict = {"loss_fake_mean": loss_fake}
        log_dict = {
            "faketrain_latents": latents.detach(),
            "faketrain_noisy_latents": noisy_latents.detach(),
            "faketrain_x0_pred": fake_x0_pred.detach()
        }
        
        return loss_dict, log_dict
    
    def forward(self, generator_turn=False, guidance_turn=False,
                generator_data_dict=None, guidance_data_dict=None):
        if generator_turn:
            assert generator_data_dict is not None
            loss_dict, log_dict = self.compute_distribution_matching_loss(
                generator_data_dict['image'],
                generator_data_dict['label']
            )
        elif guidance_turn:
            assert guidance_data_dict is not None
            loss_dict, log_dict = self.compute_loss_fake(
                guidance_data_dict['image'],
                guidance_data_dict['label']
            )
        else:
            raise ValueError("Either generator_turn or guidance_turn must be True")
        
        return loss_dict, log_dict


In [None]:
# Unified Model
class UnifiedModel(nn.Module):
    """Unified model wrapping generator and guidance"""
    def __init__(self, num_train_timesteps=1000, sigma_min=0.002, sigma_max=80.0,
                 sigma_data=0.5, rho=7.0, min_step_percent=0.02, max_step_percent=0.98,
                 conditioning_sigma=80.0):
        super().__init__()
        
        # Guidance model (contains real_unet and fake_unet)
        self.guidance_model = GuidanceModel(
            num_train_timesteps=num_train_timesteps,
            sigma_min=sigma_min,
            sigma_max=sigma_max,
            sigma_data=sigma_data,
            rho=rho,
            min_step_percent=min_step_percent,
            max_step_percent=max_step_percent
        )
        
        # Feedforward generator model (initialized from fake_unet)
        self.feedforward_model = copy.deepcopy(self.guidance_model.fake_unet)
        self.feedforward_model.requires_grad_(True)
        
        self.conditioning_sigma = conditioning_sigma
        self.num_train_timesteps = num_train_timesteps
    
    def forward(self, scaled_noisy_image, timestep_sigma, labels,
                real_train_dict=None,
                compute_generator_gradient=False,
                generator_turn=False,
                guidance_turn=False,
                guidance_data_dict=None):
        assert (generator_turn and not guidance_turn) or (guidance_turn and not generator_turn)
        
        if generator_turn:
            # Generate image with feedforward model
            if not compute_generator_gradient:
                with torch.no_grad():
                    generated_image = self.feedforward_model(
                        scaled_noisy_image, timestep_sigma, labels
                    )
            else:
                generated_image = self.feedforward_model(
                    scaled_noisy_image, timestep_sigma, labels
                )
            
            # Compute distribution matching loss if needed
            if compute_generator_gradient:
                generator_data_dict = {
                    "image": generated_image,
                    "label": labels,
                    "real_train_dict": real_train_dict
                }
                
                # Disable gradient for guidance model to avoid side effects
                self.guidance_model.requires_grad_(False)
                loss_dict, log_dict = self.guidance_model(
                    generator_turn=True,
                    guidance_turn=False,
                    generator_data_dict=generator_data_dict
                )
                self.guidance_model.requires_grad_(True)
            else:
                loss_dict = {}
                log_dict = {}
            
            log_dict['generated_image'] = generated_image.detach()
            log_dict['guidance_data_dict'] = {
                "image": generated_image.detach(),
                "label": labels.detach(),
                "real_train_dict": real_train_dict
            }
        
        elif guidance_turn:
            assert guidance_data_dict is not None
            loss_dict, log_dict = self.guidance_model(
                generator_turn=False,
                guidance_turn=True,
                guidance_data_dict=guidance_data_dict
            )
        
        return loss_dict, log_dict


## Configuration


In [None]:
# Training configuration
config = {
    'data_dir': './data',
    'output_dir': './checkpoints/dmd2',
    'teacher_checkpoint': './checkpoints/teacher/teacher_final.pt',  # Update this path
    'batch_size': 128,
    'generator_lr': 2e-6,
    'guidance_lr': 2e-6,
    'num_epochs': 50,
    'num_train_timesteps': 1000,
    'sigma_min': 0.002,
    'sigma_max': 80.0,
    'sigma_data': 0.5,
    'rho': 7.0,
    'min_step_percent': 0.02,
    'max_step_percent': 0.98,
    'conditioning_sigma': 80.0,
    'dfake_gen_update_ratio': 10,
    'dm_loss_weight': 1.0,
    'max_grad_norm': 1.0,
    'save_every': 5000,
}


## Setup


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create output directory
os.makedirs(config['output_dir'], exist_ok=True)


## Load MNIST Dataset


In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.MNIST(
    root=config['data_dir'],
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

print(f"Dataset loaded: {len(train_dataset)} training samples")


## Initialize Model


In [None]:
# Initialize unified model
model = UnifiedModel(
    num_train_timesteps=config['num_train_timesteps'],
    sigma_min=config['sigma_min'],
    sigma_max=config['sigma_max'],
    sigma_data=config['sigma_data'],
    rho=config['rho'],
    min_step_percent=config['min_step_percent'],
    max_step_percent=config['max_step_percent'],
    conditioning_sigma=config['conditioning_sigma']
).to(device)

# Load teacher checkpoint into real_unet
if config['teacher_checkpoint']:
    print(f"Loading teacher checkpoint from {config['teacher_checkpoint']}")
    checkpoint = torch.load(config['teacher_checkpoint'], map_location=device)
    model.guidance_model.real_unet.load_state_dict(checkpoint['model_state_dict'])
    print("Teacher model loaded successfully")

# Optimizers
optimizer_generator = optim.AdamW(
    model.feedforward_model.parameters(),
    lr=config['generator_lr'],
    weight_decay=0.01
)

optimizer_guidance = optim.AdamW(
    model.guidance_model.fake_unet.parameters(),
    lr=config['guidance_lr'],
    weight_decay=0.01
)

# Eye matrix for one-hot encoding
eye_matrix = torch.eye(10, device=device)

print(f"Model initialized")


## Training Loop


In [None]:
# Training loop
model.train()
global_step = 0

for epoch in range(config['num_epochs']):
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']}")
    epoch_loss_dm = 0.0
    epoch_loss_fake = 0.0
    
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device)
        labels = labels.to(device)
        
        # Convert labels to one-hot
        labels_onehot = eye_matrix[labels]
        
        # Determine if we should compute generator gradient
        COMPUTE_GENERATOR_GRADIENT = (global_step % config['dfake_gen_update_ratio'] == 0)
        
        # ========== Generator Turn ==========
        # Generate scaled noise
        scaled_noise = torch.randn_like(images) * config['conditioning_sigma']
        timestep_sigma = torch.ones(images.shape[0], device=device) * config['conditioning_sigma']
        
        # Random labels for generation
        gen_labels = torch.randint(0, 10, (images.shape[0],), device=device)
        gen_labels_onehot = eye_matrix[gen_labels]
        
        # Real training dict (for optional GAN loss)
        real_train_dict = {
            "real_image": images,
            "real_label": labels_onehot
        }
        
        # Forward pass through generator
        generator_loss_dict, generator_log_dict = model(
            scaled_noisy_image=scaled_noise,
            timestep_sigma=timestep_sigma,
            labels=gen_labels_onehot,
            real_train_dict=real_train_dict if COMPUTE_GENERATOR_GRADIENT else None,
            compute_generator_gradient=COMPUTE_GENERATOR_GRADIENT,
            generator_turn=True,
            guidance_turn=False
        )
        
        # Update generator if needed
        if COMPUTE_GENERATOR_GRADIENT:
            generator_loss = generator_loss_dict["loss_dm"] * config['dm_loss_weight']
            
            optimizer_generator.zero_grad()
            generator_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                model.feedforward_model.parameters(),
                config['max_grad_norm']
            )
            optimizer_generator.step()
            optimizer_generator.zero_grad()
            optimizer_guidance.zero_grad()
            
            epoch_loss_dm += generator_loss.item()
        
        # ========== Guidance Turn ==========
        # Update guidance model (fake_unet)
        guidance_loss_dict, guidance_log_dict = model(
            scaled_noisy_image=None,  # Not used in guidance turn
            timestep_sigma=None,  # Not used in guidance turn
            labels=None,  # Not used in guidance turn
            compute_generator_gradient=False,
            generator_turn=False,
            guidance_turn=True,
            guidance_data_dict=generator_log_dict['guidance_data_dict']
        )
        
        guidance_loss = guidance_loss_dict["loss_fake_mean"]
        
        optimizer_guidance.zero_grad()
        guidance_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.guidance_model.fake_unet.parameters(),
            config['max_grad_norm']
        )
        optimizer_guidance.step()
        optimizer_guidance.zero_grad()
        optimizer_generator.zero_grad()
        
        epoch_loss_fake += guidance_loss.item()
        
        global_step += 1
        
        # Update progress bar
        pbar.set_postfix({
            "loss_dm": epoch_loss_dm / max(1, (batch_idx + 1) // config['dfake_gen_update_ratio']),
            "loss_fake": epoch_loss_fake / (batch_idx + 1),
        })
        
        # Save checkpoint periodically
        if global_step % config['save_every'] == 0:
            checkpoint_path = os.path.join(
                config['output_dir'],
                f"dmd2_checkpoint_step_{global_step}.pt"
            )
            torch.save({
                'feedforward_model_state_dict': model.feedforward_model.state_dict(),
                'guidance_fake_unet_state_dict': model.guidance_model.fake_unet.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'optimizer_guidance_state_dict': optimizer_guidance.state_dict(),
                'step': global_step,
                'epoch': epoch,
            }, checkpoint_path)
            print(f"\nSaved checkpoint to {checkpoint_path}")

# Save final checkpoint
final_checkpoint_path = os.path.join(config['output_dir'], "dmd2_final.pt")
torch.save({
    'feedforward_model_state_dict': model.feedforward_model.state_dict(),
    'guidance_fake_unet_state_dict': model.guidance_model.fake_unet.state_dict(),
    'optimizer_generator_state_dict': optimizer_generator.state_dict(),
    'optimizer_guidance_state_dict': optimizer_guidance.state_dict(),
    'step': global_step,
    'epoch': config['num_epochs'],
}, final_checkpoint_path)
print(f"\nSaved final checkpoint to {final_checkpoint_path}")
