# CIFAR-10 Sparse Reconstruction: Flow Matching

## Overview

**Approach**: Conditional Flow Matching (Modern Approach!)

**Key Concept**: Learn vector field that transports noise distribution → data distribution via straight paths.

**Architecture**:
```
p_0 (Noise) ──────────> p_1 (Data)
    x_0          Flow φ_t       x_1
                  ↑
         Velocity Field v_θ(x_t, t | input)
```

## Theory: Conditional Flow Matching

### Flow Equation
$$\frac{dx_t}{dt} = v_\theta(x_t, t | x^{in})$$

where $v_\theta$ is the learned velocity field conditioned on input pixels.

### Conditional Flow (Straight Paths)
$$\phi_t(x_1) = (1-t) \cdot x_0 + t \cdot x_1$$

where $x_0 \sim \mathcal{N}(0, I)$ (noise), $x_1$ (data)

### Target Velocity
$$u_t(x_1) = \frac{d\phi_t}{dt} = x_1 - x_0$$

Simply the difference between data and noise!

### Training Objective (Simple!)
$$\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[\| v_\theta(\phi_t(x_1), t | x^{in}) - (x_1 - x_0) \|^2 \right]$$

Match predicted velocity to true velocity!

### Why Flow Matching?
- ✅ **Simplest training**: Directly match velocities
- ✅ **Fastest sampling**: Straight paths = fewer steps
- ✅ **Modern**: Used in Stable Diffusion 3, Flux
- ✅ **No score matching**: Simpler than score-based
- ✅ **ODE solving**: Deterministic, controllable

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

# Import shared components
from core.neural_fields.perceiver import PerceiverIO, FourierFeatures
from core.sparse.cifar10_sparse import SparseCIFAR10Dataset
from core.sparse.metrics import MetricsTracker, print_metrics, visualize_predictions

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 1. Conditional Flow

Linear interpolation between noise and data

In [None]:
def conditional_flow(x_0, x_1, t):
    """
    Conditional flow: φ_t(x_1) = (1-t) * x_0 + t * x_1
    
    Args:
        x_0: (B, N, 3) noise
        x_1: (B, N, 3) data
        t: (B, 1, 1) time in [0, 1]
    
    Returns:
        x_t: (B, N, 3) interpolated state
    """
    return (1 - t) * x_0 + t * x_1


def target_velocity(x_0, x_1):
    """
    Target velocity: u_t = x_1 - x_0
    
    Args:
        x_0: (B, N, 3) noise
        x_1: (B, N, 3) data
    
    Returns:
        u_t: (B, N, 3) target velocity (constant along path!)
    """
    return x_1 - x_0


# Visualize conditional flow
x_0 = torch.randn(1, 1, 3)
x_1 = torch.rand(1, 1, 3)
ts = torch.linspace(0, 1, 100).view(-1, 1, 1)

x_t = conditional_flow(x_0, x_1, ts)

plt.figure(figsize=(10, 4))

# Plot each RGB channel separately
for i, (color, name) in enumerate(zip(['red', 'green', 'blue'], ['R', 'G', 'B'])):
    plt.plot(ts.squeeze().numpy(), x_t.squeeze().numpy()[:, i], 
             color=color, alpha=0.7, linewidth=2, label=f'{name} channel')
    plt.scatter([0], x_0.squeeze().numpy()[i], c=color, s=100, 
                marker='o', edgecolors='black', linewidths=2, zorder=10)
    plt.scatter([1], x_1.squeeze().numpy()[i], c=color, s=100, 
                marker='s', edgecolors='black', linewidths=2, zorder=10)

plt.xlabel('Time t')
plt.ylabel('Value')
plt.title('Conditional Flow: Straight Path from Noise to Data')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## 2. Velocity Field Network

Learns to predict velocity: v_θ(x_t, t | input) → velocity

