# Local Implicit Field Diffusion (Option 2)

## Architecture Overview

**Key Innovation**: Distance-aware local attention with explicit locality bias

**Advantages over Perceiver IO**:
- ✅ No latent bottleneck → preserves all spatial information
- ✅ Explicit locality modeling → better for continuous fields
- ✅ Distance-weighted attention → natural for sparse conditioning
- ✅ FiLM time modulation → proven better for diffusion models

**Architecture**:
```
Sparse Input (20%) + Query Points
        ↓
Fourier Features + Distance Encoding
        ↓
Local Attention (masked by distance < radius)
        ↓
FiLM Modulation (time conditioning)
        ↓
MLP Decoder → Predicted RGB
```

## Three Approaches Implemented
1. Flow Matching (simplest, fastest)
2. NF Denoiser (Gaussian RF idea)
3. Score-Based (most principled)

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

from core.neural_fields.perceiver import FourierFeatures
from core.sparse.cifar10_sparse import SparseCIFAR10Dataset
from core.sparse.metrics import MetricsTracker, print_metrics, visualize_predictions

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 1. Core Components

### Time Embedding & FiLM Modulation

In [None]:
class SinusoidalTimeEmbedding(nn.Module):
    """Sinusoidal time embedding"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb


class FiLMLayer(nn.Module):
    """Feature-wise Linear Modulation for time conditioning"""
    def __init__(self, d_model):
        super().__init__()
        self.scale = nn.Linear(d_model, d_model)
        self.shift = nn.Linear(d_model, d_model)
    
    def forward(self, x, t_emb):
        """
        Args:
            x: (B, N, d_model) features
            t_emb: (B, d_model) time embedding
        """
        scale = self.scale(t_emb).unsqueeze(1)  # (B, 1, d_model)
        shift = self.shift(t_emb).unsqueeze(1)  # (B, 1, d_model)
        return x * (1 + scale) + shift

### Local Distance-Weighted Attention

In [None]:
class LocalAttentionLayer(nn.Module):
    """Local attention with distance-based masking and weighting"""
    def __init__(self, d_model, num_heads=8, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        assert d_model % num_heads == 0
        
        # Query, Key, Value projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        # Distance encoding
        self.dist_bias = nn.Sequential(
            nn.Linear(1, num_heads),
            nn.Tanh()
        )
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        
        # Feed-forward
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, queries, keys_values, distances, mask=None):
        """
        Args:
            queries: (B, N_q, d_model)
            keys_values: (B, N_q, N_k, d_model) - for each query, all keys with distance encoding
            distances: (B, N_q, N_k) - pairwise distances
            mask: (B, N_q, N_k) - locality mask (1 = attend, 0 = ignore)
        """
        B, N_q, d_model = queries.shape
        N_k = keys_values.shape[2]
        
        # Residual connection
        residual = queries
        queries = self.layer_norm1(queries)
        
        # Project Q, K, V
        Q = self.q_proj(queries).view(B, N_q, self.num_heads, self.head_dim).transpose(1, 2)  # (B, H, N_q, d_h)
        
        # Average pool keys_values across the N_k dimension to get single key/value per query
        # Better approach: weighted average by distance
        if mask is not None:
            weights = mask.unsqueeze(-1)  # (B, N_q, N_k, 1)
            kv_pooled = (keys_values * weights).sum(dim=2) / (weights.sum(dim=2) + 1e-6)  # (B, N_q, d_model)
        else:
            kv_pooled = keys_values.mean(dim=2)  # (B, N_q, d_model)
        
        K = self.k_proj(kv_pooled).view(B, N_q, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(kv_pooled).view(B, N_q, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (B, H, N_q, N_q)
        
        # Apply attention
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        out = torch.matmul(attn, V)  # (B, H, N_q, d_h)
        out = out.transpose(1, 2).contiguous().view(B, N_q, d_model)
        out = self.out_proj(out)
        out = self.dropout(out)
        
        # Residual + Feed-forward
        out = out + residual
        out = out + self.ff(self.layer_norm2(out))
        
        return out

### Main Architecture: Local Implicit Diffusion

In [None]:
class LocalImplicitDiffusion(nn.Module):
    """
    Local neural fields with distance-aware attention
    
    Key features:
    - No latent bottleneck (preserves spatial info)
    - Explicit locality via distance masking
    - FiLM modulation for time conditioning
    """
    def __init__(
        self,
        num_fourier_feats=256,
        d_model=512,
        num_layers=4,
        num_heads=8,
        local_radius=0.3,  # Locality radius in normalized coords
        dropout=0.1
    ):
        super().__init__()
        self.local_radius = local_radius
        self.d_model = d_model
        
        # Fourier features
        self.fourier = FourierFeatures(coord_dim=2, num_freqs=num_fourier_feats, scale=10.0)
        feat_dim = num_fourier_feats * 4  # sin/cos for x and y
        
        # Distance encoding
        self.dist_encoder = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64)
        )
        
        # Project inputs: fourier + RGB + distance
        self.input_proj = nn.Linear(feat_dim + 3 + 64, d_model)
        
        # Project queries: fourier + RGB
        self.query_proj = nn.Linear(feat_dim + 3, d_model)
        
        # Time embedding
        self.time_embed = SinusoidalTimeEmbedding(d_model)
        
        # FiLM layers for time conditioning
        self.film_layers = nn.ModuleList([
            FiLMLayer(d_model) for _ in range(num_layers)
        ])
        
        # Local attention layers
        self.attn_layers = nn.ModuleList([
            LocalAttentionLayer(d_model, num_heads, dropout)
            for _ in range(num_layers)
        ])
        
        # Output decoder
        self.decoder = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 3)
        )
    
    def forward(self, noisy_values, query_coords, t, input_coords, input_values):
        """
        Args:
            noisy_values: (B, N_out, 3) - noisy RGB values at query points
            query_coords: (B, N_out, 2) - query coordinates
            t: (B,) - timestep
            input_coords: (B, N_in, 2) - sparse input coordinates  
            input_values: (B, N_in, 3) - sparse input RGB values
        """
        B, N_out, _ = query_coords.shape
        N_in = input_coords.shape[1]
        
        # Time embedding
        t_emb = self.time_embed(t)  # (B, d_model)
        
        # Compute pairwise distances
        dist = torch.cdist(query_coords, input_coords)  # (B, N_out, N_in)
        
        # Create locality mask (only attend to nearby points)
        mask = (dist < self.local_radius).float()  # (B, N_out, N_in)
        
        # Encode distances
        dist_feats = self.dist_encoder(dist.unsqueeze(-1))  # (B, N_out, N_in, 64)
        
        # Fourier features
        input_feats = self.fourier(input_coords)  # (B, N_in, feat_dim)
        query_feats = self.fourier(query_coords)  # (B, N_out, feat_dim)
        
        # For each query, encode all inputs with distance
        input_feats_exp = input_feats.unsqueeze(1).expand(B, N_out, N_in, -1)
        input_rgb_exp = input_values.unsqueeze(1).expand(B, N_out, N_in, -1)
        
        # Input tokens with distance encoding
        input_tokens = self.input_proj(
            torch.cat([input_feats_exp, input_rgb_exp, dist_feats], dim=-1)
        )  # (B, N_out, N_in, d_model)
        
        # Query tokens
        query_tokens = self.query_proj(
            torch.cat([query_feats, noisy_values], dim=-1)
        )  # (B, N_out, d_model)
        
        # Process through layers with FiLM + Local Attention
        x = query_tokens
        for film_layer, attn_layer in zip(self.film_layers, self.attn_layers):
            # FiLM time modulation
            x = film_layer(x, t_emb)
            
            # Local attention with distance masking
            x = attn_layer(x, input_tokens, dist, mask)
        
        # Decode to RGB
        return self.decoder(x)


# Test model
model = LocalImplicitDiffusion(
    num_fourier_feats=256,
    d_model=512,
    num_layers=4,
    num_heads=8,
    local_radius=0.3
).to(device)

test_noisy = torch.rand(4, 204, 3).to(device)
test_query_coords = torch.rand(4, 204, 2).to(device)
test_t = torch.rand(4).to(device)
test_input_coords = torch.rand(4, 204, 2).to(device)
test_input_values = torch.rand(4, 204, 3).to(device)

test_out = model(test_noisy, test_query_coords, test_t, test_input_coords, test_input_values)
print(f"Model test: {test_out.shape}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 2. Training: Flow Matching (Recommended)

Simplest and fastest approach - train with straight-path flow matching

In [None]:
def conditional_flow(x_0, x_1, t):
    """Linear interpolation: (1-t)*x_0 + t*x_1"""
    return (1 - t) * x_0 + t * x_1

def target_velocity(x_0, x_1):
    """Target velocity: x_1 - x_0"""
    return x_1 - x_0

@torch.no_grad()
def heun_sample(model, output_coords, input_coords, input_values, num_steps=50, device='cuda'):
    """Heun ODE solver for flow matching"""
    B, N_out = output_coords.shape[0], output_coords.shape[1]
    x_t = torch.randn(B, N_out, 3, device=device)
    
    dt = 1.0 / num_steps
    ts = torch.linspace(0, 1 - dt, num_steps)
    
    for t_val in tqdm(ts, desc="Sampling", leave=False):
        t = torch.full((B,), t_val.item(), device=device)
        t_next = torch.full((B,), t_val.item() + dt, device=device)
        
        v1 = model(x_t, output_coords, t, input_coords, input_values)
        x_next_pred = x_t + dt * v1
        
        v2 = model(x_next_pred, output_coords, t_next, input_coords, input_values)
        x_t = x_t + dt * 0.5 * (v1 + v2)
    
    return torch.clamp(x_t, 0, 1)

def train_flow_matching(
    model, train_loader, test_loader, epochs=100, lr=1e-4, device='cuda',
    visualize_every=5, eval_every=2
):
    """Train with flow matching"""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    losses = []
    
    viz_batch = next(iter(train_loader))
    viz_input_coords = viz_batch['input_coords'][:4].to(device)
    viz_input_values = viz_batch['input_values'][:4].to(device)
    viz_output_coords = viz_batch['output_coords'][:4].to(device)
    viz_output_values = viz_batch['output_values'][:4].to(device)
    viz_full_images = viz_batch['full_image'][:4].to(device)
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_coords = batch['input_coords'].to(device)
            input_values = batch['input_values'].to(device)
            output_coords = batch['output_coords'].to(device)
            output_values = batch['output_values'].to(device)
            
            B = input_coords.shape[0]
            t = torch.rand(B, device=device)
            
            x_0 = torch.randn_like(output_values)
            x_1 = output_values
            
            t_broadcast = t.view(B, 1, 1)
            x_t = conditional_flow(x_0, x_1, t_broadcast)
            u_t = target_velocity(x_0, x_1)
            
            v_pred = model(x_t, output_coords, t, input_coords, input_values)
            loss = F.mse_loss(v_pred, u_t)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        scheduler.step()
        
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.6f}, LR = {scheduler.get_last_lr()[0]:.6f}")
        
        # Evaluation
        if (epoch + 1) % eval_every == 0 or epoch == 0:
            model.eval()
            tracker = MetricsTracker()
            with torch.no_grad():
                for i, batch in enumerate(test_loader):
                    if i >= 10:
                        break
                    pred_values = heun_sample(
                        model, batch['output_coords'].to(device),
                        batch['input_coords'].to(device), batch['input_values'].to(device),
                        num_steps=50, device=device
                    )
                    tracker.update(pred_values, batch['output_values'].to(device))
                results = tracker.compute()
                print(f"  Eval - MSE: {results['mse']:.6f}, MAE: {results['mae']:.6f}")
        
        # Visualization
        if (epoch + 1) % visualize_every == 0 or epoch == 0:
            model.eval()
            with torch.no_grad():
                pred_values = heun_sample(
                    model, viz_output_coords, viz_input_coords, viz_input_values,
                    num_steps=50, device=device
                )
                fig = visualize_predictions(
                    viz_input_coords, viz_input_values, viz_output_coords,
                    pred_values, viz_output_values, viz_full_images, n_samples=4
                )
                plt.suptitle(f'Local Implicit - Epoch {epoch+1}', fontsize=14, y=1.02)
                plt.savefig(f'local_implicit_epoch_{epoch+1:03d}.png', dpi=150, bbox_inches='tight')
                plt.show()
                plt.close()
    
    return losses

## 3. Load Data and Train

In [None]:
# Load dataset
train_dataset = SparseCIFAR10Dataset(
    root='../data', train=True, input_ratio=0.2, output_ratio=0.2, download=True, seed=42
)
test_dataset = SparseCIFAR10Dataset(
    root='../data', train=False, input_ratio=0.2, output_ratio=0.2, download=True, seed=42
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")

# Initialize model
model = LocalImplicitDiffusion(
    num_fourier_feats=256,
    d_model=512,
    num_layers=4,
    num_heads=8,
    local_radius=0.3
).to(device)

# Train
losses = train_flow_matching(model, train_loader, test_loader, epochs=100, lr=1e-4, device=device)

## 4. Final Evaluation: Full Image Reconstruction

In [None]:
# Plot loss
plt.figure(figsize=(10, 4))
plt.plot(losses, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss: Local Implicit Diffusion')
plt.grid(alpha=0.3)
plt.show()

# Full image reconstruction
def create_full_grid(image_size=32, device='cuda'):
    y, x = torch.meshgrid(
        torch.linspace(0, 1, image_size),
        torch.linspace(0, 1, image_size),
        indexing='ij'
    )
    return torch.stack([x.flatten(), y.flatten()], dim=-1).to(device)

full_coords = create_full_grid(32, device)

model.eval()
tracker_full = MetricsTracker()

for i, batch in enumerate(tqdm(test_loader, desc="Full Reconstruction")):
    if i >= 50:
        break
    
    B = batch['input_coords'].shape[0]
    full_coords_batch = full_coords.unsqueeze(0).expand(B, -1, -1)
    
    pred_values = heun_sample(
        model, full_coords_batch,
        batch['input_coords'].to(device),
        batch['input_values'].to(device),
        num_steps=100, device=device
    )
    
    pred_images = pred_values.view(B, 32, 32, 3).permute(0, 3, 1, 2)
    tracker_full.update(None, None, pred_images, batch['full_image'].to(device))

results = tracker_full.compute()
print(f"\nFull Image Reconstruction:")
print(f"  PSNR: {results['psnr']:.2f} dB")
print(f"  SSIM: {results['ssim']:.4f}")

## 5. Visualize Full Reconstructions

In [None]:
sample_batch = next(iter(test_loader))
B = 4
full_coords_batch = full_coords.unsqueeze(0).expand(B, -1, -1)

pred_values = heun_sample(
    model, full_coords_batch,
    sample_batch['input_coords'][:B].to(device),
    sample_batch['input_values'][:B].to(device),
    num_steps=100, device=device
)
pred_images = pred_values.view(B, 32, 32, 3).permute(0, 3, 1, 2)

fig, axes = plt.subplots(4, 3, figsize=(12, 16))
for i in range(4):
    # Ground truth
    axes[i, 0].imshow(sample_batch['full_image'][i].permute(1, 2, 0).numpy())
    axes[i, 0].set_title('Ground Truth')
    axes[i, 0].axis('off')
    
    # Sparse input
    input_img = torch.zeros(3, 32, 32)
    input_idx = sample_batch['input_indices'][i]
    input_img.view(3, -1)[:, input_idx] = sample_batch['input_values'][i].T
    axes[i, 1].imshow(input_img.permute(1, 2, 0).numpy())
    axes[i, 1].set_title(f'Input (20%)')
    axes[i, 1].axis('off')
    
    # Reconstruction
    axes[i, 2].imshow(np.clip(pred_images[i].permute(1, 2, 0).cpu().numpy(), 0, 1))
    axes[i, 2].set_title('Reconstructed')
    axes[i, 2].axis('off')

plt.suptitle('Local Implicit Diffusion: Full Image Reconstruction', fontsize=14, y=0.995)
plt.tight_layout()
plt.savefig('local_implicit_full_reconstruction.png', dpi=150, bbox_inches='tight')
plt.show()