# Sparsity-Aware Flow Matching Model for CIFAR-10 (v2)

This notebook implements a **Flow Matching** approach (instead of diffusion) for sparsity-aware image reconstruction.

**Key Innovation:**
- Uses **Conditional Flow Matching** with optimal transport paths
- Provides 20% of pixels as conditioning
- Trains on a different 20% of pixels
- Model learns to reconstruct the full image (100%)
- Supports both **ODE (deterministic)** and **SDE (stochastic)** sampling

**Advantages over Diffusion:**
- Faster sampling (10-50 steps vs 1000 steps)
- Straighter probability paths (optimal transport)
- More stable training
- Flexible sampling (deterministic or stochastic)

**Dataset:** CIFAR-10 (32x32 RGB images)

## 1. Setup and Imports

In [None]:
import os
import math
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam

import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid, save_image

from einops import rearrange

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Sparsity Controller (Same as v1)

In [None]:
class SparsityController:
    """
    Manages sparsity patterns for training.
    
    Key behaviors:
    - random_epoch: Generate same masks for same sample_id within epoch, new masks next epoch
    - random: Completely random masks every time
    """
    def __init__(self, image_size, mode='random_epoch', pattern='random', 
                 sparsity=0.2, block_size=4, num_blocks=5):
        self.image_size = image_size
        self.mode = mode
        self.pattern = pattern
        self.sparsity = sparsity
        self.block_size = block_size
        self.num_blocks = num_blocks
        
        # Store masks per epoch
        self.epoch_cache = {}
        self.current_epoch = 0
        
    def new_epoch(self):
        """Call this at the start of each epoch to regenerate masks"""
        self.current_epoch += 1
        self.epoch_cache = {}
        
    def _generate_random_mask(self, C, H, W, rng):
        """Generate random binary mask with given sparsity"""
        total_pixels = H * W
        num_sparse = int(total_pixels * self.sparsity)
        
        mask = torch.zeros(C, H, W)
        for c in range(C):
            indices = rng.choice(total_pixels, size=num_sparse, replace=False)
            flat_mask = torch.zeros(total_pixels)
            flat_mask[indices] = 1.0
            mask[c] = flat_mask.reshape(H, W)
        
        return mask
    
    def _generate_masks_for_sample(self, C, sample_id):
        """Generate both conditioning and target masks for a sample"""
        H, W = self.image_size, self.image_size
        
        # Use sample_id + epoch as seed for reproducibility within epoch
        if self.mode == 'random_epoch':
            seed = hash((sample_id, self.current_epoch)) % (2**32)
        else:
            seed = np.random.randint(0, 2**32)
        
        rng = np.random.RandomState(seed)
        
        if self.pattern == 'random':
            # Generate conditioning mask
            cond_mask = self._generate_random_mask(C, H, W, rng)
            
            # Generate target mask (non-overlapping with cond_mask)
            available_pixels = (1 - cond_mask).bool()
            target_mask = torch.zeros(C, H, W)
            
            for c in range(C):
                available_indices = torch.where(available_pixels[c].flatten())[0].numpy()
                num_target = int(H * W * self.sparsity)
                if len(available_indices) >= num_target:
                    target_indices = rng.choice(available_indices, size=num_target, replace=False)
                    flat_mask = torch.zeros(H * W)
                    flat_mask[target_indices] = 1.0
                    target_mask[c] = flat_mask.reshape(H, W)
        
        else:
            raise NotImplementedError(f"Pattern {self.pattern} not implemented")
        
        return cond_mask, target_mask
    
    def get_masks(self, batch_size, num_channels, sample_ids):
        """
        Get masks for a batch of samples.
        
        Returns:
            cond_masks: List of conditioning masks
            target_masks: List of target masks
        """
        cond_masks = []
        target_masks = []
        
        for i in range(batch_size):
            sample_id = sample_ids[i]
            
            # Check cache for random_epoch mode
            if self.mode == 'random_epoch':
                cache_key = (sample_id, self.current_epoch)
                if cache_key in self.epoch_cache:
                    cond_mask, target_mask = self.epoch_cache[cache_key]
                else:
                    cond_mask, target_mask = self._generate_masks_for_sample(num_channels, sample_id)
                    self.epoch_cache[cache_key] = (cond_mask, target_mask)
            else:
                cond_mask, target_mask = self._generate_masks_for_sample(num_channels, sample_id)
            
            cond_masks.append(cond_mask)
            target_masks.append(target_mask)
        
        return cond_masks, target_masks

