# MAMBA Diffusion with Gaussian Fourier Features + Hilbert Curve Sorting

## Key Innovations

### 1. Gaussian Fourier Features (GFF)
- **Random projection**: Use learnable Gaussian matrices for frequency sampling
- **Better spectral coverage**: Captures wider range of frequencies vs fixed sinusoidal
- **Theoretical basis**: Approximates RBF kernel (Rahimi & Recht, 2007)
- **Formula**: $\gamma(v) = [\cos(2\pi B v), \sin(2\pi B v)]$ where $B \sim \mathcal{N}(0, \sigma^2)$

### 2. Hilbert Curve Scanning
- **Space-filling curve**: Maps 2D coordinates to 1D sequence preserving locality
- **Better for SSM**: Sequential models work best when adjacent elements are spatially close
- **Locality preservation**: Points close in 2D remain close in 1D sequence
- **Why it matters**: SSM state propagation benefits from spatial coherence

### Architecture Flow
```
Sparse Input + Query Points
        ↓
Gaussian Fourier Features (learnable B matrix)
        ↓
Hilbert Curve Sorting (2D → 1D locality-preserving)
        ↓
SSM Layers (state space propagation)
        ↓
Unsort back to original order
        ↓
Cross-Attention (extract query features)
        ↓
MLP Decoder → Predicted RGB
```

In [None]:
import sys
import os

# Add parent directory to path
notebook_dir = os.path.abspath('')
parent_dir = os.path.dirname(notebook_dir)
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
from hilbertcurve.hilbertcurve import HilbertCurve

from core.sparse.cifar10_sparse import SparseCIFAR10Dataset
from core.sparse.metrics import MetricsTracker, print_metrics, visualize_predictions

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"Working directory: {os.getcwd()}")
print(f"Parent directory added to path: {parent_dir}")

## 1. Gaussian Fourier Features

Based on "Random Features for Large-Scale Kernel Machines" (Rahimi & Recht, 2007)

Key idea: Random Fourier features approximate shift-invariant kernels.
For RBF kernel $k(x, y) = \exp(-\|x - y\|^2 / (2\sigma^2))$, we can use:
$$\phi(x) = \sqrt{2/D} [\cos(\omega_1^T x), ..., \cos(\omega_D^T x), \sin(\omega_1^T x), ..., \sin(\omega_D^T x)]$$
where $\omega_i \sim \mathcal{N}(0, \sigma^{-2} I)$

In [None]:
class GaussianFourierFeatures(nn.Module):
    """
    Gaussian Fourier Features for positional encoding
    
    Uses learnable random Gaussian projection matrix for better
    spectral coverage compared to fixed sinusoidal encoding.
    
    Args:
        coord_dim: Input coordinate dimension (2 for 2D images)
        num_features: Number of Fourier features (output will be 2*num_features)
        scale: Standard deviation for Gaussian sampling (controls frequency range)
        learnable: Whether to make the projection matrix learnable
    """
    def __init__(self, coord_dim=2, num_features=256, scale=10.0, learnable=True):
        super().__init__()
        self.coord_dim = coord_dim
        self.num_features = num_features
        self.scale = scale
        
        # Random Gaussian projection matrix: (coord_dim, num_features)
        # Sample from N(0, scale^2)
        B = torch.randn(coord_dim, num_features) * scale
        
        if learnable:
            self.B = nn.Parameter(B)
        else:
            self.register_buffer('B', B)
    
    def forward(self, coords):
        """
        Args:
            coords: (B, N, coord_dim) coordinates in [0, 1]
        Returns:
            features: (B, N, 2*num_features) Fourier features
        """
        # Project coordinates: (B, N, coord_dim) @ (coord_dim, num_features) -> (B, N, num_features)
        proj = 2 * math.pi * coords @ self.B
        
        # Concatenate sin and cos
        return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)  # (B, N, 2*num_features)


