# Sparsity-Aware Diffusion Model for CIFAR-10

This notebook implements a sparsity-aware diffusion model that learns to reconstruct full images from sparse observations.

**Key Innovation:**
- Provides 20% of pixels as conditioning
- Trains on a different 20% of pixels
- Model learns to reconstruct the full image (100%)

**Dataset:** CIFAR-10 (32x32 RGB images)

## 1. Setup and Imports

In [None]:
import os
import math
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam

import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid, save_image

from einops import rearrange

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Sparsity Controller

Controls the generation of sparse masks for conditioning and target loss computation.

In [None]:
class SparsityController:
    """
    Manages sparsity patterns for training.
    
    Key behaviors:
    - random_epoch: Generate same masks for same sample_id within epoch, new masks next epoch
    - random: Completely random masks every time
    """
    def __init__(self, image_size, mode='random_epoch', pattern='random', 
                 sparsity=0.2, block_size=4, num_blocks=5):
        self.image_size = image_size
        self.mode = mode
        self.pattern = pattern
        self.sparsity = sparsity
        self.block_size = block_size
        self.num_blocks = num_blocks
        
        # Store masks per epoch
        self.epoch_cache = {}
        self.current_epoch = 0
        
    def new_epoch(self):
        """Call this at the start of each epoch to regenerate masks"""
        self.current_epoch += 1
        self.epoch_cache = {}
        
    def _generate_random_mask(self, C, H, W, rng):
        """Generate random binary mask with given sparsity"""
        total_pixels = H * W
        num_sparse = int(total_pixels * self.sparsity)
        
        mask = torch.zeros(C, H, W)
        for c in range(C):
            indices = rng.choice(total_pixels, size=num_sparse, replace=False)
            flat_mask = torch.zeros(total_pixels)
            flat_mask[indices] = 1.0
            mask[c] = flat_mask.reshape(H, W)
        
        return mask
    
    def _generate_masks_for_sample(self, C, sample_id):
        """Generate both conditioning and target masks for a sample"""
        H, W = self.image_size, self.image_size
        
        # Use sample_id + epoch as seed for reproducibility within epoch
        if self.mode == 'random_epoch':
            seed = hash((sample_id, self.current_epoch)) % (2**32)
        else:
            seed = np.random.randint(0, 2**32)
        
        rng = np.random.RandomState(seed)
        
        if self.pattern == 'random':
            # Generate conditioning mask
            cond_mask = self._generate_random_mask(C, H, W, rng)
            
            # Generate target mask (non-overlapping with cond_mask)
            available_pixels = (1 - cond_mask).bool()
            target_mask = torch.zeros(C, H, W)
            
            for c in range(C):
                available_indices = torch.where(available_pixels[c].flatten())[0].numpy()
                num_target = int(H * W * self.sparsity)
                if len(available_indices) >= num_target:
                    target_indices = rng.choice(available_indices, size=num_target, replace=False)
                    flat_mask = torch.zeros(H * W)
                    flat_mask[target_indices] = 1.0
                    target_mask[c] = flat_mask.reshape(H, W)
        
        else:
            raise NotImplementedError(f"Pattern {self.pattern} not implemented")
        
        return cond_mask, target_mask
    
    def get_masks(self, batch_size, num_channels, sample_ids):
        """
        Get masks for a batch of samples.
        
        Returns:
            cond_masks: List of conditioning masks
            target_masks: List of target masks
        """
        cond_masks = []
        target_masks = []
        
        for i in range(batch_size):
            sample_id = sample_ids[i]
            
            # Check cache for random_epoch mode
            if self.mode == 'random_epoch':
                cache_key = (sample_id, self.current_epoch)
                if cache_key in self.epoch_cache:
                    cond_mask, target_mask = self.epoch_cache[cache_key]
                else:
                    cond_mask, target_mask = self._generate_masks_for_sample(num_channels, sample_id)
                    self.epoch_cache[cache_key] = (cond_mask, target_mask)
            else:
                cond_mask, target_mask = self._generate_masks_for_sample(num_channels, sample_id)
            
            cond_masks.append(cond_mask)
            target_masks.append(target_mask)
        
        return cond_masks, target_masks

## 3. Utility Functions

In [None]:
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding for timesteps"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


def num_to_groups(num, divisor):
    """Split num into groups of size divisor"""
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