## 3. Utility Functions

In [None]:
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding for time t ∈ [0, 1]"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


def num_to_groups(num, divisor):
    """Split num into groups of size divisor"""
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


def visualize_samples(images, sparse_inputs, masks, predictions, target_masks, nrow=4, title="Samples"):
    """Visualize original, sparse input, masks, and predictions"""
    fig, axes = plt.subplots(5, 1, figsize=(15, 15))
    
    # Denormalize from [-1, 1] to [0, 1]
    def denorm(x):
        return (x + 1) / 2
    
    # Original images
    grid = make_grid(denorm(images[:nrow*nrow]), nrow=nrow)
    axes[0].imshow(grid.permute(1, 2, 0).cpu())
    axes[0].set_title("Original Images")
    axes[0].axis('off')
    
    # Sparse inputs (conditioning)
    grid = make_grid(denorm(sparse_inputs[:nrow*nrow]), nrow=nrow)
    axes[1].imshow(grid.permute(1, 2, 0).cpu())
    axes[1].set_title("Sparse Conditioning (20%)")
    axes[1].axis('off')
    
    # Conditioning masks
    grid = make_grid(masks[:nrow*nrow], nrow=nrow)
    axes[2].imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')
    axes[2].set_title("Conditioning Mask")
    axes[2].axis('off')
    
    # Target masks
    grid = make_grid(target_masks[:nrow*nrow], nrow=nrow)
    axes[3].imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')
    axes[3].set_title("Target Loss Mask (different 20%)")
    axes[3].axis('off')
    
    # Predictions
    grid = make_grid(denorm(predictions[:nrow*nrow]), nrow=nrow)
    axes[4].imshow(grid.permute(1, 2, 0).cpu())
    axes[4].set_title("Reconstructed Images")
    axes[4].axis('off')
    
    plt.suptitle(title)
    plt.tight_layout()
    return fig

## 4. U-Net Architecture Components (Same as v1)

In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=0.0, groups=32):
        super().__init__()
        self.dim, self.dim_out = dim, dim_out
        dim_out = dim if dim_out is None else dim_out
        
        self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=dim)
        self.activation1 = nn.SiLU()
        self.conv1 = nn.Conv2d(dim, dim_out, kernel_size=3, padding=1)
        
        self.mlp = nn.Sequential(
            nn.SiLU(), 
            nn.Linear(time_emb_dim, dim_out)
        ) if time_emb_dim is not None else None
        
        self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=dim_out)
        self.activation2 = nn.SiLU()
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.conv2 = nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1)
        
        self.residual_conv = nn.Conv2d(dim, dim_out, kernel_size=1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.conv1(self.activation1(self.norm1(x)))
        
        if time_emb is not None and self.mlp is not None:
            h = h + self.mlp(time_emb)[..., None, None]
        
        h = self.conv2(self.dropout(self.activation2(self.norm2(h))))
        return h + self.residual_conv(x)


class Attention(nn.Module):
    def __init__(self, dim, groups=32):
        super().__init__()
        self.dim = dim
        self.scale = dim ** (-0.5)
        
        self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim)
        self.to_qkv = nn.Conv2d(dim, dim * 3, kernel_size=1)
        self.to_out = nn.Conv2d(dim, dim, kernel_size=1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(self.norm(x)).chunk(3, dim=1)
        q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), qkv)
        
        similarity = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale
        attention_score = torch.softmax(similarity, dim=-1)
        attention = torch.einsum('b i j, b j c -> b i c', attention_score, v)
        
        out = rearrange(attention, 'b (h w) c -> b c h w', h=h, w=w)
        return self.to_out(out) + x


class ResnetAttentionBlock(nn.Module):
    def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=0.0, groups=32):
        super().__init__()
        self.resnet = ResnetBlock(dim, dim_out, time_emb_dim, dropout, groups)
        self.attention = Attention(dim_out if dim_out else dim, groups)

    def forward(self, x, time_emb=None):
        x = self.resnet(x, time_emb)
        return self.attention(x)