# Test Gaussian Fourier Features
gff = GaussianFourierFeatures(coord_dim=2, num_features=128, scale=10.0, learnable=True).to(device)
test_coords = torch.rand(4, 100, 2).to(device)
test_feats = gff(test_coords)
print(f"Gaussian Fourier Features test:")
print(f"  Input: {test_coords.shape}")
print(f"  Output: {test_feats.shape}")
print(f"  Learnable parameters: {sum(p.numel() for p in gff.parameters()):,}")

## 2. Hilbert Curve Sorting

Space-filling curves map multi-dimensional space to 1D while preserving locality.
This is crucial for SSMs which process sequences - we want spatially adjacent
pixels to be temporally adjacent in the sequence.

**Why Hilbert over other orderings:**
- Better locality preservation than Z-order/Morton curves
- No jumps across the image (vs raster scan)
- Continuous and differentiable (approximately)
- Works at any resolution (recursive construction)

In [None]:
class HilbertSorter:
    """
    Utility for sorting 2D coordinates using Hilbert curve order
    
    This preserves spatial locality when feeding coordinates to sequential models.
    """
    def __init__(self, grid_size=32, p=5):
        """
        Args:
            grid_size: Resolution of coordinate grid (assumes square)
            p: Hilbert curve order (2^p iterations)
        """
        self.grid_size = grid_size
        self.p = p
        self.hilbert = HilbertCurve(p, 2)  # 2D Hilbert curve
        
        # Pre-compute mapping from Hilbert index to grid coordinates
        self._build_index_map()
    
    def _build_index_map(self):
        """Build mapping between Hilbert indices and grid coordinates"""
        max_h = 2 ** (self.p * 2)  # Total points in Hilbert curve
        
        # Map: hilbert_index -> (x, y) in grid
        self.hilbert_to_grid = {}
        for h in range(max_h):
            coords = self.hilbert.point_from_distance(h)
            self.hilbert_to_grid[h] = coords
    
    def coords_to_hilbert_indices(self, coords, grid_size=None):
        """
        Convert continuous coordinates [0, 1] to Hilbert curve indices
        
        Args:
            coords: (B, N, 2) coordinates in [0, 1]
            grid_size: Override grid size (for different resolutions)
        Returns:
            hilbert_indices: (B, N) integer indices
        """
        if grid_size is None:
            grid_size = self.grid_size
        
        B, N, _ = coords.shape
        
        # Discretize coordinates to grid
        grid_coords = (coords * (grid_size - 1)).long()  # (B, N, 2)
        grid_coords = torch.clamp(grid_coords, 0, grid_size - 1)
        
        # Convert to Hilbert indices
        hilbert_indices = torch.zeros(B, N, dtype=torch.long, device=coords.device)
        
        for b in range(B):
            for n in range(N):
                x, y = grid_coords[b, n].cpu().tolist()
                h_idx = self.hilbert.distance_from_point([x, y])
                hilbert_indices[b, n] = h_idx
        
        return hilbert_indices
    
    def sort_by_hilbert(self, coords, values):
        """
        Sort coordinates and values by Hilbert curve order
        
        Args:
            coords: (B, N, 2) coordinates
            values: (B, N, D) associated values
        Returns:
            sorted_coords: (B, N, 2)
            sorted_values: (B, N, D)
            sort_indices: (B, N) for unsorting later
        """
        B, N = coords.shape[:2]
        
        # Get Hilbert indices
        hilbert_indices = self.coords_to_hilbert_indices(coords)  # (B, N)
        
        # Sort by Hilbert order
        sort_indices = torch.argsort(hilbert_indices, dim=1)  # (B, N)
        
        # Apply sorting
        sorted_coords = torch.gather(
            coords, 1, sort_indices.unsqueeze(-1).expand(-1, -1, 2)
        )
        sorted_values = torch.gather(
            values, 1, sort_indices.unsqueeze(-1).expand(-1, -1, values.shape[-1])
        )
        
        return sorted_coords, sorted_values, sort_indices
    
    def unsort(self, sorted_values, sort_indices):
        """
        Restore original order from Hilbert-sorted sequence
        
        Args:
            sorted_values: (B, N, D) Hilbert-sorted values
            sort_indices: (B, N) indices from sort_by_hilbert
        Returns:
            original_values: (B, N, D) in original order
        """
        B, N, D = sorted_values.shape
        
        # Create inverse permutation
        unsort_indices = torch.argsort(sort_indices, dim=1)  # (B, N)
        
        # Apply inverse sorting
        original_values = torch.gather(
            sorted_values, 1, unsort_indices.unsqueeze(-1).expand(-1, -1, D)
        )
        
        return original_values