def visualize_samples(images, sparse_inputs, masks, predictions, target_masks, nrow=4, title="Samples"):
    """Visualize original, sparse input, masks, and predictions"""
    fig, axes = plt.subplots(5, 1, figsize=(15, 15))
    
    # Denormalize from [-1, 1] to [0, 1]
    def denorm(x):
        return (x + 1) / 2
    
    # Original images
    grid = make_grid(denorm(images[:nrow*nrow]), nrow=nrow)
    axes[0].imshow(grid.permute(1, 2, 0).cpu())
    axes[0].set_title("Original Images")
    axes[0].axis('off')
    
    # Sparse inputs (conditioning)
    grid = make_grid(denorm(sparse_inputs[:nrow*nrow]), nrow=nrow)
    axes[1].imshow(grid.permute(1, 2, 0).cpu())
    axes[1].set_title("Sparse Conditioning (20%)")
    axes[1].axis('off')
    
    # Conditioning masks
    grid = make_grid(masks[:nrow*nrow], nrow=nrow)
    axes[2].imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')
    axes[2].set_title("Conditioning Mask")
    axes[2].axis('off')
    
    # Target masks
    grid = make_grid(target_masks[:nrow*nrow], nrow=nrow)
    axes[3].imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')
    axes[3].set_title("Target Loss Mask (different 20%)")
    axes[3].axis('off')
    
    # Predictions
    grid = make_grid(denorm(predictions[:nrow*nrow]), nrow=nrow)
    axes[4].imshow(grid.permute(1, 2, 0).cpu())
    axes[4].set_title("Reconstructed Images")
    axes[4].axis('off')
    
    plt.suptitle(title)
    plt.tight_layout()
    return fig

## 4. U-Net Architecture Components

In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=0.0, groups=32):
        super().__init__()
        self.dim, self.dim_out = dim, dim_out
        dim_out = dim if dim_out is None else dim_out
        
        self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=dim)
        self.activation1 = nn.SiLU()
        self.conv1 = nn.Conv2d(dim, dim_out, kernel_size=3, padding=1)
        
        self.mlp = nn.Sequential(
            nn.SiLU(), 
            nn.Linear(time_emb_dim, dim_out)
        ) if time_emb_dim is not None else None
        
        self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=dim_out)
        self.activation2 = nn.SiLU()
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.conv2 = nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1)
        
        self.residual_conv = nn.Conv2d(dim, dim_out, kernel_size=1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.conv1(self.activation1(self.norm1(x)))
        
        if time_emb is not None and self.mlp is not None:
            h = h + self.mlp(time_emb)[..., None, None]
        
        h = self.conv2(self.dropout(self.activation2(self.norm2(h))))
        return h + self.residual_conv(x)


class Attention(nn.Module):
    def __init__(self, dim, groups=32):
        super().__init__()
        self.dim = dim
        self.scale = dim ** (-0.5)
        
        self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim)
        self.to_qkv = nn.Conv2d(dim, dim * 3, kernel_size=1)
        self.to_out = nn.Conv2d(dim, dim, kernel_size=1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(self.norm(x)).chunk(3, dim=1)
        q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), qkv)
        
        similarity = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale
        attention_score = torch.softmax(similarity, dim=-1)
        attention = torch.einsum('b i j, b j c -> b i c', attention_score, v)
        
        out = rearrange(attention, 'b (h w) c -> b c h w', h=h, w=w)
        return self.to_out(out) + x


class ResnetAttentionBlock(nn.Module):
    def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=0.0, groups=32):
        super().__init__()
        self.resnet = ResnetBlock(dim, dim_out, time_emb_dim, dropout, groups)
        self.attention = Attention(dim_out if dim_out else dim, groups)

    def forward(self, x, time_emb=None):
        x = self.resnet(x, time_emb)
        return self.attention(x)