class DownSample(nn.Module):
    def __init__(self, dim_in):
        super().__init__()
        self.downsample = nn.Conv2d(dim_in, dim_in, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        return self.downsample(x)


class UpSample(nn.Module):
    def __init__(self, dim_in):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(dim_in, dim_in, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return self.upsample(x)

## 5. Velocity Network (U-Net for Flow Matching)

In [None]:
class VelocityNet(nn.Module):
    """U-Net that predicts velocity field v_t for flow matching"""
    def __init__(self, dim=64, image_size=32, dim_multiply=(1, 2, 4, 8), 
                 channel=3, num_res_blocks=2, attn_resolutions=(16,), 
                 dropout=0.0, groups=32):
        super().__init__()
        assert dim % groups == 0
        
        self.dim = dim
        self.channel = channel
        self.time_emb_dim = 4 * dim
        self.num_resolutions = len(dim_multiply)
        self.resolution = [int(image_size / (2 ** i)) for i in range(self.num_resolutions)]
        self.hidden_dims = [dim, *map(lambda x: x * dim, dim_multiply)]
        
        # Time embedding (t ∈ [0, 1])
        positional_encoding = PositionalEncoding(dim)
        self.time_mlp = nn.Sequential(
            positional_encoding,
            nn.Linear(dim, self.time_emb_dim),
            nn.SiLU(),
            nn.Linear(self.time_emb_dim, self.time_emb_dim)
        )
        
        # Initial convolution (3x channels for concatenated input)
        self.init_conv = nn.Conv2d(channel * 3, dim, kernel_size=3, padding=1)
        
        # Downward path
        self.down_path = nn.ModuleList([])
        concat_dims = [dim]
        
        for level in range(self.num_resolutions):
            d_in = self.hidden_dims[level]
            d_out = self.hidden_dims[level + 1]
            
            for block in range(num_res_blocks):
                d_in_ = d_in if block == 0 else d_out
                
                if self.resolution[level] in attn_resolutions:
                    self.down_path.append(
                        ResnetAttentionBlock(d_in_, d_out, self.time_emb_dim, dropout, groups)
                    )
                else:
                    self.down_path.append(
                        ResnetBlock(d_in_, d_out, self.time_emb_dim, dropout, groups)
                    )
                concat_dims.append(d_out)
            
            if level != self.num_resolutions - 1:
                self.down_path.append(DownSample(d_out))
                concat_dims.append(d_out)
        
        # Middle
        mid_dim = self.hidden_dims[-1]
        self.middle_resnet_attention = ResnetAttentionBlock(
            mid_dim, mid_dim, self.time_emb_dim, dropout, groups
        )
        self.middle_resnet = ResnetBlock(mid_dim, mid_dim, self.time_emb_dim, dropout, groups)
        
        # Upward path
        self.up_path = nn.ModuleList([])
        
        for level in reversed(range(self.num_resolutions)):
            d_out = self.hidden_dims[level + 1]
            
            for block in range(num_res_blocks + 1):
                d_in = self.hidden_dims[level + 2] if block == 0 and level != self.num_resolutions - 1 else d_out
                d_in = d_in + concat_dims.pop()
                
                if self.resolution[level] in attn_resolutions:
                    self.up_path.append(
                        ResnetAttentionBlock(d_in, d_out, self.time_emb_dim, dropout, groups)
                    )
                else:
                    self.up_path.append(
                        ResnetBlock(d_in, d_out, self.time_emb_dim, dropout, groups)
                    )
            
            if level != 0:
                self.up_path.append(UpSample(d_out))
        
        # Output (predict velocity)
        final_ch = self.hidden_dims[1]
        self.final_norm = nn.GroupNorm(groups, final_ch)
        self.final_activation = nn.SiLU()
        self.final_conv = nn.Conv2d(final_ch, channel, kernel_size=3, padding=1)

    def forward(self, x, time):
        """Forward pass: predict velocity field v_t"""
        t = self.time_mlp(time)
        
        # Downward
        concat = []
        x = self.init_conv(x)
        concat.append(x)
        
        for layer in self.down_path:
            if isinstance(layer, (UpSample, DownSample)):
                x = layer(x)
            else:
                x = layer(x, t)
            concat.append(x)
        
        # Middle
        x = self.middle_resnet_attention(x, t)
        x = self.middle_resnet(x, t)
        
        # Upward
        for layer in self.up_path:
            if not isinstance(layer, UpSample):
                x = torch.cat((x, concat.pop()), dim=1)
                x = layer(x, t)
            else:
                x = layer(x)
        
        # Final
        x = self.final_activation(self.final_norm(x))
        return self.final_conv(x)

## 6. Flow Matching Model

Flow matching learns a velocity field that transports samples from noise to data.

### Theory:
- **Forward path**: $x_t = t \cdot x_1 + (1-t) \cdot x_0$, where $t \in [0,1]$
- **Velocity**: $v_t = x_1 - x_0$ (target velocity)
- **Training loss**: $\mathcal{L} = \mathbb{E}_{t, x_0, x_1}[\|v_\theta(x_t, t) - v_t\|^2]$
- **Sampling**: Solve ODE $\frac{dx}{dt} = v_\theta(x_t, t)$ from $t=0$ to $t=1$

In [None]:
class ConditionalFlowMatching(nn.Module):
    """Flow Matching with sparse conditioning"""
    def __init__(self, velocity_net, image_size, loss_type='l2'):
        super().__init__()
        self.velocity_net = velocity_net
        self.channel = velocity_net.channel
        self.device = next(velocity_net.parameters()).device
        self.image_size = image_size
        self.loss_type = loss_type
    
    def forward(self, x0, sparse_input=None, mask=None, loss_mask=None):
        """
        Training forward pass.
        
        Args:
            x0: Data samples (B, C, H, W)
            sparse_input: Masked input for conditioning (B, C, H, W)
            mask: Binary mask showing conditioning pixels (B, C, H, W)
            loss_mask: Binary mask for target pixels (B, C, H, W)
        """
        b, c, h, w = x0.shape
        assert h == self.image_size and w == self.image_size
        
        # Sample time uniformly t ~ U[0, 1]
        t = torch.rand(b, device=x0.device)
        
        # Sample noise x1 ~ N(0, I)
        x1 = torch.randn_like(x0)
        
        # Interpolate: x_t = t * x_1 + (1-t) * x_0
        t_expanded = t[:, None, None, None]
        x_t = t_expanded * x1 + (1 - t_expanded) * x0
        
        # Target velocity: v_t = x_1 - x_0
        v_t = x1 - x0
        
        # Prepare model input
        if sparse_input is not None and mask is not None:
            model_input = torch.cat([x_t, sparse_input, mask], dim=1)
        else:
            model_input = x_t
        
        # Predict velocity
        v_pred = self.velocity_net(model_input, t)
        
        # Compute loss
        if self.loss_type == 'l1':
            raw_loss = F.l1_loss(v_pred, v_t, reduction='none')
        elif self.loss_type == 'l2':
            raw_loss = F.mse_loss(v_pred, v_t, reduction='none')
        elif self.loss_type == 'huber':
            raw_loss = F.smooth_l1_loss(v_pred, v_t, reduction='none')
        else:
            raise NotImplementedError()
        
        # Apply loss mask (focus on target pixels)
        if loss_mask is not None:
            # Weight target pixels highly, conditioning pixels slightly
            lambda_cond = 0.05
            combined_mask = loss_mask + lambda_cond * mask
            combined_mask = combined_mask.clamp(max=1.0)
            loss = (raw_loss * combined_mask).sum() / combined_mask.sum()
        else:
            loss = raw_loss.mean()
        
        return loss
    
    @torch.inference_mode()
    def sample_ode(self, batch_size, num_steps=50, sparse_input=None, mask=None, method='euler'):
        """
        ODE sampling (deterministic).
        
        Solves: dx/dt = v_θ(x_t, t) from t=0 to t=1
        
        Args:
            batch_size: Number of samples
            num_steps: Number of integration steps
            method: 'euler' or 'rk4'
        """
        assert sparse_input is not None and mask is not None
        
        # Start from noise x_0 ~ N(0, I)
        x = torch.randn([batch_size, self.channel, self.image_size, self.image_size], 
                       device=self.device)
        
        dt = 1.0 / num_steps
        
        for i in tqdm(range(num_steps), desc='ODE Sampling', leave=False):
            t = i / num_steps
            t_batch = torch.full((batch_size,), t, device=self.device)
            
            # Prepare input
            model_input = torch.cat([x, sparse_input, mask], dim=1)
            
            if method == 'euler':
                # Euler method: x_{t+dt} = x_t + dt * v_θ(x_t, t)
                v = self.velocity_net(model_input, t_batch)
                x = x + dt * v
            
            elif method == 'rk4':
                # RK4 (more accurate)
                k1 = self.velocity_net(model_input, t_batch)
                
                x2 = x + 0.5 * dt * k1
                t2 = torch.full((batch_size,), t + 0.5 * dt, device=self.device)
                k2 = self.velocity_net(torch.cat([x2, sparse_input, mask], dim=1), t2)
                
                x3 = x + 0.5 * dt * k2
                k3 = self.velocity_net(torch.cat([x3, sparse_input, mask], dim=1), t2)
                
                x4 = x + dt * k3
                t4 = torch.full((batch_size,), t + dt, device=self.device)
                k4 = self.velocity_net(torch.cat([x4, sparse_input, mask], dim=1), t4)
                
                x = x + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)
        
        x.clamp_(-1., 1.)
        return x
    
    @torch.inference_mode()
    def sample_sde(self, batch_size, num_steps=50, sparse_input=None, mask=None, 
                   noise_scale=0.1):
        """
        SDE sampling (stochastic).
        
        Solves: dx = v_θ(x_t, t)dt + σ(t)dW
        
        Args:
            batch_size: Number of samples
            num_steps: Number of integration steps
            noise_scale: Scaling factor for stochastic noise σ(t)
        """
        assert sparse_input is not None and mask is not None
        
        # Start from noise x_0 ~ N(0, I)
        x = torch.randn([batch_size, self.channel, self.image_size, self.image_size], 
                       device=self.device)
        
        dt = 1.0 / num_steps
        
        for i in tqdm(range(num_steps), desc='SDE Sampling', leave=False):
            t = i / num_steps
            t_batch = torch.full((batch_size,), t, device=self.device)
            
            # Prepare input
            model_input = torch.cat([x, sparse_input, mask], dim=1)
            
            # Predict velocity
            v = self.velocity_net(model_input, t_batch)
            
            # Time-dependent noise schedule: more noise early, less later
            sigma_t = noise_scale * (1 - t)
            
            # SDE update: dx = v*dt + σ*√dt*dW
            noise = torch.randn_like(x)
            x = x + dt * v + sigma_t * math.sqrt(dt) * noise
        
        x.clamp_(-1., 1.)
        return x
    
    @torch.inference_mode()
    def sample(self, batch_size=16, num_steps=50, sparse_input=None, mask=None, 
              method='ode', **kwargs):
        """Unified sampling interface"""
        if method == 'ode':
            return self.sample_ode(batch_size, num_steps, sparse_input, mask, 
                                  kwargs.get('ode_method', 'euler'))
        elif method == 'sde':
            return self.sample_sde(batch_size, num_steps, sparse_input, mask, 
                                  kwargs.get('noise_scale', 0.1))
        else:
            raise ValueError(f"Unknown sampling method: {method}")

## 7. CIFAR-10 Dataset Setup

In [None]:
# Configuration
IMAGE_SIZE = 32
BATCH_SIZE = 128
LEARNING_RATE = 2e-4
EPOCHS = 50
SAVE_EVERY = 5
SAMPLE_STEPS = 50  # Flow matching uses fewer steps than diffusion!

# Create directories
os.makedirs('./results_v2', exist_ok=True)
os.makedirs('./checkpoints_v2', exist_ok=True)

# Data transforms (normalize to [-1, 1])
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # [-1, 1]
])