# Test Hilbert sorting
print("\nTesting Hilbert Curve Sorting:")
sorter = HilbertSorter(grid_size=32, p=5)

# Create test coordinates
test_coords = torch.rand(2, 20, 2)
test_values = torch.rand(2, 20, 3)

# Sort
sorted_coords, sorted_values, sort_idx = sorter.sort_by_hilbert(test_coords, test_values)
print(f"  Original coords: {test_coords.shape}")
print(f"  Sorted coords: {sorted_coords.shape}")

# Unsort
restored_values = sorter.unsort(sorted_values, sort_idx)
print(f"  Restored values: {restored_values.shape}")
print(f"  Restoration error: {(restored_values - test_values).abs().max().item():.6f}")

# Visualize Hilbert curve order
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Original order
axes[0].scatter(test_coords[0, :, 0], test_coords[0, :, 1], c=range(20), cmap='viridis', s=100)
axes[0].plot(test_coords[0, :, 0], test_coords[0, :, 1], 'k-', alpha=0.3, linewidth=0.5)
axes[0].set_title('Original Order (Random)')
axes[0].set_xlim(0, 1)
axes[0].set_ylim(0, 1)
axes[0].set_aspect('equal')

# Hilbert order
axes[1].scatter(sorted_coords[0, :, 0], sorted_coords[0, :, 1], c=range(20), cmap='viridis', s=100)
axes[1].plot(sorted_coords[0, :, 0], sorted_coords[0, :, 1], 'r-', alpha=0.5, linewidth=1.5)
axes[1].set_title('Hilbert Order (Locality-Preserving)')
axes[1].set_xlim(0, 1)
axes[1].set_ylim(0, 1)
axes[1].set_aspect('equal')