class DownSample(nn.Module):
    def __init__(self, dim_in):
        super().__init__()
        self.downsample = nn.Conv2d(dim_in, dim_in, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        return self.downsample(x)


class UpSample(nn.Module):
    def __init__(self, dim_in):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(dim_in, dim_in, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return self.upsample(x)

## 5. U-Net Model

In [None]:
class Unet(nn.Module):
    def __init__(self, dim=64, image_size=32, dim_multiply=(1, 2, 4, 8), 
                 channel=3, num_res_blocks=2, attn_resolutions=(16,), 
                 dropout=0.0, groups=32):
        """
        U-Net for noise prediction with sparse conditioning.
        
        Input channels: channel * 3 (noised_image + sparse_input + mask)
        """
        super().__init__()
        assert dim % groups == 0
        
        self.dim = dim
        self.channel = channel
        self.time_emb_dim = 4 * dim
        self.num_resolutions = len(dim_multiply)
        self.resolution = [int(image_size / (2 ** i)) for i in range(self.num_resolutions)]
        self.hidden_dims = [dim, *map(lambda x: x * dim, dim_multiply)]
        
        # Time embedding
        positional_encoding = PositionalEncoding(dim)
        self.time_mlp = nn.Sequential(
            positional_encoding,
            nn.Linear(dim, self.time_emb_dim),
            nn.SiLU(),
            nn.Linear(self.time_emb_dim, self.time_emb_dim)
        )
        
        # Initial convolution (3x channels for concatenated input)
        self.init_conv = nn.Conv2d(channel * 3, dim, kernel_size=3, padding=1)
        
        # Downward path
        self.down_path = nn.ModuleList([])
        concat_dims = [dim]
        
        for level in range(self.num_resolutions):
            d_in = self.hidden_dims[level]
            d_out = self.hidden_dims[level + 1]
            
            for block in range(num_res_blocks):
                d_in_ = d_in if block == 0 else d_out
                
                if self.resolution[level] in attn_resolutions:
                    self.down_path.append(
                        ResnetAttentionBlock(d_in_, d_out, self.time_emb_dim, dropout, groups)
                    )
                else:
                    self.down_path.append(
                        ResnetBlock(d_in_, d_out, self.time_emb_dim, dropout, groups)
                    )
                concat_dims.append(d_out)
            
            if level != self.num_resolutions - 1:
                self.down_path.append(DownSample(d_out))
                concat_dims.append(d_out)
        
        # Middle
        mid_dim = self.hidden_dims[-1]
        self.middle_resnet_attention = ResnetAttentionBlock(
            mid_dim, mid_dim, self.time_emb_dim, dropout, groups
        )
        self.middle_resnet = ResnetBlock(mid_dim, mid_dim, self.time_emb_dim, dropout, groups)
        
        # Upward path
        self.up_path = nn.ModuleList([])
        
        for level in reversed(range(self.num_resolutions)):
            d_out = self.hidden_dims[level + 1]
            
            for block in range(num_res_blocks + 1):
                d_in = self.hidden_dims[level + 2] if block == 0 and level != self.num_resolutions - 1 else d_out
                d_in = d_in + concat_dims.pop()
                
                if self.resolution[level] in attn_resolutions:
                    self.up_path.append(
                        ResnetAttentionBlock(d_in, d_out, self.time_emb_dim, dropout, groups)
                    )
                else:
                    self.up_path.append(
                        ResnetBlock(d_in, d_out, self.time_emb_dim, dropout, groups)
                    )
            
            if level != 0:
                self.up_path.append(UpSample(d_out))
        
        # Output
        final_ch = self.hidden_dims[1]
        self.final_norm = nn.GroupNorm(groups, final_ch)
        self.final_activation = nn.SiLU()
        self.final_conv = nn.Conv2d(final_ch, channel, kernel_size=3, padding=1)

    def forward(self, x, time):
        """Forward pass with concatenated [noised_image, sparse_input, mask]"""
        t = self.time_mlp(time)
        
        # Downward
        concat = []
        x = self.init_conv(x)
        concat.append(x)
        
        for layer in self.down_path:
            if isinstance(layer, (UpSample, DownSample)):
                x = layer(x)
            else:
                x = layer(x, t)
            concat.append(x)
        
        # Middle
        x = self.middle_resnet_attention(x, t)
        x = self.middle_resnet(x, t)
        
        # Upward
        for layer in self.up_path:
            if not isinstance(layer, UpSample):
                x = torch.cat((x, concat.pop()), dim=1)
                x = layer(x, t)
            else:
                x = layer(x)
        
        # Final
        x = self.final_activation(self.final_norm(x))
        return self.final_conv(x)

## 6. Gaussian Diffusion Model

In [None]:
class GaussianDiffusion(nn.Module):
    def __init__(self, model, image_size, time_step=1000, loss_type='l2'):
        """
        Gaussian Diffusion with sparse conditioning.
        
        Args:
            model: U-Net denoising network
            image_size: Image resolution
            time_step: Number of diffusion steps (T)
            loss_type: 'l1', 'l2', or 'huber'
        """
        super().__init__()
        self.unet = model
        self.channel = self.unet.channel
        self.device = next(self.unet.parameters()).device
        self.image_size = image_size
        self.time_step = time_step
        self.loss_type = loss_type
        
        # Noise schedule
        beta = self.linear_beta_schedule()
        alpha = 1. - beta
        alpha_bar = torch.cumprod(alpha, dim=0)
        alpha_bar_prev = F.pad(alpha_bar[:-1], pad=(1, 0), value=1.)
        
        self.register_buffer('beta', beta)
        self.register_buffer('alpha', alpha)
        self.register_buffer('alpha_bar', alpha_bar)
        self.register_buffer('alpha_bar_prev', alpha_bar_prev)
        
        # For q(x_t | x_0)
        self.register_buffer('sqrt_alpha_bar', torch.sqrt(alpha_bar))
        self.register_buffer('sqrt_one_minus_alpha_bar', torch.sqrt(1 - alpha_bar))
        
        # For q(x_{t-1} | x_t, x_0)
        self.register_buffer('beta_tilde', beta * ((1. - alpha_bar_prev) / (1. - alpha_bar)))
        self.register_buffer('mean_tilde_x0_coeff', beta * torch.sqrt(alpha_bar_prev) / (1 - alpha_bar))
        self.register_buffer('mean_tilde_xt_coeff', torch.sqrt(alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar))
        
        # For predicted x0
        self.register_buffer('sqrt_recip_alpha_bar', torch.sqrt(1. / alpha_bar))
        self.register_buffer('sqrt_recip_alpha_bar_min_1', torch.sqrt(1. / alpha_bar - 1))
        
        # For sampling
        self.register_buffer('sqrt_recip_alpha', torch.sqrt(1. / alpha))
        self.register_buffer('beta_over_sqrt_one_minus_alpha_bar', beta / torch.sqrt(1. - alpha_bar))
    
    def linear_beta_schedule(self):
        """Linear beta schedule from DDPM paper"""
        scale = 1000 / self.time_step
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return torch.linspace(beta_start, beta_end, self.time_step, dtype=torch.float32)
    
    def q_sample(self, x0, t, noise):
        """Sample x_t from q(x_t | x_0) using reparameterization trick"""
        return (self.sqrt_alpha_bar[t][:, None, None, None] * x0 +
                self.sqrt_one_minus_alpha_bar[t][:, None, None, None] * noise)
    
    def forward(self, img, sparse_input=None, mask=None, loss_mask=None):
        """
        Training forward pass.
        
        Args:
            img: Original images (B, C, H, W)
            sparse_input: Masked input for conditioning (B, C, H, W)
            mask: Binary mask showing conditioning pixels (B, C, H, W)
            loss_mask: Binary mask for target pixels (B, C, H, W)
        """
        b, c, h, w = img.shape
        assert h == self.image_size and w == self.image_size
        
        # Sample random timestep
        t = torch.randint(0, self.time_step, (b,), device=img.device).long()
        
        # Add noise
        noise = torch.randn_like(img)
        noised_image = self.q_sample(img, t, noise)
        
        # Prepare model input
        if sparse_input is not None and mask is not None:
            model_input = torch.cat([noised_image, sparse_input, mask], dim=1)
        else:
            model_input = noised_image
        
        # Predict noise
        predicted_noise = self.unet(model_input, t)
        
        # Compute loss
        if self.loss_type == 'l1':
            raw_loss = F.l1_loss(noise, predicted_noise, reduction='none')
        elif self.loss_type == 'l2':
            raw_loss = F.mse_loss(noise, predicted_noise, reduction='none')
        elif self.loss_type == 'huber':
            raw_loss = F.smooth_l1_loss(noise, predicted_noise, reduction='none')
        else:
            raise NotImplementedError()
        
        # Apply loss mask (focus on target pixels)
        if loss_mask is not None:
            # Weight target pixels highly, conditioning pixels slightly
            lambda_cond = 0.05
            combined_mask = loss_mask + lambda_cond * mask
            combined_mask = combined_mask.clamp(max=1.0)
            loss = (raw_loss * combined_mask).sum() / combined_mask.sum()
        else:
            loss = raw_loss.mean()
        
        return loss
    
    @torch.inference_mode()
    def p_sample(self, xt, t, clip=True, sparse_input=None, mask=None):
        """Sample x_{t-1} from p(x_{t-1} | x_t)"""
        batched_time = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
        
        # Prepare input
        if sparse_input is not None and mask is not None:
            model_input = torch.cat([xt, sparse_input, mask], dim=1)
        else:
            model_input = xt
        
        pred_noise = self.unet(model_input, batched_time)
        
        # Compute mean
        if clip:
            x0 = self.sqrt_recip_alpha_bar[t] * xt - self.sqrt_recip_alpha_bar_min_1[t] * pred_noise
            x0.clamp_(-1., 1.)
            mean = self.mean_tilde_x0_coeff[t] * x0 + self.mean_tilde_xt_coeff[t] * xt
        else:
            mean = self.sqrt_recip_alpha[t] * (xt - self.beta_over_sqrt_one_minus_alpha_bar[t] * pred_noise)
        
        variance = self.beta_tilde[t]
        noise = torch.randn_like(xt) if t > 0 else 0.
        
        return mean + torch.sqrt(variance) * noise
    
    @torch.inference_mode()
    def sample(self, batch_size=16, sparse_input=None, mask=None, clip=True):
        """Generate samples via reverse diffusion"""
        assert sparse_input is not None and mask is not None, "Must provide sparse_input and mask"
        
        # Start from random noise
        xT = torch.randn([batch_size, self.channel, self.image_size, self.image_size], device=self.device)
        
        xt = xT
        for t in tqdm(reversed(range(0, self.time_step)), desc='Sampling', total=self.time_step, leave=False):
            xt = self.p_sample(xt, t, clip, sparse_input, mask)
        
        # Clamp to valid range
        xt.clamp_(-1., 1.)
        return xt

## 7. CIFAR-10 Dataset Setup

In [None]:
# Configuration
IMAGE_SIZE = 32
BATCH_SIZE = 128
LEARNING_RATE = 2e-4
EPOCHS = 50
SAVE_EVERY = 5

# Create directories
os.makedirs('./results', exist_ok=True)
os.makedirs('./checkpoints', exist_ok=True)

# Data transforms (normalize to [-1, 1])
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # [-1, 1]
])

