# Approach 1: Latent Diffusion + Neural Field Decoder

## Overview

This notebook implements the simplest baseline approach for merging neural fields with diffusion models:

```
Sparse Input → Encoder → Latent z
    ↓
Diffusion Process → z_0 (denoised latent)
    ↓
Neural Field Decoder(coords, z_0) → Continuous Output
```

## Why Start Here?

- **Easiest to implement**: Well-understood components
- **Debuggable**: Can test each component separately
- **Foundation**: Concepts transfer to more complex approaches

## Toy Problem: 1D Signal Reconstruction

We'll use a simple 1D problem to validate the approach:
- **Input**: Sparse samples from a 1D signal (e.g., sine wave + noise)
- **Goal**: Reconstruct continuous signal at arbitrary resolution
- **Sparse Pattern**: Random 20% of points observed

## Architecture Components

1. **Encoder**: MLP that maps sparse observations → latent code z
2. **Neural Field**: SIREN network f(x, z) that maps coordinates + latent → values
3. **Diffusion**: Simple DDPM on latent codes

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Data Generation: 1D Signals with Sparse Observations

In [None]:
def generate_1d_signal(n_points=256, signal_type='sine_mix'):
    """
    Generate various 1D signals for testing
    
    Args:
        n_points: Number of sample points
        signal_type: 'sine_mix', 'square', 'triangle', 'random_fourier'
    
    Returns:
        coords: (n_points,) normalized coordinates in [0, 1]
        values: (n_points,) signal values
    """
    coords = np.linspace(0, 1, n_points)
    
    if signal_type == 'sine_mix':
        # Mix of sine waves with different frequencies
        values = (
            0.5 * np.sin(2 * np.pi * coords * 2) +
            0.3 * np.sin(2 * np.pi * coords * 5) +
            0.2 * np.sin(2 * np.pi * coords * 10)
        )
    elif signal_type == 'square':
        values = np.sign(np.sin(2 * np.pi * coords * 3))
    elif signal_type == 'triangle':
        values = 2 * np.abs(2 * (coords * 4 - np.floor(coords * 4 + 0.5))) - 1
    elif signal_type == 'random_fourier':
        # Random Fourier series
        np.random.seed(42)
        n_freqs = 10
        values = np.zeros_like(coords)
        for k in range(1, n_freqs + 1):
            amp = np.random.randn() * (1 / k)
            phase = np.random.rand() * 2 * np.pi
            values += amp * np.sin(2 * np.pi * k * coords + phase)
    else:
        raise ValueError(f"Unknown signal type: {signal_type}")
    
    # Normalize to [-1, 1]
    values = values / (np.abs(values).max() + 1e-8)
    
    return coords.astype(np.float32), values.astype(np.float32)


class Sparse1DDataset(Dataset):
    """
    Dataset of 1D signals with sparse observations
    """
    def __init__(self, n_samples=1000, n_points=256, sparsity=0.2, signal_types=None):
        """
        Args:
            n_samples: Number of different signals
            n_points: Resolution of each signal
            sparsity: Fraction of points observed (0.2 = 20% observed)
            signal_types: List of signal types to sample from
        """
        self.n_samples = n_samples
        self.n_points = n_points
        self.sparsity = sparsity
        self.n_observed = int(n_points * sparsity)
        
        if signal_types is None:
            signal_types = ['sine_mix', 'random_fourier']
        self.signal_types = signal_types
        
        # Pre-generate all signals for reproducibility
        self.signals = []
        np.random.seed(42)
        for i in range(n_samples):
            signal_type = np.random.choice(signal_types)
            coords, values = generate_1d_signal(n_points, signal_type)
            self.signals.append((coords, values))
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        coords, values = self.signals[idx]
        
        # Random sparse sampling
        observed_idxs = np.random.choice(
            self.n_points, 
            size=self.n_observed, 
            replace=False
        )
        observed_idxs = np.sort(observed_idxs)
        
        sparse_coords = coords[observed_idxs]
        sparse_values = values[observed_idxs]
        
        return {
            'sparse_coords': torch.from_numpy(sparse_coords),  # (n_observed,)
            'sparse_values': torch.from_numpy(sparse_values),  # (n_observed,)
            'full_coords': torch.from_numpy(coords),           # (n_points,)
            'full_values': torch.from_numpy(values),           # (n_points,)
        }