plt.tight_layout()
plt.savefig('hilbert_curve_ordering.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. SSM Blocks (from original MAMBA)

In [None]:
class SSMBlockFast(nn.Module):
    """Ultra-fast SSM using cumulative scan"""
    def __init__(self, d_model, d_state=16, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # State space parameters
        self.A_log = nn.Parameter(torch.randn(d_state) * 0.1 - 1.0)
        self.B = nn.Linear(d_model, d_state, bias=False)
        self.C = nn.Linear(d_state, d_model, bias=False)
        self.D = nn.Parameter(torch.randn(d_model) * 0.01)
        
        nn.init.xavier_uniform_(self.B.weight, gain=0.5)
        nn.init.xavier_uniform_(self.C.weight, gain=0.5)
        
        self.gate = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Sigmoid()
        )
        
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.eps = 1e-8
    
    def forward(self, x):
        """Fully vectorized SSM forward pass"""
        B, N, D = x.shape
        
        # Discretization
        A = -torch.exp(self.A_log).clamp(min=self.eps, max=10.0)
        dt = 1.0 / N
        A_bar = torch.exp(dt * A)
        B_bar = torch.where(
            torch.abs(A) > self.eps,
            (A_bar - 1.0) / (A + self.eps),
            torch.ones_like(A) * dt
        )
        
        # Input projection
        Bu = self.B(x) * B_bar
        
        # Create exponential decay matrix
        indices = torch.arange(N, device=x.device)
        decay = A_bar.unsqueeze(0).pow(
            (indices.unsqueeze(0) - indices.unsqueeze(1)).clamp(min=0).unsqueeze(-1)
        )
        mask = indices.unsqueeze(0) >= indices.unsqueeze(1)
        decay = decay * mask.unsqueeze(-1).float()
        
        # Compute states
        h = torch.einsum('nmd,bnd->bmd', decay, Bu)
        h = torch.clamp(h, min=-10.0, max=10.0)
        
        # Output
        y = self.C(h) + self.D * x
        
        # Gating and residual
        gate = self.gate(x)
        y = gate * y + (1 - gate) * x
        
        return self.dropout(self.norm(y))


class MambaBlock(nn.Module):
    """Complete Mamba block with FAST SSM + MLP"""
    def __init__(self, d_model, d_state=16, expand_factor=2, dropout=0.1):
        super().__init__()
        
        self.proj_in = nn.Linear(d_model, d_model * expand_factor)
        self.ssm = SSMBlockFast(d_model * expand_factor, d_state, dropout)
        self.proj_out = nn.Linear(d_model * expand_factor, d_model)
        
        self.mlp = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        # SSM branch
        residual = x
        x = self.proj_in(x)
        x = self.ssm(x)
        x = self.proj_out(x)
        x = x + residual
        
        # MLP branch
        x = x + self.mlp(x)
        
        return x

## 4. Time Embedding

In [None]:
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb

## 5. Main Architecture: MAMBA with Gaussian Fourier + Hilbert

In [None]:
class MAMBADiffusionGaussianHilbert(nn.Module):
    """
    MAMBA Diffusion with:
    1. Gaussian Fourier Features (learnable random projection)
    2. Hilbert Curve sorting (locality-preserving sequence ordering)
    """
    def __init__(
        self,
        num_fourier_feats=256,
        d_model=512,
        num_layers=6,
        d_state=16,
        dropout=0.1,
        gaussian_scale=10.0,
        learnable_gff=True,
        use_hilbert=True,
        hilbert_grid_size=32
    ):
        super().__init__()
        self.d_model = d_model
        self.use_hilbert = use_hilbert
        
        # Gaussian Fourier features
        self.fourier = GaussianFourierFeatures(
            coord_dim=2,
            num_features=num_fourier_feats,
            scale=gaussian_scale,
            learnable=learnable_gff
        )
        feat_dim = num_fourier_feats * 2
        
        # Hilbert curve sorter
        if use_hilbert:
            self.hilbert_sorter = HilbertSorter(grid_size=hilbert_grid_size, p=5)
        
        # Project inputs and queries
        self.input_proj = nn.Linear(feat_dim + 3, d_model)
        self.query_proj = nn.Linear(feat_dim + 3, d_model)
        
        # Time embedding
        self.time_embed = SinusoidalTimeEmbedding(d_model)
        self.time_mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model)
        )
        
        # Mamba blocks
        self.mamba_blocks = nn.ModuleList([
            MambaBlock(d_model, d_state=d_state, expand_factor=2, dropout=dropout)
            for _ in range(num_layers)
        ])
        
        # Cross-attention
        self.query_cross_attn = nn.MultiheadAttention(
            d_model, num_heads=8, dropout=dropout, batch_first=True
        )
        
        # Output decoder
        self.decoder = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 3)
        )
    
    def forward(self, noisy_values, query_coords, t, input_coords, input_values):
        """
        Args:
            noisy_values: (B, N_out, 3)
            query_coords: (B, N_out, 2)
            t: (B,) timestep
            input_coords: (B, N_in, 2)
            input_values: (B, N_in, 3)
        """
        B = query_coords.shape[0]
        N_in = input_coords.shape[1]
        N_out = query_coords.shape[1]
        
        # Time embedding
        t_emb = self.time_mlp(self.time_embed(t))
        
        # Gaussian Fourier features
        input_feats = self.fourier(input_coords)
        query_feats = self.fourier(query_coords)
        
        # Encode tokens
        input_tokens = self.input_proj(
            torch.cat([input_feats, input_values], dim=-1)
        )
        query_tokens = self.query_proj(
            torch.cat([query_feats, noisy_values], dim=-1)
        )
        
        # Add time embedding
        input_tokens = input_tokens + t_emb.unsqueeze(1)
        query_tokens = query_tokens + t_emb.unsqueeze(1)
        
        # Concatenate into sequence
        seq = torch.cat([input_tokens, query_tokens], dim=1)
        seq_coords = torch.cat([input_coords, query_coords], dim=1)
        
        # Hilbert curve sorting (optional)
        if self.use_hilbert:
            seq, _, sort_indices = self.hilbert_sorter.sort_by_hilbert(seq_coords, seq)
        
        # Process through Mamba blocks (SSM)
        for mamba_block in self.mamba_blocks:
            seq = mamba_block(seq)
        
        # Unsort if we used Hilbert ordering
        if self.use_hilbert:
            seq = self.hilbert_sorter.unsort(seq, sort_indices)
        
        # Split back into input and query sequences
        input_seq = seq[:, :N_in, :]
        query_seq = seq[:, N_in:, :]
        
        # Cross-attention
        output, _ = self.query_cross_attn(query_seq, input_seq, input_seq)
        
        # Decode to RGB
        return self.decoder(output)