# Load CIFAR-10
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                         num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                        num_workers=4, pin_memory=True)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 8. Initialize Model and Training Components

In [None]:
# Initialize Velocity Network
velocity_net = VelocityNet(
    dim=64,
    image_size=IMAGE_SIZE,
    dim_multiply=(1, 2, 4, 8),
    channel=3,
    num_res_blocks=2,
    attn_resolutions=(16,),
    dropout=0.1,
    groups=32
).to(device)

# Initialize Flow Matching Model
flow_matching = ConditionalFlowMatching(
    velocity_net=velocity_net,
    image_size=IMAGE_SIZE,
    loss_type='l2'
).to(device)

# Initialize Sparsity Controller
sparsity_controller = SparsityController(
    image_size=IMAGE_SIZE,
    mode='random_epoch',
    pattern='random',
    sparsity=0.2,
    block_size=4,
    num_blocks=5
)

# Optimizer
optimizer = Adam(flow_matching.parameters(), lr=LEARNING_RATE)

# Count parameters
num_params = sum(p.numel() for p in flow_matching.parameters())
print(f"Total parameters: {num_params:,}")
print(f"Sampling steps: {SAMPLE_STEPS} (vs 1000 for diffusion!)")

## 9. Training Loop

In [None]:
def train_epoch(epoch):
    flow_matching.train()
    sparsity_controller.new_epoch()  # Generate new masks for this epoch
    
    epoch_loss = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device)
        B, C, H, W = images.shape
        
        # Generate sample IDs
        sample_ids = labels.tolist()
        
        # Get masks
        cond_masks, target_masks = sparsity_controller.get_masks(B, C, sample_ids)
        cond_mask = torch.stack(cond_masks).to(device)
        target_mask = torch.stack(target_masks).to(device)
        
        # Create sparse input
        sparse_input = images * cond_mask
        
        # Forward pass
        optimizer.zero_grad()
        loss = flow_matching(
            images,
            sparse_input=sparse_input,
            mask=cond_mask,
            loss_mask=target_mask
        )
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(flow_matching.parameters(), 1.0)
        optimizer.step()
        
        epoch_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return epoch_loss / len(train_loader)