# Test data generation
coords, values = generate_1d_signal(256, 'sine_mix')
plt.figure(figsize=(12, 3))
plt.plot(coords, values, 'b-', linewidth=2, label='Full signal')

# Show sparse sampling
sparse_idxs = np.random.choice(256, size=int(256 * 0.2), replace=False)
plt.scatter(coords[sparse_idxs], values[sparse_idxs], c='r', s=30, label='Sparse observations (20%)', zorder=10)
plt.xlabel('Coordinate')
plt.ylabel('Value')
plt.title('Example 1D Signal with Sparse Observations')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## 2. Neural Field Decoder (SIREN-based)

SIREN uses periodic activations (sine) which are well-suited for representing continuous signals.

In [None]:
class SineLayer(nn.Module):
    """Sine activation with frequency modulation (SIREN layer)"""
    
    def __init__(self, in_features, out_features, bias=True, omega_0=30.0, is_first=False):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.linear.in_features, 
                                           1 / self.linear.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.linear.in_features) / self.omega_0,
                                           np.sqrt(6 / self.linear.in_features) / self.omega_0)
    
    def forward(self, x):
        return torch.sin(self.omega_0 * self.linear(x))


class NeuralFieldDecoder(nn.Module):
    """
    Neural Field Decoder: f(coords, z) → values
    
    Uses SIREN architecture conditioned on latent code z
    """
    def __init__(self, latent_dim=64, hidden_dim=128, n_layers=3, omega_0=30.0):
        super().__init__()
        self.latent_dim = latent_dim
        
        # First layer: coords (1D) + latent (latent_dim) → hidden
        self.first_layer = SineLayer(
            in_features=1 + latent_dim,
            out_features=hidden_dim,
            omega_0=omega_0,
            is_first=True
        )
        
        # Hidden layers
        self.hidden_layers = nn.ModuleList([
            SineLayer(hidden_dim, hidden_dim, omega_0=omega_0)
            for _ in range(n_layers - 1)
        ])
        
        # Output layer: hidden → 1 (value)
        self.output_layer = nn.Linear(hidden_dim, 1)
        
        # Initialize output layer
        with torch.no_grad():
            self.output_layer.weight.uniform_(
                -np.sqrt(6 / hidden_dim) / omega_0,
                np.sqrt(6 / hidden_dim) / omega_0
            )
    
    def forward(self, coords, latent):
        """
        Args:
            coords: (B, N, 1) query coordinates
            latent: (B, latent_dim) latent code
        
        Returns:
            values: (B, N, 1) predicted values
        """
        B, N, _ = coords.shape
        
        # Broadcast latent to all coordinates
        latent_expanded = latent.unsqueeze(1).expand(B, N, self.latent_dim)  # (B, N, latent_dim)
        
        # Concatenate coords with latent
        x = torch.cat([coords, latent_expanded], dim=-1)  # (B, N, 1 + latent_dim)
        
        # Forward through SIREN
        x = self.first_layer(x)
        for layer in self.hidden_layers:
            x = layer(x)
        values = self.output_layer(x)
        
        return values


# Test neural field
nf = NeuralFieldDecoder(latent_dim=64, hidden_dim=128, n_layers=3).to(device)
test_coords = torch.linspace(0, 1, 100).unsqueeze(0).unsqueeze(-1).to(device)  # (1, 100, 1)
test_latent = torch.randn(1, 64).to(device)  # (1, 64)
test_output = nf(test_coords, test_latent)
print(f"Neural Field test: coords {test_coords.shape} + latent {test_latent.shape} → output {test_output.shape}")

# Visualize random neural field outputs
plt.figure(figsize=(12, 4))
coords_np = test_coords[0].cpu().numpy()
for i in range(5):
    latent = torch.randn(1, 64).to(device)
    output = nf(test_coords, latent)
    plt.plot(coords_np, output[0].detach().cpu().numpy(), alpha=0.7, label=f'Random latent {i+1}')