# Test model with both features
print("\nTesting MAMBA with Gaussian Fourier + Hilbert:")
model = MAMBADiffusionGaussianHilbert(
    num_fourier_feats=256,
    d_model=512,
    num_layers=6,
    d_state=16,
    gaussian_scale=10.0,
    learnable_gff=True,
    use_hilbert=True,
    hilbert_grid_size=32
).to(device)

test_noisy = torch.rand(4, 204, 3).to(device)
test_query_coords = torch.rand(4, 204, 2).to(device)
test_t = torch.rand(4).to(device)
test_input_coords = torch.rand(4, 204, 2).to(device)
test_input_values = torch.rand(4, 204, 3).to(device)

test_out = model(test_noisy, test_query_coords, test_t, test_input_coords, test_input_values)
print(f"  Output shape: {test_out.shape}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  GFF learnable params: {sum(p.numel() for p in model.fourier.parameters()):,}")

## 6. Training: Flow Matching

In [None]:
def conditional_flow(x_0, x_1, t):
    """Linear interpolation: (1-t)*x_0 + t*x_1"""
    return (1 - t) * x_0 + t * x_1

def target_velocity(x_0, x_1):
    """Target velocity: x_1 - x_0"""
    return x_1 - x_0

@torch.no_grad()
def heun_sample(model, output_coords, input_coords, input_values, num_steps=50, device='cuda'):
    """Heun ODE solver for flow matching"""
    B, N_out = output_coords.shape[0], output_coords.shape[1]
    x_t = torch.randn(B, N_out, 3, device=device)
    
    dt = 1.0 / num_steps
    ts = torch.linspace(0, 1 - dt, num_steps)
    
    for t_val in tqdm(ts, desc="Sampling", leave=False):
        t = torch.full((B,), t_val.item(), device=device)
        t_next = torch.full((B,), t_val.item() + dt, device=device)
        
        v1 = model(x_t, output_coords, t, input_coords, input_values)
        x_next_pred = x_t + dt * v1
        
        v2 = model(x_next_pred, output_coords, t_next, input_coords, input_values)
        x_t = x_t + dt * 0.5 * (v1 + v2)
    
    return torch.clamp(x_t, 0, 1)