@torch.no_grad()
def sample_images(epoch, num_samples=16, method='ode'):
    """Generate samples using ODE or SDE"""
    flow_matching.eval()
    
    # Get a batch from test set
    test_images, test_labels = next(iter(test_loader))
    test_images = test_images[:num_samples].to(device)
    
    # Generate masks
    sample_ids = test_labels[:num_samples].tolist()
    cond_masks, target_masks = sparsity_controller.get_masks(num_samples, 3, sample_ids)
    cond_mask = torch.stack(cond_masks).to(device)
    target_mask = torch.stack(target_masks).to(device)
    
    # Create sparse input
    sparse_input = test_images * cond_mask
    
    # Sample using specified method
    samples = flow_matching.sample(
        batch_size=num_samples,
        num_steps=SAMPLE_STEPS,
        sparse_input=sparse_input,
        mask=cond_mask,
        method=method,
        ode_method='euler',  # or 'rk4' for higher accuracy
        noise_scale=0.1  # for SDE
    )
    
    # Visualize
    fig = visualize_samples(
        test_images, sparse_input, cond_mask, samples, target_mask,
        nrow=4, title=f"Epoch {epoch+1} ({method.upper()})"
    )
    plt.savefig(f'./results_v2/samples_epoch_{epoch+1:03d}_{method}.png', 
               dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"Saved {method.upper()} samples for epoch {epoch+1}")