# Load CIFAR-10
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                         num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                        num_workers=4, pin_memory=True)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 8. Initialize Model and Training Components

In [None]:
# Initialize U-Net
unet = Unet(
    dim=64,
    image_size=IMAGE_SIZE,
    dim_multiply=(1, 2, 4, 8),
    channel=3,
    num_res_blocks=2,
    attn_resolutions=(16,),
    dropout=0.1,
    groups=32
).to(device)

# Initialize Diffusion Model
diffusion = GaussianDiffusion(
    model=unet,
    image_size=IMAGE_SIZE,
    time_step=1000,
    loss_type='l2'
).to(device)

# Initialize Sparsity Controller
sparsity_controller = SparsityController(
    image_size=IMAGE_SIZE,
    mode='random_epoch',
    pattern='random',
    sparsity=0.2,
    block_size=4,
    num_blocks=5
)

# Optimizer
optimizer = Adam(diffusion.parameters(), lr=LEARNING_RATE)

# Count parameters
num_params = sum(p.numel() for p in diffusion.parameters())
print(f"Total parameters: {num_params:,}")

## 9. Training Loop

In [None]:
def train_epoch(epoch):
    diffusion.train()
    sparsity_controller.new_epoch()  # Generate new masks for this epoch
    
    epoch_loss = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device)
        B, C, H, W = images.shape
        
        # Generate sample IDs (use label as pseudo-ID for reproducibility)
        sample_ids = labels.tolist()
        
        # Get masks
        cond_masks, target_masks = sparsity_controller.get_masks(B, C, sample_ids)
        cond_mask = torch.stack(cond_masks).to(device)
        target_mask = torch.stack(target_masks).to(device)
        
        # Create sparse input
        sparse_input = images * cond_mask
        
        # Forward pass
        optimizer.zero_grad()
        loss = diffusion(
            images,
            sparse_input=sparse_input,
            mask=cond_mask,
            loss_mask=target_mask
        )
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(diffusion.parameters(), 1.0)
        optimizer.step()
        
        epoch_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Visualize first batch of first epoch
        if epoch == 0 and batch_idx == 0:
            print(f"\nCond mask mean: {cond_mask.mean():.4f}, Target mask mean: {target_mask.mean():.4f}")
    
    return epoch_loss / len(train_loader)