plt.xlabel('Coordinate')
plt.ylabel('Value')
plt.title('Neural Field Decoder with Random Latent Codes')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## 3. Encoder: Sparse Observations → Latent Code

Simple MLP that encodes sparse observations into a fixed-size latent code.

In [None]:
class SparseEncoder(nn.Module):
    """
    Encoder: Sparse observations → Latent code
    
    Uses simple MLP on flattened sparse (coord, value) pairs
    """
    def __init__(self, max_sparse_points=64, latent_dim=64, hidden_dim=256):
        super().__init__()
        self.max_sparse_points = max_sparse_points
        self.latent_dim = latent_dim
        
        # Input: flattened (coord, value) pairs
        input_dim = max_sparse_points * 2  # (coord, value) for each point
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
    
    def forward(self, sparse_coords, sparse_values):
        """
        Args:
            sparse_coords: (B, n_sparse) observed coordinates
            sparse_values: (B, n_sparse) observed values
        
        Returns:
            latent: (B, latent_dim) latent code
        """
        B, n_sparse = sparse_coords.shape
        
        # Concatenate coords and values
        sparse_data = torch.stack([sparse_coords, sparse_values], dim=-1)  # (B, n_sparse, 2)
        
        # Pad to max_sparse_points if needed
        if n_sparse < self.max_sparse_points:
            padding = torch.zeros(B, self.max_sparse_points - n_sparse, 2, device=sparse_data.device)
            sparse_data = torch.cat([sparse_data, padding], dim=1)
        
        # Flatten and encode
        sparse_flat = sparse_data.view(B, -1)  # (B, max_sparse_points * 2)
        latent = self.encoder(sparse_flat)  # (B, latent_dim)
        
        return latent


# Test encoder
encoder = SparseEncoder(max_sparse_points=64, latent_dim=64).to(device)
test_sparse_coords = torch.rand(4, 51).to(device)  # Batch of 4, 51 sparse points
test_sparse_values = torch.rand(4, 51).to(device)
test_latent = encoder(test_sparse_coords, test_sparse_values)
print(f"Encoder test: sparse coords {test_sparse_coords.shape} + values {test_sparse_values.shape} → latent {test_latent.shape}")

## 4. Complete Model: Encoder + Neural Field (No Diffusion Yet)

First, let's train the encoder + neural field to reconstruct signals from sparse observations.
This validates that the architecture can learn the mapping before adding diffusion.

In [None]:
class SparseToFieldModel(nn.Module):
    """Complete model: Sparse observations → Latent → Continuous field"""
    
    def __init__(self, max_sparse_points=64, latent_dim=64, hidden_dim=256):
        super().__init__()
        self.encoder = SparseEncoder(max_sparse_points, latent_dim, hidden_dim)
        self.decoder = NeuralFieldDecoder(latent_dim, hidden_dim=128, n_layers=3)
    
    def forward(self, sparse_coords, sparse_values, query_coords):
        """
        Args:
            sparse_coords: (B, n_sparse) observed coordinates
            sparse_values: (B, n_sparse) observed values
            query_coords: (B, n_query, 1) coordinates to query
        
        Returns:
            pred_values: (B, n_query, 1) predicted values
            latent: (B, latent_dim) latent code (for diffusion later)
        """
        latent = self.encoder(sparse_coords, sparse_values)
        pred_values = self.decoder(query_coords, latent)
        return pred_values, latent


# Training function
def train_sparse_to_field(model, train_loader, epochs=50, lr=1e-4):
    """Train encoder + decoder on sparse reconstruction task"""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    losses = []
    
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            sparse_coords = batch['sparse_coords'].to(device)  # (B, n_sparse)
            sparse_values = batch['sparse_values'].to(device)  # (B, n_sparse)
            full_coords = batch['full_coords'].to(device).unsqueeze(-1)  # (B, n_full, 1)
            full_values = batch['full_values'].to(device).unsqueeze(-1)  # (B, n_full, 1)
            
            # Forward pass
            pred_values, latent = model(sparse_coords, sparse_values, full_coords)
            
            # MSE loss on full signal reconstruction
            loss = F.mse_loss(pred_values, full_values)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.6f}")
    
    return losses