## 10. Run Training

In [None]:
# Training history
train_losses = []

print("Starting Flow Matching training...\n")

for epoch in range(EPOCHS):
    # Train
    avg_loss = train_epoch(epoch)
    train_losses.append(avg_loss)
    
    print(f"\nEpoch {epoch+1}/{EPOCHS} - Average Loss: {avg_loss:.4f}")
    
    # Sample images
    if (epoch + 1) % SAVE_EVERY == 0 or epoch == 0:
        # Generate both ODE and SDE samples for comparison
        sample_images(epoch, num_samples=16, method='ode')
        sample_images(epoch, num_samples=16, method='sde')
        
        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': flow_matching.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
            'train_losses': train_losses
        }
        torch.save(checkpoint, f'./checkpoints_v2/checkpoint_epoch_{epoch+1:03d}.pt')
        print(f"Saved checkpoint at epoch {epoch+1}")

print("\nTraining completed!")

## 11. Plot Training Loss

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Flow Matching Training Loss over Time')
plt.legend()
plt.grid(True)
plt.savefig('./results_v2/training_loss.png', dpi=150, bbox_inches='tight')
plt.show()

## 12. Compare ODE vs SDE Sampling

Generate samples with both deterministic (ODE) and stochastic (SDE) methods.

In [None]:
flow_matching.eval()

# Get test samples
num_samples = 64
test_images, test_labels = next(iter(test_loader))
test_images = test_images[:num_samples].to(device)

sample_ids = test_labels[:num_samples].tolist()
cond_masks, target_masks = sparsity_controller.get_masks(num_samples, 3, sample_ids)
cond_mask = torch.stack(cond_masks).to(device)
target_mask = torch.stack(target_masks).to(device)
sparse_input = test_images * cond_mask

print("Generating ODE samples (deterministic)...")
samples_ode = flow_matching.sample(
    batch_size=num_samples,
    num_steps=SAMPLE_STEPS,
    sparse_input=sparse_input,
    mask=cond_mask,
    method='ode',
    ode_method='euler'
)