@torch.no_grad()
def sample_images(epoch, num_samples=16):
    """Generate samples and save visualizations"""
    diffusion.eval()
    
    # Get a batch from test set for conditioning
    test_images, test_labels = next(iter(test_loader))
    test_images = test_images[:num_samples].to(device)
    
    # Generate masks
    sample_ids = test_labels[:num_samples].tolist()
    cond_masks, target_masks = sparsity_controller.get_masks(num_samples, 3, sample_ids)
    cond_mask = torch.stack(cond_masks).to(device)
    target_mask = torch.stack(target_masks).to(device)
    
    # Create sparse input
    sparse_input = test_images * cond_mask
    
    # Sample
    samples = diffusion.sample(
        batch_size=num_samples,
        sparse_input=sparse_input,
        mask=cond_mask,
        clip=True
    )
    
    # Visualize
    fig = visualize_samples(
        test_images, sparse_input, cond_mask, samples, target_mask,
        nrow=4, title=f"Epoch {epoch+1}"
    )
    plt.savefig(f'./results/samples_epoch_{epoch+1:03d}.png', dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"Saved samples for epoch {epoch+1}")

## 10. Run Training

In [None]:
# Training history
train_losses = []

print("Starting training...\n")

for epoch in range(EPOCHS):
    # Train
    avg_loss = train_epoch(epoch)
    train_losses.append(avg_loss)
    
    print(f"\nEpoch {epoch+1}/{EPOCHS} - Average Loss: {avg_loss:.4f}")
    
    # Sample images
    if (epoch + 1) % SAVE_EVERY == 0 or epoch == 0:
        sample_images(epoch, num_samples=16)
        
        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': diffusion.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
            'train_losses': train_losses
        }
        torch.save(checkpoint, f'./checkpoints/checkpoint_epoch_{epoch+1:03d}.pt')
        print(f"Saved checkpoint at epoch {epoch+1}")