def train_flow_matching(
    model, train_loader, test_loader, epochs=100, lr=1e-4, device='cuda',
    visualize_every=5, eval_every=2, save_dir='checkpoints_gff_hilbert'
):
    """Train with flow matching"""
    os.makedirs(save_dir, exist_ok=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    losses = []
    
    best_val_loss = float('inf')
    
    # Visualization setup
    y, x = torch.meshgrid(
        torch.linspace(0, 1, 32),
        torch.linspace(0, 1, 32),
        indexing='ij'
    )
    full_coords = torch.stack([x.flatten(), y.flatten()], dim=-1).to(device)
    
    viz_batch = next(iter(train_loader))
    viz_input_coords = viz_batch['input_coords'][:4].to(device)
    viz_input_values = viz_batch['input_values'][:4].to(device)
    viz_output_coords = viz_batch['output_coords'][:4].to(device)
    viz_output_values = viz_batch['output_values'][:4].to(device)
    viz_full_images = viz_batch['full_image'][:4].to(device)
    viz_input_indices = viz_batch['input_indices'][:4]
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_coords = batch['input_coords'].to(device)
            input_values = batch['input_values'].to(device)
            output_coords = batch['output_coords'].to(device)
            output_values = batch['output_values'].to(device)
            
            B = input_coords.shape[0]
            t = torch.rand(B, device=device)
            
            x_0 = torch.randn_like(output_values)
            x_1 = output_values
            
            t_broadcast = t.view(B, 1, 1)
            x_t = conditional_flow(x_0, x_1, t_broadcast)
            u_t = target_velocity(x_0, x_1)
            
            v_pred = model(x_t, output_coords, t, input_coords, input_values)
            loss = F.mse_loss(v_pred, u_t)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        scheduler.step()
        
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.6f}, LR = {scheduler.get_last_lr()[0]:.6f}")
        
        # Evaluation
        val_loss = None
        if (epoch + 1) % eval_every == 0 or epoch == 0:
            model.eval()
            tracker = MetricsTracker()
            val_loss_accum = 0
            val_batches = 0
            with torch.no_grad():
                for i, batch in enumerate(test_loader):
                    if i >= 10:
                        break
                    pred_values = heun_sample(
                        model, batch['output_coords'].to(device),
                        batch['input_coords'].to(device), batch['input_values'].to(device),
                        num_steps=50, device=device
                    )
                    tracker.update(pred_values, batch['output_values'].to(device))
                    val_loss_accum += F.mse_loss(pred_values, batch['output_values'].to(device)).item()
                    val_batches += 1
                    
                results = tracker.compute()
                val_loss = val_loss_accum / val_batches
                print(f"  Eval - MSE: {results['mse']:.6f}, MAE: {results['mae']:.6f}, Val Loss: {val_loss:.6f}")
                
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    torch.save({
                        'epoch': epoch + 1,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'loss': avg_loss,
                        'val_loss': val_loss,
                        'best_val_loss': best_val_loss
                    }, f'{save_dir}/mamba_gff_hilbert_best.pth')
                    print(f"  ✓ Saved best model (val_loss: {val_loss:.6f})")
        
        # Save latest
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_loss,
            'val_loss': val_loss if val_loss is not None else avg_loss,
            'best_val_loss': best_val_loss
        }, f'{save_dir}/mamba_gff_hilbert_latest.pth')
        
        # Visualization
        if (epoch + 1) % visualize_every == 0 or epoch == 0:
            model.eval()
            with torch.no_grad():
                pred_values = heun_sample(
                    model, viz_output_coords, viz_input_coords, viz_input_values,
                    num_steps=50, device=device
                )
                
                full_coords_batch = full_coords.unsqueeze(0).expand(4, -1, -1)
                full_pred_values = heun_sample(
                    model, full_coords_batch, viz_input_coords, viz_input_values,
                    num_steps=50, device=device
                )
                full_pred_images = full_pred_values.view(4, 32, 32, 3).permute(0, 3, 1, 2)
                
                fig, axes = plt.subplots(4, 5, figsize=(20, 16))
                
                for i in range(4):
                    axes[i, 0].imshow(viz_full_images[i].permute(1, 2, 0).cpu().numpy())
                    axes[i, 0].set_title('Ground Truth' if i == 0 else '', fontsize=10)
                    axes[i, 0].axis('off')
                    
                    input_img = torch.zeros(3, 32, 32, device=device)
                    input_idx = viz_input_indices[i]
                    input_img.view(3, -1)[:, input_idx] = viz_input_values[i].T
                    axes[i, 1].imshow(input_img.permute(1, 2, 0).cpu().numpy())
                    axes[i, 1].set_title('Sparse Input (20%)' if i == 0 else '', fontsize=10)
                    axes[i, 1].axis('off')
                    
                    target_img = torch.zeros(3, 32, 32, device=device)
                    output_idx = viz_batch['output_indices'][i]
                    target_img.view(3, -1)[:, output_idx] = viz_output_values[i].T
                    axes[i, 2].imshow(target_img.permute(1, 2, 0).cpu().numpy())
                    axes[i, 2].set_title('Sparse Target (20%)' if i == 0 else '', fontsize=10)
                    axes[i, 2].axis('off')
                    
                    pred_img = torch.zeros(3, 32, 32, device=device)
                    pred_img.view(3, -1)[:, output_idx] = pred_values[i].T
                    axes[i, 3].imshow(np.clip(pred_img.permute(1, 2, 0).cpu().numpy(), 0, 1))
                    axes[i, 3].set_title('Sparse Prediction' if i == 0 else '', fontsize=10)
                    axes[i, 3].axis('off')
                    
                    axes[i, 4].imshow(np.clip(full_pred_images[i].permute(1, 2, 0).cpu().numpy(), 0, 1))
                    axes[i, 4].set_title('Full Field Recon' if i == 0 else '', fontsize=10)
                    axes[i, 4].axis('off')
                
                plt.suptitle(f'MAMBA GFF+Hilbert - Epoch {epoch+1} (Best Val: {best_val_loss:.6f})', 
                           fontsize=14, y=0.995)
                plt.tight_layout()
                plt.savefig(f'{save_dir}/epoch_{epoch+1:03d}.png', dpi=150, bbox_inches='tight')
                plt.show()
                plt.close()
    
    print(f"\n✓ Training complete! Best validation loss: {best_val_loss:.6f}")
    return losses

