In [None]:
import torch
import torch.nn as nn
from typing import Tuple, Optional
from abc import ABC, abstractmethod
import math

# ============================================================================
# Model API - This is what you need to implement
# ============================================================================

class DiffusionModel(nn.Module):
    """
    Interface for the diffusion model (e.g., U-Net)
    """
    @abstractmethod
    def forward(self, x: torch.Tensor, t: torch.Tensor, labels: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Predict the noise added to x at timestep t
        
        Args:
            x: noisy images, shape (batch_size, c, h, w)
            t: timestep indices, shape (batch_size,), values in [0, T-1]
            labels: optional class labels for conditional generation, shape (batch_size,)
        
        Returns:
            predicted_noise: shape (batch_size, c, h, w)
        """
        pass


# ============================================================================
# Diffusion Process
# ============================================================================

class DiffusionProcess:
    """
    Manages the forward and reverse diffusion process (DDPM formulation)
    """
    def __init__(self, num_timesteps: int = 1000, beta_start: float = 1e-4, beta_end: float = 0.02):
        self.num_timesteps = num_timesteps
        
        # Linear noise schedule
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])
        
        # Precompute values for sampling
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        
        # Posterior variance for reverse process
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
    
    def to(self, device):
        """Move all tensors to device"""
        self.betas = self.betas.to(device)
        self.alphas = self.alphas.to(device)
        self.alphas_cumprod = self.alphas_cumprod.to(device)
        self.alphas_cumprod_prev = self.alphas_cumprod_prev.to(device)
        self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
        self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
        self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(device)
        self.posterior_variance = self.posterior_variance.to(device)
        return self
    
    def add_noise(self, x_0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        """
        Forward diffusion: q(x_t | x_0)
        
        Args:
            x_0: clean images, shape (batch_size, c, h, w)
            t: timestep indices, shape (batch_size,)
            noise: Gaussian noise, shape (batch_size, c, h, w)
        
        Returns:
            x_t: noisy images at timestep t
        """
        sqrt_alpha_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        
        return sqrt_alpha_cumprod_t * x_0 + sqrt_one_minus_alpha_cumprod_t * noise


# ============================================================================
# Training
# ============================================================================

def train_diffusion(
    model: DiffusionModel,
    data_sampler: Sampleable,
    diffusion_process: DiffusionProcess,
    num_epochs: int = 100,
    batch_size: int = 128,
    learning_rate: float = 2e-4,
    device: str = "cuda",
    conditional: bool = False,
):
    """
    Train the diffusion model
    """
    model = model.to(device)
    diffusion_process = diffusion_process.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    num_samples = len(data_sampler.dataset)
    steps_per_epoch = num_samples // batch_size
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        
        for step in range(steps_per_epoch):
            # Sample batch of clean images
            x_0, labels = data_sampler.sample(batch_size)
            x_0 = x_0.to(device)
            if conditional:
                labels = labels.to(device)
            
            # Sample random timesteps
            t = torch.randint(0, diffusion_process.num_timesteps, (batch_size,), device=device)
            
            # Sample noise
            noise = torch.randn_like(x_0)
            
            # Create noisy images
            x_t = diffusion_process.add_noise(x_0, t, noise)
            
            # Predict noise
            if conditional:
                predicted_noise = model(x_t, t, labels)
            else:
                predicted_noise = model(x_t, t)
            
            # Compute loss (simple MSE between predicted and actual noise)
            loss = torch.nn.functional.mse_loss(predicted_noise, noise)
            
            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / steps_per_epoch
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")


# ============================================================================
# Sampling
# ============================================================================

@torch.no_grad()
def sample_diffusion(
    model: DiffusionModel,
    diffusion_process: DiffusionProcess,
    num_samples: int,
    image_shape: Tuple[int, int, int] = (1, 32, 32),
    device: str = "cuda",
    labels: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Sample from the diffusion model using DDPM sampling
    
    Args:
        model: trained diffusion model
        diffusion_process: diffusion process configuration
        num_samples: number of samples to generate
        image_shape: (c, h, w)
        device: device to run on
        labels: optional class labels for conditional generation, shape (num_samples,)
    
    Returns:
        samples: generated images, shape (num_samples, c, h, w)
    """
    model.eval()
    
    # Start from pure noise
    x_t = torch.randn(num_samples, *image_shape, device=device)
    
    # Reverse diffusion process
    for t in reversed(range(diffusion_process.num_timesteps)):
        t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)
        
        # Predict noise
        if labels is not None:
            predicted_noise = model(x_t, t_batch, labels)
        else:
            predicted_noise = model(x_t, t_batch)
        
        # Compute coefficients
        alpha_t = diffusion_process.alphas[t]
        alpha_cumprod_t = diffusion_process.alphas_cumprod[t]
        beta_t = diffusion_process.betas[t]
        sqrt_one_minus_alpha_cumprod_t = diffusion_process.sqrt_one_minus_alphas_cumprod[t]
        
        # Compute mean of reverse process
        model_mean = (1.0 / torch.sqrt(alpha_t)) * (
            x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * predicted_noise
        )
        
        if t > 0:
            # Add noise (except for final step)
            noise = torch.randn_like(x_t)
            posterior_variance_t = diffusion_process.posterior_variance[t]
            x_t = model_mean + torch.sqrt(posterior_variance_t) * noise
        else:
            x_t = model_mean
    
    return x_t


# ============================================================================
# Example Usage
# ============================================================================

if __name__ == "__main__":
    # Initialize components
    data_sampler = MNISTSampler()
    diffusion_process = DiffusionProcess(num_timesteps=1000)
    
    # TODO: Implement your model (e.g., U-Net)
    # model = YourDiffusionUNet(in_channels=1, num_classes=10)
    
    # Train
    # train_diffusion(
    #     model=model,
    #     data_sampler=data_sampler,
    #     diffusion_process=diffusion_process,
    #     num_epochs=100,
    #     batch_size=128,
    #     device="cuda",
    #     conditional=True  # Set to True for class-conditional generation
    # )
    
    # Sample
    # samples = sample_diffusion(
    #     model=model,
    #     diffusion_process=diffusion_process,
    #     num_samples=16,
    #     image_shape=(1, 32, 32),
    #     device="cuda",
    #     labels=torch.arange(16) % 10  # Generate one of each digit
    # )