print("\nTraining completed!")

## 11. Plot Training Loss

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Time')
plt.legend()
plt.grid(True)
plt.savefig('./results/training_loss.png', dpi=150, bbox_inches='tight')
plt.show()

## 12. Generate Final Samples

In [None]:
# Load best checkpoint if needed
# checkpoint = torch.load('./checkpoints/checkpoint_epoch_050.pt')
# diffusion.load_state_dict(checkpoint['model_state_dict'])

diffusion.eval()

# Generate more samples
num_samples = 64
test_images, test_labels = next(iter(test_loader))
test_images = test_images[:num_samples].to(device)

sample_ids = test_labels[:num_samples].tolist()
cond_masks, target_masks = sparsity_controller.get_masks(num_samples, 3, sample_ids)
cond_mask = torch.stack(cond_masks).to(device)
target_mask = torch.stack(target_masks).to(device)
sparse_input = test_images * cond_mask

samples = diffusion.sample(
    batch_size=num_samples,
    sparse_input=sparse_input,
    mask=cond_mask,
    clip=True
)

# Save final samples
fig = visualize_samples(
    test_images, sparse_input, cond_mask, samples, target_mask,
    nrow=8, title="Final Samples"
)
plt.savefig('./results/final_samples.png', dpi=200, bbox_inches='tight')
plt.show()

print("\n✅ Training and sampling completed!")
print(f"Results saved in ./results/")
print(f"Checkpoints saved in ./checkpoints/")

## Summary

This notebook demonstrates:

1. **Sparsity-Aware Training**: Model receives 20% of pixels as conditioning
2. **Target Loss Masking**: Loss computed on a different 20% of pixels
3. **Full Reconstruction**: Model learns to reconstruct complete images (100%)
4. **CIFAR-10 Application**: Trained on 32×32 RGB images

### Key Results to Observe:
- Loss should decrease over epochs
- Generated samples should progressively improve
- Model learns relationships between sparse observations and full images

### Next Steps:
- Experiment with different sparsity levels (10%, 30%, etc.)
- Try different sparsity patterns (blocks, grids, etc.)
- Increase model capacity for better quality
- Add DDIM sampling for faster inference