## 7. Load Data and Train

In [None]:
# Load dataset
train_dataset = SparseCIFAR10Dataset(
    root='../data', train=True, input_ratio=0.2, output_ratio=0.2, download=True, seed=42
)
test_dataset = SparseCIFAR10Dataset(
    root='../data', train=False, input_ratio=0.2, output_ratio=0.2, download=True, seed=42
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")

# Initialize model
model = MAMBADiffusionGaussianHilbert(
    num_fourier_feats=256,
    d_model=512,
    num_layers=6,
    d_state=16,
    gaussian_scale=10.0,
    learnable_gff=True,
    use_hilbert=True,
    hilbert_grid_size=32
).to(device)

# Train
losses = train_flow_matching(
    model, train_loader, test_loader,
    epochs=100, lr=1e-4, device=device,
    save_dir='checkpoints_gff_hilbert'
)

## 8. Final Evaluation

In [None]:
# Plot loss
plt.figure(figsize=(10, 4))
plt.plot(losses, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss: MAMBA with GFF + Hilbert')
plt.grid(alpha=0.3)
plt.savefig('checkpoints_gff_hilbert/training_loss.png', dpi=150, bbox_inches='tight')
plt.show()

# Full image reconstruction
def create_full_grid(image_size=32, device='cuda'):
    y, x = torch.meshgrid(
        torch.linspace(0, 1, image_size),
        torch.linspace(0, 1, image_size),
        indexing='ij'
    )
    return torch.stack([x.flatten(), y.flatten()], dim=-1).to(device)

full_coords = create_full_grid(32, device)

model.eval()
tracker_full = MetricsTracker()

for i, batch in enumerate(tqdm(test_loader, desc="Full Reconstruction")):
    if i >= 50:
        break
    
    B = batch['input_coords'].shape[0]
    full_coords_batch = full_coords.unsqueeze(0).expand(B, -1, -1)
    
    pred_values = heun_sample(
        model, full_coords_batch,
        batch['input_coords'].to(device),
        batch['input_values'].to(device),
        num_steps=100, device=device
    )
    
    pred_images = pred_values.view(B, 32, 32, 3).permute(0, 3, 1, 2)
    tracker_full.update(None, None, pred_images, batch['full_image'].to(device))

results = tracker_full.compute()
print(f"\nFull Image Reconstruction:")
print(f"  PSNR: {results['psnr']:.2f} dB")
print(f"  SSIM: {results['ssim']:.4f}")

## 9. Visualize Results

In [None]:
sample_batch = next(iter(test_loader))
B = 4
full_coords_batch = full_coords.unsqueeze(0).expand(B, -1, -1)

pred_values = heun_sample(
    model, full_coords_batch,
    sample_batch['input_coords'][:B].to(device),
    sample_batch['input_values'][:B].to(device),
    num_steps=100, device=device
)
pred_images = pred_values.view(B, 32, 32, 3).permute(0, 3, 1, 2)

fig, axes = plt.subplots(4, 3, figsize=(12, 16))
for i in range(4):
    axes[i, 0].imshow(sample_batch['full_image'][i].permute(1, 2, 0).numpy())
    axes[i, 0].set_title('Ground Truth')
    axes[i, 0].axis('off')
    
    input_img = torch.zeros(3, 32, 32)
    input_idx = sample_batch['input_indices'][i]
    input_img.view(3, -1)[:, input_idx] = sample_batch['input_values'][i].T
    axes[i, 1].imshow(input_img.permute(1, 2, 0).numpy())
    axes[i, 1].set_title(f'Input (20%)')
    axes[i, 1].axis('off')
    
    axes[i, 2].imshow(np.clip(pred_images[i].permute(1, 2, 0).cpu().numpy(), 0, 1))
    axes[i, 2].set_title('Reconstructed')
    axes[i, 2].axis('off')

plt.suptitle('MAMBA GFF+Hilbert: Full Image Reconstruction', fontsize=14, y=0.995)
plt.tight_layout()
plt.savefig('checkpoints_gff_hilbert/final_reconstruction.png', dpi=150, bbox_inches='tight')
plt.show()

## 10. Ablation Study: Compare Configurations

Test different combinations to see the impact of each component:
1. Baseline (fixed Fourier, no Hilbert)
2. Gaussian Fourier only
3. Hilbert only
4. Both (GFF + Hilbert)

In [None]:
print("="*60)
print("ABLATION STUDY")
print("="*60)

# This would require training 4 models - for now, we document the approach
ablation_configs = [
    {'name': 'Baseline', 'gff': False, 'hilbert': False},
    {'name': 'GFF Only', 'gff': True, 'hilbert': False},
    {'name': 'Hilbert Only', 'gff': False, 'hilbert': True},
    {'name': 'GFF + Hilbert', 'gff': True, 'hilbert': True}
]

print("\nTo run full ablation, train 4 models with these configs:")
for config in ablation_configs:
    print(f"  {config['name']:20s} - GFF: {config['gff']}, Hilbert: {config['hilbert']}")

print("\nExpected results:")
print("  1. GFF should improve frequency coverage → better fine details")
print("  2. Hilbert should improve SSM efficiency → faster convergence")
print("  3. Both together should give best performance")

## Summary

### Key Innovations

#### 1. Gaussian Fourier Features
- **Learnable frequency sampling** from random Gaussian projection
- **Better spectral coverage** than fixed sinusoidal encoding
- **Theoretical foundation** in RBF kernel approximation
- **Adaptive learning** - model can adjust frequency bands during training

#### 2. Hilbert Curve Sorting
- **Locality preservation** - spatially adjacent pixels stay together in sequence
- **Better for SSM** - state propagation benefits from spatial coherence
- **No information loss** - bijective mapping, fully invertible
- **Scale-flexible** - works at any resolution

### Expected Benefits

| Component | Benefit | Impact |
|-----------|---------|--------|
| GFF | Better frequency representation | Sharper details, better textures |
| Hilbert | Improved sequence coherence | Faster training, better long-range deps |
| Combined | Synergistic improvements | Best overall performance |

### Implementation Notes

1. **GFF Scale**: The `gaussian_scale` parameter controls frequency range
   - Higher values → higher frequencies (fine details)
   - Lower values → lower frequencies (smooth variations)
   - Learnable version adapts automatically

2. **Hilbert Order**: The `p` parameter controls curve resolution
   - p=5 → 2^10 = 1024 points (good for 32×32)
   - Higher p for larger images

3. **Computational Cost**: 
   - GFF: Minimal overhead (just matrix multiply)
   - Hilbert: O(N log N) sorting per batch
   - Total: ~5-10% slower than baseline

### Future Extensions

1. **Adaptive GFF**: Learn frequency scale per layer
2. **3D Hilbert**: Extend to video/volumetric data
3. **Learned Scanning**: Neural attention-based ordering
4. **Multi-scale Hilbert**: Different scales for different layers