# CIFAR-10 Sparse Reconstruction: Neural Field as Denoiser

## Overview

**Approach**: Gaussian Random Field Denoising (Your Idea!)

**Key Concept**: Start with Gaussian noise at output pixel locations, progressively denoise using neural field while keeping input pixels clean and fixed.

**Architecture**:
```
Step T: Gaussian Noise at Output Locations
        + Clean Input Pixels (FIXED)
        ↓
Neural Field Denoiser f_θ(noisy_out, clean_in, coords, t)
        ↓
Step T-1: Less Noisy Output
          + Clean Input Pixels (FIXED)
        ↓
        ...
        ↓
Step 0: Clean Output Predictions!
```

## Theory: Direct Denoising with Fixed Conditioning

### Forward Process (Add Noise to Output Only)
$$x_t^{out} = \sqrt{\bar{\alpha}_t} \cdot x_0^{out} + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon$$

where $x_0^{out}$ are clean output pixels, $\epsilon \sim \mathcal{N}(0, I)$

**Key**: Input pixels $x^{in}$ remain **clean throughout**!

### Reverse Process (Denoise Output)
$$x_0^{pred} = f_\theta(x_t^{out}, x^{in}, \text{coords}, t)$$

Neural field predicts clean output values from:
- Noisy output pixels (changing each step)
- Clean input pixels (fixed conditioning)
- Coordinates (spatial information)
- Timestep t (noise level)

### Training Objective
$$\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon} \left[\| f_\theta(x_t^{out}, x^{in}, \text{coords}, t) - x_0^{out} \|^2 \right]$$

**Direct prediction loss** - predict clean from noisy!

### Why This Approach?
- ✅ **Most intuitive**: Gaussian random field → clean field
- ✅ **Single objective**: Predict clean values directly
- ✅ **Clear separation**: Input clean, output noisy
- ✅ **Fast sampling**: DDIM possible (10-50 steps)
- ✅ **Flexible**: Can use DDPM or DDIM sampling

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

# Import shared components
from core.neural_fields.perceiver import PerceiverIO, FourierFeatures
from core.sparse.cifar10_sparse import SparseCIFAR10Dataset
from core.sparse.metrics import MetricsTracker, print_metrics, visualize_predictions

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 1. Noise Schedule (Same as DDPM)

In [None]:
def cosine_beta_schedule(timesteps, s=0.008):
    """Cosine schedule from Improved DDPM"""
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)


class DiffusionSchedule:
    """DDPM diffusion schedule"""
    def __init__(self, timesteps=1000, beta_schedule='cosine'):
        self.timesteps = timesteps
        
        if beta_schedule == 'cosine':
            betas = cosine_beta_schedule(timesteps)
        else:
            betas = torch.linspace(0.0001, 0.02, timesteps)
        
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        
        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.alphas_cumprod_prev = alphas_cumprod_prev
        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
        
        # Posterior variance for DDPM
        self.posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        self.posterior_log_variance = torch.log(torch.clamp(self.posterior_variance, min=1e-20))
    
    def to(self, device):
        """Move all tensors to device"""
        self.betas = self.betas.to(device)
        self.alphas = self.alphas.to(device)
        self.alphas_cumprod = self.alphas_cumprod.to(device)
        self.alphas_cumprod_prev = self.alphas_cumprod_prev.to(device)
        self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
        self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
        self.sqrt_recip_alphas_cumprod = self.sqrt_recip_alphas_cumprod.to(device)
        self.sqrt_recipm1_alphas_cumprod = self.sqrt_recipm1_alphas_cumprod.to(device)
        self.posterior_variance = self.posterior_variance.to(device)
        self.posterior_log_variance = self.posterior_log_variance.to(device)
        return self


# Visualize
schedule = DiffusionSchedule(timesteps=1000)

plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(schedule.betas.numpy())
plt.title('β_t')
plt.xlabel('Timestep')
plt.grid(alpha=0.3)