In [None]:
class SinusoidalTimeEmbedding(nn.Module):
    """Sinusoidal time embedding for t in [0, 1]"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        """t: (B,) time in [0, 1]"""
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb


class VelocityField(nn.Module):
    """
    Velocity Field: v_θ(x_t, coords, t | input) → velocity
    
    Predicts how to transport x_t toward data x_1
    """
    def __init__(
        self,
        num_latents=512,
        latent_dim=512,
        num_fourier_feats=256,
        num_blocks=6,
        num_heads=8,
        dropout=0.1
    ):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Perceiver IO backbone
        self.perceiver = PerceiverIO(
            input_channels=3,
            output_channels=3,
            num_latents=num_latents,
            latent_dim=latent_dim,
            num_fourier_feats=num_fourier_feats,
            num_blocks=num_blocks,
            num_heads=num_heads,
            dropout=dropout
        )
        
        # Time embedding
        self.time_embed = SinusoidalTimeEmbedding(latent_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.SiLU(),
            nn.Linear(latent_dim, latent_dim)
        )
    
    def forward(self, x_t, output_coords, t, input_coords, input_values):
        """
        Args:
            x_t: (B, N_out, 3) current state on flow path
            output_coords: (B, N_out, 2) output coordinates
            t: (B,) time in [0, 1]
            input_coords: (B, N_in, 2) input coordinates
            input_values: (B, N_in, 3) input values (conditioning)
        
        Returns:
            velocity: (B, N_out, 3) predicted velocity
        """
        B = x_t.shape[0]
        
        # Time embedding
        t_emb = self.time_mlp(self.time_embed(t))  # (B, latent_dim)
        
        # Add time to pixel values
        time_signal = t_emb[:, :3].unsqueeze(1)  # (B, 1, 3)
        
        input_values_t = input_values + time_signal
        x_t_emb = x_t + time_signal
        
        # Concatenate input (conditioning) and output (current state)
        all_coords = torch.cat([input_coords, output_coords], dim=1)
        all_values = torch.cat([input_values_t, x_t_emb], dim=1)
        
        # Predict velocity at output coordinates
        velocity = self.perceiver(all_coords, all_values, output_coords)
        
        return velocity


# Test velocity field
model = VelocityField().to(device)
test_x_t = torch.rand(4, 204, 3).to(device)
test_output_coords = torch.rand(4, 204, 2).to(device)
test_t = torch.rand(4).to(device)
test_input_coords = torch.rand(4, 204, 2).to(device)
test_input_values = torch.rand(4, 204, 3).to(device)

test_vel = model(test_x_t, test_output_coords, test_t, test_input_coords, test_input_values)
print(f"Velocity field test: {test_vel.shape}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 3. Sampling: ODE Solver

Define sampling functions before training

Solve dx/dt = v_θ(x_t, t | input) from t=0 to t=1

In [None]:
@torch.no_grad()
def euler_sample(
    model,
    output_coords,
    input_coords,
    input_values,
    num_steps=50,
    device='cuda'
):
    """
    Euler method for ODE: dx/dt = v_θ(x_t, t | input)
    
    Args:
        model: Trained velocity field
        output_coords: (B, N_out, 2)
        input_coords: (B, N_in, 2)
        input_values: (B, N_in, 3) conditioning
        num_steps: Number of integration steps
    
    Returns:
        x_1: (B, N_out, 3) predicted data at t=1
    """
    B = output_coords.shape[0]
    N_out = output_coords.shape[1]
    
    # Start from noise at t=0
    x_t = torch.randn(B, N_out, 3, device=device)
    
    # Time discretization
    dt = 1.0 / num_steps
    ts = torch.linspace(0, 1 - dt, num_steps)
    
    for t_val in tqdm(ts, desc="Sampling (Euler)", leave=False):
        t = torch.full((B,), t_val.item(), device=device)
        
        # Predict velocity
        velocity = model(x_t, output_coords, t, input_coords, input_values)
        
        # Euler update: x_{t+dt} = x_t + dt * v_θ(x_t, t)
        x_t = x_t + dt * velocity
    
    return torch.clamp(x_t, 0, 1)


@torch.no_grad()
def heun_sample(
    model,
    output_coords,
    input_coords,
    input_values,
    num_steps=50,
    device='cuda'
):
    """
    Heun's method (improved Euler) for better accuracy
    
    Args:
        Same as euler_sample
    
    Returns:
        x_1: (B, N_out, 3) predicted data at t=1
    """
    B = output_coords.shape[0]
    N_out = output_coords.shape[1]
    
    x_t = torch.randn(B, N_out, 3, device=device)
    
    dt = 1.0 / num_steps
    ts = torch.linspace(0, 1 - dt, num_steps)
    
    for t_val in tqdm(ts, desc="Sampling (Heun)", leave=False):
        t = torch.full((B,), t_val.item(), device=device)
        t_next = torch.full((B,), t_val.item() + dt, device=device)
        
        # First step (Euler predictor)
        v1 = model(x_t, output_coords, t, input_coords, input_values)
        x_next_pred = x_t + dt * v1
        
        # Second step (corrector)
        v2 = model(x_next_pred, output_coords, t_next, input_coords, input_values)
        
        # Heun update (average of two velocities)
        x_t = x_t + dt * 0.5 * (v1 + v2)
    
    return torch.clamp(x_t, 0, 1)

## 4. Training: Conditional Flow Matching with Evaluation

In [None]:
def train_flow_matching(
    model,
    train_loader,
    test_loader,
    epochs=100,
    lr=1e-4,
    device='cuda',
    visualize_every=5,
    eval_every=2
):
    """Train velocity field with conditional flow matching"""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    losses = []
    
    # Get a fixed batch for visualization
    viz_batch = next(iter(train_loader))
    viz_input_coords = viz_batch['input_coords'][:4].to(device)
    viz_input_values = viz_batch['input_values'][:4].to(device)
    viz_output_coords = viz_batch['output_coords'][:4].to(device)
    viz_output_values = viz_batch['output_values'][:4].to(device)
    viz_full_images = viz_batch['full_image'][:4].to(device)
    
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_coords = batch['input_coords'].to(device)
            input_values = batch['input_values'].to(device)
            output_coords = batch['output_coords'].to(device)
            output_values = batch['output_values'].to(device)  # x_1 (data)
            
            B = input_coords.shape[0]
            
            # Sample random time t ~ Uniform(0, 1)
            t = torch.rand(B, device=device)
            
            # Sample noise x_0 ~ N(0, I)
            x_0 = torch.randn_like(output_values)
            x_1 = output_values
            
            # Conditional flow: x_t = (1-t) * x_0 + t * x_1
            t_broadcast = t.view(B, 1, 1)
            x_t = conditional_flow(x_0, x_1, t_broadcast)
            
            # Target velocity: u_t = x_1 - x_0 (constant!)
            u_t = target_velocity(x_0, x_1)
            
            # Predict velocity
            v_pred = model(x_t, output_coords, t, input_coords, input_values)
            
            # Loss: match predicted velocity to target velocity
            loss = F.mse_loss(v_pred, u_t)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        scheduler.step()
        
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.6f}, LR = {scheduler.get_last_lr()[0]:.6f}")
        
        # Evaluate every N epochs
        if (epoch + 1) % eval_every == 0 or epoch == 0:
            model.eval()
            with torch.no_grad():
                tracker = MetricsTracker()
                
                # Evaluate on small subset for speed
                for i, batch in enumerate(test_loader):
                    if i >= 10:
                        break
                    
                    input_coords = batch['input_coords'].to(device)
                    input_values = batch['input_values'].to(device)
                    output_coords = batch['output_coords'].to(device)
                    output_values = batch['output_values'].to(device)
                    
                    # Generate predictions
                    pred_values = heun_sample(
                        model, output_coords, input_coords, input_values,
                        num_steps=50, device=device
                    )
                    
                    tracker.update(pred_values, output_values)
                
                results = tracker.compute()
                print(f"  Eval - MSE: {results['mse']:.6f}, MAE: {results['mae']:.6f}")
            
            model.train()
        
        # Visualize predictions every N epochs
        if (epoch + 1) % visualize_every == 0 or epoch == 0:
            model.eval()
            with torch.no_grad():
                # Generate predictions using Heun ODE solver
                pred_values = heun_sample(
                    model, viz_output_coords, 
                    viz_input_coords, viz_input_values,
                    num_steps=50, device=device
                )
                
                # Visualize
                fig = visualize_predictions(
                    viz_input_coords, viz_input_values,
                    viz_output_coords, pred_values, viz_output_values,
                    viz_full_images, n_samples=4
                )
                plt.suptitle(f'Flow Matching - Epoch {epoch+1}/{epochs}', fontsize=14, y=1.02)
                plt.savefig(f'flow_matching_epoch_{epoch+1:03d}.png', dpi=150, bbox_inches='tight')
                plt.show()
                plt.close()
            
            model.train()
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, f'flow_matching_epoch_{epoch+1}.pt')
    
    return losses


# Create dataset
print("Loading CIFAR-10...")
train_dataset = SparseCIFAR10Dataset(
    root='../data',
    train=True,
    input_ratio=0.2,
    output_ratio=0.2,
    download=True,
    seed=42
)

test_dataset = SparseCIFAR10Dataset(
    root='../data',
    train=False,
    input_ratio=0.2,
    output_ratio=0.2,
    download=True,
    seed=42
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Initialize
model = VelocityField(
    num_latents=512,
    latent_dim=512,
    num_fourier_feats=256,
    num_blocks=6,
    num_heads=8
).to(device)

# Train
print("\nStarting training...")
losses = train_flow_matching(model, train_loader, test_loader, 
                             epochs=100, lr=1e-4, device=device, 
                             visualize_every=5, eval_every=2)

## 5. Final Evaluation: Full Image Reconstruction

In [None]:
# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(losses, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Velocity Matching Loss')
plt.title('Training Loss: Flow Matching')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

# Final Evaluation: Reconstruct FULL images (all 1024 pixels)
print("\n" + "="*70)
print("FINAL EVALUATION: Full Image Reconstruction (1024 pixels)")
print("="*70)

model.eval()

# Create full grid of coordinates for 32x32 image
def create_full_grid(image_size=32, device='cuda'):
    """Create coordinate grid for full image"""
    y, x = torch.meshgrid(
        torch.linspace(0, 1, image_size),
        torch.linspace(0, 1, image_size),
        indexing='ij'
    )
    coords = torch.stack([x.flatten(), y.flatten()], dim=-1)  # (1024, 2)
    return coords.to(device)

full_coords = create_full_grid(32, device)  # (1024, 2)

# Evaluate on test set - reconstruct FULL images
tracker_full = MetricsTracker()

for i, batch in enumerate(tqdm(test_loader, desc="Full Image Reconstruction")):
    if i >= 50:  # Evaluate on 50 batches = 800 images
        break
    
    input_coords = batch['input_coords'].to(device)
    input_values = batch['input_values'].to(device)
    full_images = batch['full_image'].to(device)
    
    B = input_coords.shape[0]
    
    # Predict ALL pixels (1024) conditioned on sparse input (204)
    full_coords_batch = full_coords.unsqueeze(0).expand(B, -1, -1)  # (B, 1024, 2)
    
    pred_values = heun_sample(
        model, full_coords_batch, input_coords, input_values,
        num_steps=100, device=device
    )
    
    # Reshape predictions to image format
    pred_images = pred_values.view(B, 32, 32, 3).permute(0, 3, 1, 2)  # (B, 3, 32, 32)
    
    # Compute metrics on full images
    tracker_full.update(None, None, pred_images, full_images)

# Print results
results_full = tracker_full.compute()
print("\nFull Image Reconstruction Results:")
print(f"  PSNR: {results_full['psnr']:.2f} dB")
print(f"  SSIM: {results_full['ssim']:.4f}")
print("="*70)

## 6. Visualize Full Image Reconstructions

In [None]:
# Visualize full image reconstructions
sample_batch = next(iter(test_loader))
input_coords = sample_batch['input_coords'][:4].to(device)
input_values = sample_batch['input_values'][:4].to(device)
full_images = sample_batch['full_image'][:4].to(device)

B = input_coords.shape[0]
full_coords_batch = full_coords.unsqueeze(0).expand(B, -1, -1)

# Generate FULL image predictions
pred_values = heun_sample(
    model, full_coords_batch, input_coords, input_values,
    num_steps=100, device=device
)

pred_images = pred_values.view(B, 32, 32, 3).permute(0, 3, 1, 2)

# Visualize
fig, axes = plt.subplots(4, 3, figsize=(12, 16))

for i in range(4):
    # Ground truth
    gt_img = full_images[i].permute(1, 2, 0).cpu().numpy()
    axes[i, 0].imshow(gt_img)
    axes[i, 0].set_title('Ground Truth')
    axes[i, 0].axis('off')
    
    # Sparse input (visualize the 20% input pixels)
    input_img = torch.zeros(3, 32, 32, device=device)
    input_idx = sample_batch['input_indices'][i].to(device)
    input_img.view(3, -1)[:, input_idx] = input_values[i].T
    axes[i, 1].imshow(input_img.permute(1, 2, 0).cpu().numpy())
    axes[i, 1].set_title(f'Input (20% = {len(input_idx)} pixels)')
    axes[i, 1].axis('off')
    
    # Full reconstruction
    pred_img = pred_images[i].permute(1, 2, 0).cpu().numpy()
    axes[i, 2].imshow(np.clip(pred_img, 0, 1))
    axes[i, 2].set_title('Reconstructed (100%)')
    axes[i, 2].axis('off')

plt.suptitle('Flow Matching: Full Image Reconstruction from 20% Sparse Input', fontsize=14, y=0.995)
plt.tight_layout()
plt.savefig('flow_matching_full_reconstruction.png', dpi=150, bbox_inches='tight')
plt.show()

## Summary

### ✅ Implemented
- Conditional flow matching
- Straight-path interpolation
- Velocity field prediction
- Heun ODE solver (2nd order)

### 📊 Results
See metrics above for:
- MSE/MAE on output pixels
- PSNR/SSIM on full images

### ⚖️ Strengths & Weaknesses

**Strengths**:
- ✅ **Simplest training**: Direct velocity matching
- ✅ **Fastest sampling**: Straight paths, 20-50 steps
- ✅ **Modern approach**: Used in SD3, Flux
- ✅ **Deterministic**: ODE solving, reproducible
- ✅ **Flexible**: Can use various ODE solvers

**Potential Weaknesses**:
- ⚠️ Less explored for sparse conditioning
- ⚠️ May need careful solver selection

### 🔄 Comparison with Other Approaches

| Metric | Score-Based | NF Denoiser | Flow Matching |
|--------|-------------|-------------|---------------|
| Training | Complex (score matching) | Simple (MSE) | Simplest (velocity MSE) |
| Sampling | Slow (Langevin) | Fast (DDIM) | Fastest (ODE) |
| Steps | 100-1000 | 50-100 | 20-50 |
| Theory | Mature | Established | Modern |

### 🏆 Final Verdict

Run all 3 notebooks and compare quantitative metrics!

**Expected Winner**:
- **Quality**: All three should be similar
- **Speed**: Flow Matching ≈ NF Denoiser > Score-Based
- **Simplicity**: Flow Matching > NF Denoiser > Score-Based

**Best Overall**: Likely Flow Matching or NF Denoiser depending on your priorities!