# Create dataset and dataloader
train_dataset = Sparse1DDataset(n_samples=1000, n_points=256, sparsity=0.2)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Initialize model
model = SparseToFieldModel(max_sparse_points=64, latent_dim=64, hidden_dim=256).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train
losses = train_sparse_to_field(model, train_loader, epochs=20, lr=1e-4)

## 5. Evaluation: Visualize Reconstructions

In [None]:
# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(losses, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training Loss: Sparse to Field Reconstruction')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

# Visualize reconstructions
model.eval()
test_dataset = Sparse1DDataset(n_samples=5, n_points=256, sparsity=0.2)

fig, axes = plt.subplots(1, 5, figsize=(20, 3))
with torch.no_grad():
    for i, ax in enumerate(axes):
        sample = test_dataset[i]
        sparse_coords = sample['sparse_coords'].unsqueeze(0).to(device)
        sparse_values = sample['sparse_values'].unsqueeze(0).to(device)
        full_coords = sample['full_coords'].unsqueeze(0).unsqueeze(-1).to(device)
        full_values = sample['full_values'].cpu().numpy()
        
        # Predict
        pred_values, _ = model(sparse_coords, sparse_values, full_coords)
        pred_values = pred_values[0].squeeze().cpu().numpy()
        
        # Plot
        coords_np = sample['full_coords'].numpy()
        ax.plot(coords_np, full_values, 'b-', linewidth=2, alpha=0.7, label='Ground truth')
        ax.plot(coords_np, pred_values, 'r--', linewidth=2, alpha=0.7, label='Reconstruction')
        ax.scatter(
            sample['sparse_coords'].numpy(), 
            sample['sparse_values'].numpy(), 
            c='green', s=30, zorder=10, label='Sparse obs'
        )
        ax.set_xlabel('Coordinate')
        ax.set_ylabel('Value')
        ax.set_title(f'Test Signal {i+1}')
        ax.legend(fontsize=8)
        ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

# Compute metrics
mse_errors = []
with torch.no_grad():
    for i in range(len(test_dataset)):
        sample = test_dataset[i]
        sparse_coords = sample['sparse_coords'].unsqueeze(0).to(device)
        sparse_values = sample['sparse_values'].unsqueeze(0).to(device)
        full_coords = sample['full_coords'].unsqueeze(0).unsqueeze(-1).to(device)
        full_values = sample['full_values'].unsqueeze(0).unsqueeze(-1).to(device)
        
        pred_values, _ = model(sparse_coords, sparse_values, full_coords)
        mse = F.mse_loss(pred_values, full_values).item()
        mse_errors.append(mse)

print(f"\nTest MSE: {np.mean(mse_errors):.6f} ± {np.std(mse_errors):.6f}")

## 6. Next Steps: Add Diffusion Model

Now that we have a working encoder + neural field decoder, the next step is to add a diffusion model on the latent codes.

### Plan:
1. **Extract latent codes** from trained encoder on full dataset
2. **Train DDPM** on latent space
3. **Generation**: Sample z ~ p(z) from diffusion → Decode with neural field
4. **Conditional generation**: Condition diffusion on sparse observations

### Why Diffusion on Latents?
- **Smaller space**: 64D latent vs 256D signal
- **Semantic**: Latents capture signal structure
- **Fast sampling**: Fewer diffusion steps needed

This will be implemented in the next notebook: `02_approach1_with_diffusion.ipynb`

## Summary

✅ **Implemented**:
- SIREN-based neural field decoder for continuous 1D signal representation
- Sparse observation encoder (MLP-based)
- End-to-end training on sparse-to-dense reconstruction

✅ **Validated**:
- Neural field can represent continuous signals from latent codes
- Encoder can compress sparse observations into meaningful latents
- Model successfully reconstructs full signals from 20% sparse observations

📋 **Next Steps**:
1. Add DDPM diffusion model on latent space
2. Train diffusion for unconditional generation
3. Add conditioning on sparse observations
4. Test arbitrary-resolution querying (super-resolution)
5. Extend to 2D images (CIFAR-10 or MNIST)