plt.subplot(1, 3, 2)
plt.plot(schedule.alphas_cumprod.numpy())
plt.title('ᾱ_t (Signal Strength)')
plt.xlabel('Timestep')
plt.grid(alpha=0.3)

plt.subplot(1, 3, 3)
plt.plot(schedule.sqrt_alphas_cumprod.numpy(), label='√ᾱ_t')
plt.plot(schedule.sqrt_one_minus_alphas_cumprod.numpy(), label='√(1-ᾱ_t)')
plt.title('Noise Mixing')
plt.xlabel('Timestep')
plt.legend()
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 2. Neural Field Denoiser Architecture

Key idea: Input pixels stay clean, only output pixels are noisy!

In [None]:
class SinusoidalTimeEmbedding(nn.Module):
    """Sinusoidal time embedding for noise level"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb


class NFDenoiser(nn.Module):
    """
    Neural Field Denoiser: Gaussian Random Field → Clean Field
    
    f_θ(noisy_output, clean_input, coords, t) → clean_output
    """
    def __init__(
        self,
        num_latents=512,
        latent_dim=512,
        num_fourier_feats=256,
        num_blocks=6,
        num_heads=8,
        dropout=0.1
    ):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Perceiver IO backbone
        self.perceiver = PerceiverIO(
            input_channels=3,
            output_channels=3,
            num_latents=num_latents,
            latent_dim=latent_dim,
            num_fourier_feats=num_fourier_feats,
            num_blocks=num_blocks,
            num_heads=num_heads,
            dropout=dropout
        )
        
        # Time embedding
        self.time_embed = SinusoidalTimeEmbedding(latent_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.SiLU(),
            nn.Linear(latent_dim, latent_dim)
        )
        
        # Marker for input vs output pixels
        self.marker_embed = nn.Embedding(2, 3)  # 0=input, 1=output
    
    def forward(self, noisy_output_values, output_coords, t, input_coords, input_values):
        """
        Args:
            noisy_output_values: (B, N_out, 3) NOISY output pixel values
            output_coords: (B, N_out, 2) output pixel coordinates
            t: (B,) continuous timesteps
            input_coords: (B, N_in, 2) input pixel coordinates
            input_values: (B, N_in, 3) CLEAN input pixel values
        
        Returns:
            clean_pred: (B, N_out, 3) predicted CLEAN output values
        """
        B, N_in = input_coords.shape[0], input_coords.shape[1]
        N_out = output_coords.shape[1]
        
        # Time embedding
        t_emb = self.time_mlp(self.time_embed(t))  # (B, latent_dim)
        
        # Add time information to pixel values (broadcast to RGB channels)
        time_signal = t_emb[:, :3].unsqueeze(1)  # (B, 1, 3)
        
        # IMPORTANT: Input pixels are CLEAN (no noise added)
        #            Output pixels are NOISY (progressively denoised)
        input_values_t = input_values + time_signal  # Add time to clean input
        noisy_output_values_t = noisy_output_values + time_signal  # Add time to noisy output
        
        # Concatenate all pixels (input=clean, output=noisy)
        all_coords = torch.cat([input_coords, output_coords], dim=1)  # (B, N_in+N_out, 2)
        all_values = torch.cat([input_values_t, noisy_output_values_t], dim=1)  # (B, N_in+N_out, 3)
        
        # Predict CLEAN values at output coordinates
        clean_pred = self.perceiver(all_coords, all_values, output_coords)
        
        return clean_pred


# Test denoiser
model = NFDenoiser().to(device)
test_noisy = torch.randn(4, 204, 3).to(device)
test_output_coords = torch.rand(4, 204, 2).to(device)
test_t = torch.rand(4).to(device) * 1000
test_input_coords = torch.rand(4, 204, 2).to(device)
test_input_values = torch.rand(4, 204, 3).to(device)

test_pred = model(test_noisy, test_output_coords, test_t, test_input_coords, test_input_values)
print(f"Denoiser test: {test_pred.shape}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 3. Training: Direct Clean Prediction

In [None]:
def train_denoiser(
    model,
    train_loader,
    schedule,
    epochs=100,
    lr=1e-4,
    device='cuda'
):
    """Train neural field denoiser"""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    losses = []
    
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_coords = batch['input_coords'].to(device)
            input_values = batch['input_values'].to(device)  # CLEAN (fixed)
            output_coords = batch['output_coords'].to(device)
            output_values = batch['output_values'].to(device)  # x_0 (target)
            
            B = input_coords.shape[0]
            
            # Sample random timestep
            t = torch.randint(0, schedule.timesteps, (B,), device=device).float()
            
            # Add noise to OUTPUT ONLY: x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε
            noise = torch.randn_like(output_values)
            sqrt_alpha_t = schedule.sqrt_alphas_cumprod[t.long()].view(B, 1, 1)
            sqrt_one_minus_alpha_t = schedule.sqrt_one_minus_alphas_cumprod[t.long()].view(B, 1, 1)
            
            noisy_output = sqrt_alpha_t * output_values + sqrt_one_minus_alpha_t * noise
            
            # Predict CLEAN output from noisy output + clean input
            pred_clean = model(noisy_output, output_coords, t, input_coords, input_values)
            
            # Loss: predict clean values
            loss = F.mse_loss(pred_clean, output_values)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        scheduler.step()
        
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.6f}, LR = {scheduler.get_last_lr()[0]:.6f}")
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, f'nf_denoiser_epoch_{epoch+1}.pt')
    
    return losses


# Create dataset
print("Loading CIFAR-10...")
train_dataset = SparseCIFAR10Dataset(
    root='../data',
    train=True,
    input_ratio=0.2,
    output_ratio=0.2,
    download=True,
    seed=42
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

print(f"Dataset size: {len(train_dataset)}")
print(f"Batches: {len(train_loader)}")

# Initialize
model = NFDenoiser(
    num_latents=512,
    latent_dim=512,
    num_fourier_feats=256,
    num_blocks=6,
    num_heads=8
).to(device)

schedule = DiffusionSchedule(timesteps=1000, beta_schedule='cosine').to(device)

# Train
print("\nStarting training...")
losses = train_denoiser(model, train_loader, schedule, epochs=100, lr=1e-4, device=device)

## 4. Sampling: DDPM or DDIM

We implement DDIM for faster sampling (10-50 steps vs 1000)

In [None]:
@torch.no_grad()
def ddim_sample(
    model,
    schedule,
    output_coords,
    input_coords,
    input_values,
    num_steps=50,
    eta=0.0,  # 0=DDIM, 1=DDPM
    device='cuda'
):
    """
    DDIM sampling (faster than DDPM)
    
    Args:
        model: Trained denoiser
        schedule: Diffusion schedule
        output_coords: (B, N_out, 2)
        input_coords: (B, N_in, 2)
        input_values: (B, N_in, 3) CLEAN (conditioning)
        num_steps: Sampling steps (50 is good)
        eta: Stochasticity (0=deterministic DDIM)
    
    Returns:
        x_0: (B, N_out, 3) predicted clean values
    """
    B = output_coords.shape[0]
    N_out = output_coords.shape[1]
    
    # Start from Gaussian random field
    x_t = torch.randn(B, N_out, 3, device=device)
    
    # Uniform timestep schedule
    timesteps = torch.linspace(schedule.timesteps - 1, 0, num_steps).long()
    
    for i, t_idx in enumerate(tqdm(timesteps, desc="Sampling (DDIM)")):
        t = torch.full((B,), t_idx.item(), device=device, dtype=torch.float)
        
        # Predict x_0 from x_t
        x_0_pred = model(x_t, output_coords, t, input_coords, input_values)
        
        if i < len(timesteps) - 1:
            # Get next timestep
            t_next = timesteps[i + 1]
            
            # DDIM update
            alpha_t = schedule.alphas_cumprod[t_idx]
            alpha_t_next = schedule.alphas_cumprod[t_next]
            
            # Predicted noise
            eps_pred = (x_t - torch.sqrt(alpha_t) * x_0_pred) / torch.sqrt(1 - alpha_t)
            
            # DDIM step
            x_t = (
                torch.sqrt(alpha_t_next) * x_0_pred +
                torch.sqrt(1 - alpha_t_next) * eps_pred
            )
            
            # Optional: add stochasticity (eta > 0)
            if eta > 0:
                sigma_t = eta * torch.sqrt(
                    (1 - alpha_t_next) / (1 - alpha_t) * (1 - alpha_t / alpha_t_next)
                )
                noise = torch.randn_like(x_t)
                x_t = x_t + sigma_t * noise
        else:
            # Final step
            x_t = x_0_pred
    
    return torch.clamp(x_t, 0, 1)

## 5. Evaluation with Unified Metrics

In [None]:
# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(losses, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training Loss: Neural Field Denoiser')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

# Evaluate on test set
print("\nEvaluating on test set...")
test_dataset = SparseCIFAR10Dataset(
    root='../data',
    train=False,
    input_ratio=0.2,
    output_ratio=0.2,
    download=True,
    seed=42
)

test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

model.eval()
tracker = MetricsTracker()

for batch in tqdm(test_loader, desc="Testing"):
    input_coords = batch['input_coords'].to(device)
    input_values = batch['input_values'].to(device)
    output_coords = batch['output_coords'].to(device)
    output_values = batch['output_values'].to(device)
    
    # Sample predictions (DDIM with 50 steps)
    pred_values = ddim_sample(
        model, schedule, output_coords, input_coords, input_values,
        num_steps=50, eta=0.0, device=device
    )
    
    tracker.update(pred_values, output_values)

# Print results
results = tracker.compute()
print("\n" + "="*50)
print("NEURAL FIELD DENOISER - Test Results")
print("="*50)
print_metrics(results)
print("="*50)

## 6. Visualize Predictions

In [None]:
# Visualize predictions
sample_batch = next(iter(test_loader))
input_coords = sample_batch['input_coords'].to(device)
input_values = sample_batch['input_values'].to(device)
output_coords = sample_batch['output_coords'].to(device)
output_values = sample_batch['output_values'].to(device)
full_images = sample_batch['full_image'].to(device)

# Generate predictions
pred_values = ddim_sample(
    model, schedule, output_coords, input_coords, input_values,
    num_steps=50, eta=0.0, device=device
)

# Visualize
fig = visualize_predictions(
    input_coords, input_values,
    output_coords, pred_values, output_values,
    full_images, n_samples=4
)
plt.suptitle('Neural Field Denoiser: Predictions vs Ground Truth', fontsize=14, y=1.02)
plt.savefig('nf_denoiser_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

## Summary

### ✅ Implemented
- Gaussian Random Field denoising (your idea!)
- Direct clean prediction objective
- DDIM sampling (50 steps)
- Fixed clean input conditioning

### 📊 Results
See metrics above for:
- MSE/MAE on output pixels
- PSNR/SSIM on full images

### ⚖️ Strengths & Weaknesses

**Strengths**:
- ✅ **Most intuitive approach**: Gaussian noise → clean field
- ✅ **Simple training**: Single MSE loss on clean prediction
- ✅ **Fast sampling**: DDIM with 50 steps (vs 1000 for score-based)
- ✅ **Clear separation**: Input always clean, output denoised
- ✅ **Flexible**: Works with DDPM or DDIM

**Potential Weaknesses**:
- ⚠️ May need more steps for very high quality
- ⚠️ Less theoretically principled than score-based

### 🔄 Next
Compare with:
- Notebook 3: Score-Based (already done)
- Notebook 5: Flow Matching (next)