print("Generating SDE samples (stochastic)...")
samples_sde = flow_matching.sample(
    batch_size=num_samples,
    num_steps=SAMPLE_STEPS,
    sparse_input=sparse_input,
    mask=cond_mask,
    method='sde',
    noise_scale=0.1
)

# Visualize comparison
fig, axes = plt.subplots(2, 1, figsize=(20, 8))

def denorm(x):
    return (x + 1) / 2

# ODE samples
grid_ode = make_grid(denorm(samples_ode), nrow=8)
axes[0].imshow(grid_ode.permute(1, 2, 0).cpu())
axes[0].set_title("ODE Sampling (Deterministic)", fontsize=16)
axes[0].axis('off')

# SDE samples
grid_sde = make_grid(denorm(samples_sde), nrow=8)
axes[1].imshow(grid_sde.permute(1, 2, 0).cpu())
axes[1].set_title("SDE Sampling (Stochastic)", fontsize=16)
axes[1].axis('off')

plt.tight_layout()
plt.savefig('./results_v2/ode_vs_sde_comparison.png', dpi=200, bbox_inches='tight')
plt.show()

print("\n✅ Flow Matching training and sampling completed!")
print(f"Results saved in ./results_v2/")
print(f"Checkpoints saved in ./checkpoints_v2/")

## 13. Benchmark: Different Step Counts

Test how sampling quality varies with number of steps.

In [None]:
import time

step_counts = [10, 20, 50, 100]
num_test_samples = 16

test_images, test_labels = next(iter(test_loader))
test_images = test_images[:num_test_samples].to(device)
sample_ids = test_labels[:num_test_samples].tolist()
cond_masks, target_masks = sparsity_controller.get_masks(num_test_samples, 3, sample_ids)
cond_mask = torch.stack(cond_masks).to(device)
sparse_input = test_images * cond_mask

fig, axes = plt.subplots(len(step_counts), 1, figsize=(15, 4*len(step_counts)))

for idx, steps in enumerate(step_counts):
    print(f"\nSampling with {steps} steps...")
    start = time.time()
    
    samples = flow_matching.sample(
        batch_size=num_test_samples,
        num_steps=steps,
        sparse_input=sparse_input,
        mask=cond_mask,
        method='ode'
    )
    
    elapsed = time.time() - start
    print(f"Time: {elapsed:.2f}s ({elapsed/num_test_samples:.3f}s per sample)")
    
    grid = make_grid(denorm(samples), nrow=4)
    axes[idx].imshow(grid.permute(1, 2, 0).cpu())
    axes[idx].set_title(f"{steps} steps ({elapsed:.2f}s)", fontsize=14)
    axes[idx].axis('off')

plt.tight_layout()
plt.savefig('./results_v2/step_count_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n💡 Flow matching is much faster than diffusion!")
print("   - Diffusion: 1000 steps")
print("   - Flow Matching: 10-100 steps for similar quality")

## Summary

This notebook demonstrates **Flow Matching** for sparsity-aware image reconstruction:

### Key Advantages over Diffusion (v1):

1. **⚡ Faster Sampling**: 10-100 steps vs 1000 steps
2. **📐 Straighter Paths**: Optimal transport interpolation
3. **🎯 Simpler Training**: Direct velocity prediction
4. **🔀 Flexible Sampling**: Both ODE (deterministic) and SDE (stochastic)

### Key Features:

- ✅ **Conditional Flow Matching** with sparse observations
- ✅ **ODE Solver**: Deterministic sampling (Euler or RK4)
- ✅ **SDE Solver**: Stochastic sampling with configurable noise
- ✅ **Weighted Loss**: Focus on target pixels (1.0) vs conditioning (0.05)
- ✅ **Same Architecture**: Can compare directly with diffusion baseline

### Sampling Methods:

**ODE (Deterministic)**:
- Solves: $dx/dt = v_\theta(x_t, t)$
- Same input → same output
- Best for reproducibility

**SDE (Stochastic)**:
- Solves: $dx = v_\theta(x_t, t)dt + \sigma(t)dW$
- Adds controlled randomness
- Can produce diverse samples

### Next Steps:

1. **Compare with v1**: Flow matching vs diffusion performance
2. **Tune sampling**: Try different step counts (10-100)
3. **Experiment with SDE noise**: Adjust `noise_scale`
4. **Try RK4**: More accurate ODE integration
5. **Optimal Transport**: Add minibatch OT for better paths