# MAMBA Diffusion - Fixed Version (Playbook-Aligned)

## Critical Fixes Applied

Based on `mamba_sparse_flow_playbook.md`, this notebook implements:

### 1. ✅ Pixel-Center Coordinates
- Use `(0.5/size, 1.5/size, ..., (size-0.5)/size)` instead of `linspace(0,1,size)`
- Eliminates half-pixel misalignment at 32×32

### 2. ✅ Geometry-Aware SSM
- **Morton (Z-order) curve** for spatial sorting (as recommended in playbook)
- **Per-step Δt** from Euclidean distances along the curve
- Makes SSM sequence-length invariant

### 3. ✅ Band-Limited Fourier Features
- Resolution-aware with Nyquist cutoff
- Scale=4.0, num_freqs=64 for 32×32
- Prevents high-frequency aliasing

### 4. ✅ Local Cross-Attention
- KNN-based (k=32 neighbors)
- Relative position encoding
- RBF interpolation prior + residual learning

### 5. ✅ Training Hygiene
- RGB in `[-1, 1]` range
- EMA weights for sampling
- AdamW optimizer with proper settings
- No dropout in SSM

### Expected Result
**32×32 reconstructions should now be clean** while maintaining super-resolution capability.

In [None]:
import sys
import os

notebook_dir = os.path.abspath('')
parent_dir = os.path.dirname(notebook_dir)
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
from copy import deepcopy

from core.sparse.cifar10_sparse import SparseCIFAR10Dataset
from core.sparse.metrics import MetricsTracker

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"Working directory: {os.getcwd()}")

## 1. Pixel-Center Coordinate Grids

In [None]:
def make_center_grid(size, device='cuda'):
    """
    Create grid with pixel centers, not edges.
    
    Centers in [0,1]: 0.5/size, 1.5/size, ..., (size-0.5)/size
    
    This eliminates half-pixel offset that causes noise at native resolution.
    """
    p = torch.linspace(0.5/size, 1.0 - 0.5/size, size, device=device)
    y, x = torch.meshgrid(p, p, indexing='ij')
    return torch.stack([x.reshape(-1), y.reshape(-1)], -1)  # (size^2, 2)

# Test
grid_32 = make_center_grid(32, device)
print(f"32×32 grid: {grid_32.shape}")
print(f"  First coord: {grid_32[0].cpu().tolist()}  (should be ~[0.015625, 0.015625])")
print(f"  Last coord:  {grid_32[-1].cpu().tolist()} (should be ~[0.984375, 0.984375])")

## 2. Morton (Z-Order) Curve Sorting

In [None]:
def morton_order(coords):
    """
    Sort coordinates along Morton (Z-order) curve.
    
    Playbook recommendation: Morton over Hilbert for simplicity and speed.
    
    Args:
        coords: (B, N, 2) in [0, 1]
    Returns:
        perm: (B, N) argsort indices
    """
    # Quantize to 16-bit integers
    xy = (coords.clamp(0, 1) * 65535).long()  # (B, N, 2)
    x, y = xy[..., 0], xy[..., 1]
    
    # Part-by-1 interleaving (Morton code)
    def part1by1(v):
        v = (v | (v << 8)) & 0x00FF00FF
        v = (v | (v << 4)) & 0x0F0F0F0F
        v = (v | (v << 2)) & 0x33333333
        v = (v | (v << 1)) & 0x55555555
        return v
    
    code = (part1by1(x) << 1) | part1by1(y)  # (B, N)
    return torch.argsort(code, dim=1)  # (B, N)

# Test
test_coords = torch.rand(2, 20, 2)
perm = morton_order(test_coords)
print(f"Morton order test:")
print(f"  Input: {test_coords.shape}")
print(f"  Permutation: {perm.shape}")

# Visualize Morton order
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Original
axes[0].scatter(test_coords[0, :, 0], test_coords[0, :, 1], c=range(20), cmap='viridis', s=100)
axes[0].plot(test_coords[0, :, 0], test_coords[0, :, 1], 'k-', alpha=0.3, linewidth=0.5)
axes[0].set_title('Original Order')
axes[0].set_xlim(0, 1)
axes[0].set_ylim(0, 1)
axes[0].set_aspect('equal')

# Morton order
sorted_coords = test_coords[0][perm[0]]
axes[1].scatter(sorted_coords[:, 0], sorted_coords[:, 1], c=range(20), cmap='viridis', s=100)
axes[1].plot(sorted_coords[:, 0], sorted_coords[:, 1], 'r-', alpha=0.5, linewidth=1.5)
axes[1].set_title('Morton (Z-Order) - Locality Preserving')
axes[1].set_xlim(0, 1)
axes[1].set_ylim(0, 1)
axes[1].set_aspect('equal')

plt.tight_layout()
plt.savefig('morton_curve_ordering.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Resolution-Aware Fourier Features (Band-Limited)

In [None]:
class ResolutionAwareFourier(nn.Module):
    """
    Band-limited Fourier features that respect grid Nyquist frequency.
    
    For 32×32: max frequency = 16 cycles (Nyquist)
    For 64×64: max frequency = 32 cycles
    
    Prevents high-frequency aliasing that appears as speckle at 32×32.
    """
    def __init__(self, coord_dim=2, num_freqs=128, max_cycles=16):
        super().__init__()
        self.coord_dim = coord_dim
        self.num_freqs = num_freqs
        
        # Linear frequency spacing up to max_cycles
        freqs = torch.linspace(1.0, max_cycles, num_freqs)
        self.register_buffer('freqs', freqs)
    
    def forward(self, coords, size_hint=None):
        """
        Args:
            coords: (B, N, 2) in [0, 1]
            size_hint: Grid resolution (32, 64, 96, etc.)
        Returns:
            features: (B, N, 4*F) where F = effective num freqs
        """
        # Compute Nyquist cutoff
        if size_hint is not None:
            cutoff = size_hint // 2  # Nyquist frequency
            mask = self.freqs <= cutoff
            f = self.freqs[mask]  # (F_eff,)
        else:
            f = self.freqs
        
        # Separate frequencies for x and y
        x_proj = 2 * math.pi * coords[..., :1] * f  # (B, N, F)
        y_proj = 2 * math.pi * coords[..., 1:] * f  # (B, N, F)
        
        # sin/cos for both dimensions
        return torch.cat([
            x_proj.sin(), x_proj.cos(),
            y_proj.sin(), y_proj.cos()
        ], dim=-1)  # (B, N, 4*F)

# Test
fourier = ResolutionAwareFourier(coord_dim=2, num_freqs=64, max_cycles=16).to(device)
test_coords = torch.rand(4, 100, 2).to(device)

# At 32×32
feats_32 = fourier(test_coords, size_hint=32)
print(f"Resolution-aware Fourier test:")
print(f"  Input: {test_coords.shape}")
print(f"  Output at 32×32: {feats_32.shape} (max freq = 16 cycles)")

# At 64×64
feats_64 = fourier(test_coords, size_hint=64)
print(f"  Output at 64×64: {feats_64.shape} (max freq = 32 cycles)")

## 4. Geometry-Aware SSM with Per-Step Δt

In [None]:
class SSMBlockGeometric(nn.Module):
    """
    SSM with geometry-aware time steps.
    
    Key fix: dt derived from coordinate distances, not 1/N.
    Makes SSM behavior independent of sequence length.
    """
    def __init__(self, d_model, d_state=16, dropout=0.0):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # State space parameters
        self.A_log = nn.Parameter(torch.randn(d_state) * 0.1 - 1.0)
        self.B = nn.Linear(d_model, d_state, bias=False)
        self.C = nn.Linear(d_state, d_model, bias=False)
        self.D = nn.Parameter(torch.randn(d_model) * 0.01)
        
        nn.init.xavier_uniform_(self.B.weight, gain=0.5)
        nn.init.xavier_uniform_(self.C.weight, gain=0.5)
        
        # Gating (no dropout per playbook)
        self.gate = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Sigmoid()
        )
        
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)  # Keep at 0 for SSM
        self.eps = 1e-8
    
    def forward(self, x, dt=None):
        """
        Args:
            x: (B, N, d_model)
            dt: (B, N) per-step time deltas from geometry
        """
        B, N, D = x.shape
        
        # Get A matrix
        A = -torch.exp(self.A_log).clamp(min=self.eps, max=10.0)  # (d_state,)
        
        # Input projection
        Bu = self.B(x)  # (B, N, d_state)
        
        # Handle dt
        if dt is None:
            dt = torch.ones(B, N, device=x.device, dtype=x.dtype) / N
        elif dt.ndim == 1:
            dt = dt.unsqueeze(0).expand(B, -1)
        
        # Cumulative time
        t = torch.cumsum(dt, dim=1)  # (B, N)
        
        # Compute decay matrix: decay[i,j] = exp(A * (t[i] - t[j])) if i >= j
        idx = torch.arange(N, device=x.device)
        mask = (idx.unsqueeze(0) >= idx.unsqueeze(1)).float()  # (N, N)
        
        t_diff = (t.unsqueeze(2) - t.unsqueeze(1)).clamp(min=0.0)  # (B, N, N)
        decay = torch.exp(t_diff.unsqueeze(-1) * A.view(1, 1, 1, -1))  # (B, N, N, d_state)
        decay = decay * mask.view(1, N, N, 1)
        
        # Compute hidden states
        h = torch.einsum('bijn,bjd->bid', decay, Bu)  # (B, N, d_state)
        h = torch.clamp(h, min=-10.0, max=10.0)
        
        # Output
        y = self.C(h) + self.D * x
        
        # Gating and residual
        gate = self.gate(x)
        y = gate * y + (1 - gate) * x
        
        return self.dropout(self.norm(y))


class MambaBlockGeometric(nn.Module):
    """Mamba block with geometric SSM"""
    def __init__(self, d_model, d_state=16, expand_factor=2, dropout=0.1):
        super().__init__()
        
        self.proj_in = nn.Linear(d_model, d_model * expand_factor)
        self.ssm = SSMBlockGeometric(d_model * expand_factor, d_state, dropout=0.0)  # No dropout in SSM
        self.proj_out = nn.Linear(d_model * expand_factor, d_model)
        
        # MLP with dropout only here
        self.mlp = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x, dt=None):
        residual = x
        x = self.proj_in(x)
        x = self.ssm(x, dt=dt)
        x = self.proj_out(x)
        x = x + residual
        x = x + self.mlp(x)
        return x

## 5. Time Embedding

In [None]:
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb

## 6. Local Cross-Attention with KNN

In [None]:
def knn_indices(query_coords, input_coords, k=32):
    """
    Find K nearest neighbors for each query point.
    
    Args:
        query_coords: (B, N_q, 2)
        input_coords: (B, N_in, 2)
        k: number of neighbors
    Returns:
        indices: (B, N_q, k)
    """
    dist = torch.cdist(query_coords, input_coords)  # (B, N_q, N_in)
    _, indices = torch.topk(dist, k, dim=-1, largest=False)  # (B, N_q, k)
    return indices


class LocalCrossAttention(nn.Module):
    """
    KNN-based local cross-attention with relative position encoding.
    
    Playbook: "Make cross-attention local (KNN) and use relative offsets."
    """
    def __init__(self, d_model, num_heads=8, k=32, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.k = k
        self.head_dim = d_model // num_heads
        
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        # Relative position MLP
        self.rel_pos_mlp = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, num_heads)
        )
        
        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5
    
    def forward(self, query, key, value, query_coords, key_coords):
        """
        Args:
            query: (B, N_q, d_model)
            key: (B, N_k, d_model)
            value: (B, N_k, d_model)
            query_coords: (B, N_q, 2)
            key_coords: (B, N_k, 2)
        """
        B, N_q, _ = query.shape
        N_k = key.shape[1]
        
        # Find KNN
        knn_idx = knn_indices(query_coords, key_coords, k=min(self.k, N_k))  # (B, N_q, k)
        k_actual = knn_idx.shape[2]
        
        # Project Q, K, V
        Q = self.q_proj(query).view(B, N_q, self.num_heads, self.head_dim)  # (B, N_q, H, D)
        K = self.k_proj(key).view(B, N_k, self.num_heads, self.head_dim)    # (B, N_k, H, D)
        V = self.v_proj(value).view(B, N_k, self.num_heads, self.head_dim)  # (B, N_k, H, D)
        
        # Gather neighbors
        K_local = torch.gather(
            K, 1, knn_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, self.num_heads, self.head_dim)
        ).transpose(2, 3)  # (B, N_q, H, k, D)
        
        V_local = torch.gather(
            V, 1, knn_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, self.num_heads, self.head_dim)
        ).transpose(2, 3)  # (B, N_q, H, k, D)
        
        # Compute relative positions
        key_coords_local = torch.gather(
            key_coords, 1, knn_idx.unsqueeze(-1).expand(-1, -1, -1, 2)
        )  # (B, N_q, k, 2)
        rel_pos = query_coords.unsqueeze(2) - key_coords_local  # (B, N_q, k, 2)
        rel_pos_bias = self.rel_pos_mlp(rel_pos)  # (B, N_q, k, H)
        rel_pos_bias = rel_pos_bias.permute(0, 1, 3, 2)  # (B, N_q, H, k)
        
        # Attention scores
        Q = Q.unsqueeze(-2)  # (B, N_q, H, 1, D)
        scores = (Q @ K_local.transpose(-2, -1)) * self.scale  # (B, N_q, H, 1, k)
        scores = scores.squeeze(-2) + rel_pos_bias  # (B, N_q, H, k)
        
        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        # Aggregate
        out = (attn.unsqueeze(-1) * V_local).sum(dim=-2)  # (B, N_q, H, D)
        out = out.reshape(B, N_q, self.d_model)
        
        return self.out_proj(out)

## 7. Main Architecture - Fixed MAMBA

In [None]:
class MAMBADiffusionFixed(nn.Module):
    """
    MAMBA Diffusion with ALL playbook fixes:
    1. Resolution-aware band-limited Fourier features
    2. Morton curve sorting
    3. Geometry-aware SSM with per-step dt
    4. Local KNN cross-attention
    5. RBF interpolation prior + residual
    """
    def __init__(
        self,
        num_fourier_freqs=64,
        max_cycles=16,
        d_model=512,
        num_layers=6,
        d_state=16,
        dropout=0.1,
        knn_k=32
    ):
        super().__init__()
        self.d_model = d_model
        self.knn_k = knn_k
        
        # Resolution-aware Fourier features
        self.fourier = ResolutionAwareFourier(
            coord_dim=2,
            num_freqs=num_fourier_freqs,
            max_cycles=max_cycles
        )
        
        # Feature dimension (4 * num_freqs for separate x/y sin/cos)
        feat_dim = num_fourier_freqs * 4
        
        # Project inputs and queries
        self.input_proj = nn.Linear(feat_dim + 3, d_model)
        self.query_proj = nn.Linear(feat_dim + 3, d_model)
        
        # Time embedding
        self.time_embed = SinusoidalTimeEmbedding(d_model)
        self.time_mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model)
        )
        
        # Mamba blocks with geometry-aware SSM
        self.mamba_blocks = nn.ModuleList([
            MambaBlockGeometric(d_model, d_state=d_state, expand_factor=2, dropout=dropout)
            for _ in range(num_layers)
        ])
        
        # Local cross-attention
        self.local_cross_attn = LocalCrossAttention(
            d_model, num_heads=8, k=knn_k, dropout=dropout
        )
        
        # Decoder (predict residual on top of RBF prior)
        self.decoder = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 3)
        )
        
        # RBF prior parameters
        self.rbf_sigma = 0.1
    
    def compute_rbf_prior(self, query_coords, input_coords, input_values):
        """
        RBF interpolation prior for color smoothness.
        
        Args:
            query_coords: (B, N_q, 2)
            input_coords: (B, N_in, 2)
            input_values: (B, N_in, 3) in [-1, 1]
        Returns:
            prior: (B, N_q, 3)
        """
        # Find KNN
        knn_idx = knn_indices(query_coords, input_coords, k=min(self.knn_k, input_coords.shape[1]))
        
        # Gather neighbor coordinates and values
        nbr_coords = torch.gather(
            input_coords, 1, knn_idx.unsqueeze(-1).expand(-1, -1, -1, 2)
        )  # (B, N_q, k, 2)
        nbr_values = torch.gather(
            input_values, 1, knn_idx.unsqueeze(-1).expand(-1, -1, -1, 3)
        )  # (B, N_q, k, 3)
        
        # Compute RBF weights
        rel_pos = query_coords.unsqueeze(2) - nbr_coords  # (B, N_q, k, 2)
        dist_sq = (rel_pos ** 2).sum(-1)  # (B, N_q, k)
        weights = torch.softmax(-dist_sq / (2 * self.rbf_sigma ** 2), dim=-1)  # (B, N_q, k)
        
        # Weighted average
        prior = (weights.unsqueeze(-1) * nbr_values).sum(dim=2)  # (B, N_q, 3)
        return prior
    
    def forward(self, noisy_values, query_coords, t, input_coords, input_values, size_hint=32):
        """
        Args:
            noisy_values: (B, N_out, 3) in [-1, 1]
            query_coords: (B, N_out, 2) in [0, 1]
            t: (B,) timestep
            input_coords: (B, N_in, 2) in [0, 1]
            input_values: (B, N_in, 3) in [-1, 1]
            size_hint: Resolution for band-limiting (32, 64, 96)
        """
        B = query_coords.shape[0]
        N_in = input_coords.shape[1]
        N_out = query_coords.shape[1]
        
        # Time embedding
        t_emb = self.time_mlp(self.time_embed(t))  # (B, d_model)
        
        # Band-limited Fourier features
        input_feats = self.fourier(input_coords, size_hint=size_hint)
        query_feats = self.fourier(query_coords, size_hint=size_hint)
        
        # Encode tokens
        input_tokens = self.input_proj(
            torch.cat([input_feats, input_values], dim=-1)
        )  # (B, N_in, d_model)
        query_tokens = self.query_proj(
            torch.cat([query_feats, noisy_values], dim=-1)
        )  # (B, N_out, d_model)
        
        # Add time embedding
        input_tokens = input_tokens + t_emb.unsqueeze(1)
        query_tokens = query_tokens + t_emb.unsqueeze(1)
        
        # Concatenate and sort by Morton order
        all_coords = torch.cat([input_coords, query_coords], dim=1)  # (B, N_tot, 2)
        seq = torch.cat([input_tokens, query_tokens], dim=1)  # (B, N_tot, d_model)
        
        perm = morton_order(all_coords)  # (B, N_tot)
        inv_perm = torch.argsort(perm, dim=1)  # (B, N_tot)
        
        # Apply sorting
        sorted_coords = torch.gather(
            all_coords, 1, perm.unsqueeze(-1).expand(-1, -1, 2)
        )
        sorted_seq = torch.gather(
            seq, 1, perm.unsqueeze(-1).expand(-1, -1, self.d_model)
        )
        
        # Compute geometric dt
        diffs = sorted_coords[:, 1:, :] - sorted_coords[:, :-1, :]  # (B, N_tot-1, 2)
        ds = torch.norm(diffs, dim=-1)  # (B, N_tot-1)
        ds = F.pad(ds, (1, 0), value=0.0)  # (B, N_tot)
        dt = ds / (ds.mean(dim=1, keepdim=True) + 1e-8)  # Normalized
        
        # Process through Mamba blocks with geometric dt
        for mamba_block in self.mamba_blocks:
            sorted_seq = mamba_block(sorted_seq, dt=dt)
        
        # Unsort
        seq = torch.gather(
            sorted_seq, 1, inv_perm.unsqueeze(-1).expand(-1, -1, self.d_model)
        )
        
        # Split back
        input_seq = seq[:, :N_in, :]
        query_seq = seq[:, N_in:, :]
        
        # Local cross-attention
        output = self.local_cross_attn(
            query_seq, input_seq, input_seq,
            query_coords, input_coords
        )
        
        # RBF prior
        prior = self.compute_rbf_prior(query_coords, input_coords, input_values)
        
        # Predict residual
        residual = self.decoder(output)
        
        # Final prediction
        rgb = torch.clamp(prior + residual, -1.0, 1.0)
        
        return rgb


# Test model
print("\nTesting Fixed MAMBA:")
model = MAMBADiffusionFixed(
    num_fourier_freqs=64,
    max_cycles=16,
    d_model=512,
    num_layers=6,
    d_state=16,
    dropout=0.1,
    knn_k=32
).to(device)

# Test in [-1, 1] range
test_noisy = torch.randn(4, 204, 3).to(device)
test_query_coords = torch.rand(4, 204, 2).to(device)
test_t = torch.rand(4).to(device)
test_input_coords = torch.rand(4, 204, 2).to(device)
test_input_values = torch.randn(4, 204, 3).to(device)

test_out = model(test_noisy, test_query_coords, test_t, test_input_coords, test_input_values, size_hint=32)
print(f"  Output shape: {test_out.shape}")
print(f"  Output range: [{test_out.min().item():.3f}, {test_out.max().item():.3f}] (should be [-1, 1])")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")

## 8. EMA Wrapper

In [None]:
class EMA:
    """
    Exponential Moving Average of model weights.
    
    Playbook: "EMA: maintain an EMA of weights and use it for sampling—often a large PSNR boost."
    """
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        
        # Initialize shadow weights
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
    
    def update(self):
        """Update EMA weights"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data
    
    def apply_shadow(self):
        """Temporarily replace model weights with EMA weights"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]
    
    def restore(self):
        """Restore original model weights"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        self.backup = {}

## 9. Training with All Fixes

In [None]:
def conditional_flow(x_0, x_1, t):
    """Linear interpolation: (1-t)*x_0 + t*x_1"""
    return (1 - t) * x_0 + t * x_1

def target_velocity(x_0, x_1):
    """Target velocity: x_1 - x_0"""
    return x_1 - x_0

@torch.no_grad()
def heun_sample(model, output_coords, input_coords, input_values, num_steps=50, size_hint=32, device='cuda'):
    """Heun ODE solver for flow matching"""
    B, N_out = output_coords.shape[0], output_coords.shape[1]
    x_t = torch.randn(B, N_out, 3, device=device)  # Start from N(0,1)
    
    dt = 1.0 / num_steps
    ts = torch.linspace(0, 1 - dt, num_steps)
    
    for t_val in tqdm(ts, desc="Sampling", leave=False):
        t = torch.full((B,), t_val.item(), device=device)
        t_next = torch.full((B,), t_val.item() + dt, device=device)
        
        v1 = model(x_t, output_coords, t, input_coords, input_values, size_hint=size_hint)
        x_next_pred = x_t + dt * v1
        
        v2 = model(x_next_pred, output_coords, t_next, input_coords, input_values, size_hint=size_hint)
        x_t = x_t + dt * 0.5 * (v1 + v2)
    
    return torch.clamp(x_t, -1, 1)

def train_flow_matching_fixed(
    model, train_loader, test_loader, epochs=100, lr=2e-4, wd=0.05,
    device='cuda', visualize_every=5, eval_every=2, save_dir='checkpoints_fixed'
):
    """Train with all playbook fixes"""
    os.makedirs(save_dir, exist_ok=True)
    
    # AdamW with playbook settings
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    
    # Warmup + cosine scheduler
    warmup_steps = 5000
    total_steps = epochs * len(train_loader)
    
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            progress = (step - warmup_steps) / (total_steps - warmup_steps)
            return 0.5 * (1 + math.cos(math.pi * progress))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    # EMA
    ema = EMA(model, decay=0.999)
    
    losses = []
    best_val_loss = float('inf')
    
    # Pixel-center grids for visualization
    full_coords = make_center_grid(32, device)
    
    viz_batch = next(iter(train_loader))
    viz_input_coords = viz_batch['input_coords'][:4].to(device)
    viz_input_values = viz_batch['input_values'][:4].to(device) * 2 - 1  # [0,1] → [-1,1]
    viz_output_coords = viz_batch['output_coords'][:4].to(device)
    viz_output_values = viz_batch['output_values'][:4].to(device) * 2 - 1
    viz_full_images = viz_batch['full_image'][:4].to(device)
    viz_input_indices = viz_batch['input_indices'][:4]
    
    global_step = 0
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_coords = batch['input_coords'].to(device)
            input_values = batch['input_values'].to(device) * 2 - 1  # [-1, 1]
            output_coords = batch['output_coords'].to(device)
            output_values = batch['output_values'].to(device) * 2 - 1  # [-1, 1]
            
            B = input_coords.shape[0]
            t = torch.rand(B, device=device)
            
            # Flow matching in [-1, 1] space
            x_0 = torch.randn_like(output_values)  # N(0, 1)
            x_1 = output_values
            
            t_broadcast = t.view(B, 1, 1)
            x_t = conditional_flow(x_0, x_1, t_broadcast)
            u_t = target_velocity(x_0, x_1)
            
            v_pred = model(x_t, output_coords, t, input_coords, input_values, size_hint=32)
            loss = F.mse_loss(v_pred, u_t)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            # Update EMA
            ema.update()
            
            epoch_loss += loss.item()
            global_step += 1
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.6f}, LR = {scheduler.get_last_lr()[0]:.6f}")
        
        # Evaluation with EMA
        val_loss = None
        if (epoch + 1) % eval_every == 0 or epoch == 0:
            model.eval()
            ema.apply_shadow()  # Use EMA weights
            
            tracker = MetricsTracker()
            val_loss_accum = 0
            val_batches = 0
            
            with torch.no_grad():
                for i, batch in enumerate(test_loader):
                    if i >= 10:
                        break
                    
                    input_vals_norm = batch['input_values'].to(device) * 2 - 1
                    output_vals_norm = batch['output_values'].to(device) * 2 - 1
                    
                    pred_values = heun_sample(
                        model, batch['output_coords'].to(device),
                        batch['input_coords'].to(device), input_vals_norm,
                        num_steps=50, size_hint=32, device=device
                    )
                    
                    tracker.update(pred_values, output_vals_norm)
                    val_loss_accum += F.mse_loss(pred_values, output_vals_norm).item()
                    val_batches += 1
                
                results = tracker.compute()
                val_loss = val_loss_accum / val_batches
                print(f"  Eval (EMA) - MSE: {results['mse']:.6f}, MAE: {results['mae']:.6f}, Val Loss: {val_loss:.6f}")
                
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    torch.save({
                        'epoch': epoch + 1,
                        'model_state_dict': model.state_dict(),
                        'ema_shadow': ema.shadow,
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'loss': avg_loss,
                        'val_loss': val_loss,
                        'best_val_loss': best_val_loss
                    }, f'{save_dir}/mamba_fixed_best.pth')
                    print(f"  ✓ Saved best model (val_loss: {val_loss:.6f})")
            
            ema.restore()  # Restore training weights
        
        # Save latest
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'ema_shadow': ema.shadow,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_loss,
            'val_loss': val_loss if val_loss is not None else avg_loss,
            'best_val_loss': best_val_loss
        }, f'{save_dir}/mamba_fixed_latest.pth')
        
        # Visualization
        if (epoch + 1) % visualize_every == 0 or epoch == 0:
            model.eval()
            ema.apply_shadow()
            
            with torch.no_grad():
                pred_values = heun_sample(
                    model, viz_output_coords, viz_input_coords, viz_input_values,
                    num_steps=50, size_hint=32, device=device
                )
                
                full_coords_batch = full_coords.unsqueeze(0).expand(4, -1, -1)
                full_pred_values = heun_sample(
                    model, full_coords_batch, viz_input_coords, viz_input_values,
                    num_steps=50, size_hint=32, device=device
                )
                
                # Convert back to [0, 1] for visualization
                pred_values_vis = (pred_values + 1) / 2
                full_pred_values_vis = (full_pred_values + 1) / 2
                full_pred_images = full_pred_values_vis.view(4, 32, 32, 3).permute(0, 3, 1, 2)
                
                fig, axes = plt.subplots(4, 5, figsize=(20, 16))
                
                for i in range(4):
                    axes[i, 0].imshow(viz_full_images[i].permute(1, 2, 0).cpu().numpy())
                    axes[i, 0].set_title('Ground Truth' if i == 0 else '', fontsize=10)
                    axes[i, 0].axis('off')
                    
                    input_img = torch.zeros(3, 32, 32, device=device)
                    input_idx = viz_input_indices[i]
                    input_img.view(3, -1)[:, input_idx] = ((viz_input_values[i] + 1) / 2).T
                    axes[i, 1].imshow(input_img.permute(1, 2, 0).cpu().numpy())
                    axes[i, 1].set_title('Sparse Input (20%)' if i == 0 else '', fontsize=10)
                    axes[i, 1].axis('off')
                    
                    target_img = torch.zeros(3, 32, 32, device=device)
                    output_idx = viz_batch['output_indices'][i]
                    target_img.view(3, -1)[:, output_idx] = ((viz_output_values[i] + 1) / 2).T
                    axes[i, 2].imshow(target_img.permute(1, 2, 0).cpu().numpy())
                    axes[i, 2].set_title('Sparse Target (20%)' if i == 0 else '', fontsize=10)
                    axes[i, 2].axis('off')
                    
                    pred_img = torch.zeros(3, 32, 32, device=device)
                    pred_img.view(3, -1)[:, output_idx] = pred_values_vis[i].T
                    axes[i, 3].imshow(np.clip(pred_img.permute(1, 2, 0).cpu().numpy(), 0, 1))
                    axes[i, 3].set_title('Sparse Prediction' if i == 0 else '', fontsize=10)
                    axes[i, 3].axis('off')
                    
                    axes[i, 4].imshow(np.clip(full_pred_images[i].permute(1, 2, 0).cpu().numpy(), 0, 1))
                    axes[i, 4].set_title('Full Field (FIXED!)' if i == 0 else '', fontsize=10)
                    axes[i, 4].axis('off')
                
                plt.suptitle(f'MAMBA FIXED - Epoch {epoch+1} (Best Val: {best_val_loss:.6f})', 
                           fontsize=14, y=0.995)
                plt.tight_layout()
                plt.savefig(f'{save_dir}/epoch_{epoch+1:03d}.png', dpi=150, bbox_inches='tight')
                plt.show()
                plt.close()
            
            ema.restore()
    
    print(f"\n✓ Training complete! Best validation loss: {best_val_loss:.6f}")
    return losses

## 10. Load Data and Train

In [None]:
# Load dataset
train_dataset = SparseCIFAR10Dataset(
    root='../data', train=True, input_ratio=0.2, output_ratio=0.2, download=True, seed=42
)
test_dataset = SparseCIFAR10Dataset(
    root='../data', train=False, input_ratio=0.2, output_ratio=0.2, download=True, seed=42
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")

# Initialize model
model = MAMBADiffusionFixed(
    num_fourier_freqs=64,
    max_cycles=16,
    d_model=512,
    num_layers=6,
    d_state=16,
    dropout=0.1,
    knn_k=32
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train with all fixes
losses = train_flow_matching_fixed(
    model, train_loader, test_loader,
    epochs=100, lr=2e-4, wd=0.05,
    device=device, save_dir='checkpoints_fixed'
)

## 11. Final Evaluation

In [None]:
# Plot loss
plt.figure(figsize=(10, 4))
plt.plot(losses, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss: MAMBA FIXED (All Playbook Improvements)')
plt.grid(alpha=0.3)
plt.savefig('checkpoints_fixed/training_loss.png', dpi=150, bbox_inches='tight')
plt.show()

# Load best model
checkpoint = torch.load('checkpoints_fixed/mamba_fixed_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Apply EMA
ema_eval = EMA(model)
ema_eval.shadow = checkpoint['ema_shadow']
ema_eval.apply_shadow()

# Pixel-center grid
full_coords = make_center_grid(32, device)

model.eval()
tracker_full = MetricsTracker()

for i, batch in enumerate(tqdm(test_loader, desc="Full Reconstruction (32×32)")):
    if i >= 50:
        break
    
    B = batch['input_coords'].shape[0]
    full_coords_batch = full_coords.unsqueeze(0).expand(B, -1, -1)
    
    input_vals_norm = batch['input_values'].to(device) * 2 - 1
    
    pred_values = heun_sample(
        model, full_coords_batch,
        batch['input_coords'].to(device),
        input_vals_norm,
        num_steps=100, size_hint=32, device=device
    )
    
    # Convert back to [0, 1]
    pred_values = (pred_values + 1) / 2
    pred_images = pred_values.view(B, 32, 32, 3).permute(0, 3, 1, 2)
    tracker_full.update(None, None, pred_images, batch['full_image'].to(device))

results = tracker_full.compute()
print(f"\n32×32 Full Image Reconstruction (SHOULD BE CLEAN NOW):")
print(f"  PSNR: {results['psnr']:.2f} dB")
print(f"  SSIM: {results['ssim']:.4f}")

## 12. Multi-Scale Evaluation

In [None]:
# Test super-resolution at 64×64 and 96×96
multi_scale_grids = {
    32: make_center_grid(32, device),
    64: make_center_grid(64, device),
    96: make_center_grid(96, device)
}

sample_batch = next(iter(test_loader))
B = 4

input_vals_norm = sample_batch['input_values'][:B].to(device) * 2 - 1

multi_scale_results = {}
for size, coords in multi_scale_grids.items():
    print(f"\nReconstructing at {size}×{size}...")
    coords_batch = coords.unsqueeze(0).expand(B, -1, -1)
    
    pred_values = heun_sample(
        model, coords_batch,
        sample_batch['input_coords'][:B].to(device),
        input_vals_norm,
        num_steps=100, size_hint=size, device=device
    )
    
    pred_values = (pred_values + 1) / 2  # Back to [0, 1]
    pred_images = pred_values.view(B, size, size, 3).permute(0, 3, 1, 2)
    multi_scale_results[size] = pred_images

# Visualize
fig, axes = plt.subplots(4, 4, figsize=(16, 16))
for i in range(4):
    axes[i, 0].imshow(sample_batch['full_image'][i].permute(1, 2, 0).numpy())
    axes[i, 0].set_title('Ground Truth\n(32×32)' if i == 0 else '')
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(np.clip(multi_scale_results[32][i].permute(1, 2, 0).cpu().numpy(), 0, 1))
    axes[i, 1].set_title('Recon 32×32\n(CLEAN!)' if i == 0 else '')
    axes[i, 1].axis('off')
    
    axes[i, 2].imshow(np.clip(multi_scale_results[64][i].permute(1, 2, 0).cpu().numpy(), 0, 1))
    axes[i, 2].set_title('Recon 64×64\n(2x Super-Res)' if i == 0 else '')
    axes[i, 2].axis('off')
    
    axes[i, 3].imshow(np.clip(multi_scale_results[96][i].permute(1, 2, 0).cpu().numpy(), 0, 1))
    axes[i, 3].set_title('Recon 96×96\n(3x Super-Res)' if i == 0 else '')
    axes[i, 3].axis('off')

plt.suptitle('MAMBA FIXED: Clean 32×32 + Super-Resolution', fontsize=16, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig('checkpoints_fixed/multi_scale_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*60)
print("EXPECTED RESULTS:")
print("  ✅ 32×32 should now be CLEAN (no speckle noise)")
print("  ✅ 64×64 and 96×96 should show smooth super-resolution")
print("  ✅ All scales should have consistent colors and features")
print("="*60)

## Summary of Fixes

### ✅ Implemented from Playbook

1. **Pixel-Center Coordinates** (`make_center_grid`)
   - Eliminates half-pixel misalignment at 32×32

2. **Geometry-Aware SSM**
   - Morton (Z-order) curve sorting
   - Per-step Δt from coordinate distances
   - Sequence-length invariant behavior

3. **Band-Limited Fourier Features**
   - Resolution-aware with Nyquist cutoff
   - Prevents high-frequency aliasing
   - num_freqs=64, max_cycles=16 for 32×32

4. **Local Cross-Attention**
   - KNN-based (k=32)
   - Relative position encoding
   - More efficient and stable

5. **RBF Interpolation Prior**
   - Strong bias for color smoothness
   - Model predicts residual corrections
   - Better convergence

6. **Training Hygiene**
   - RGB in `[-1, 1]` range
   - N(0,1) noise initialization
   - EMA weights (decay=0.999)
   - AdamW: lr=2e-4, wd=0.05
   - Warmup + cosine schedule
   - No dropout in SSM

### 📊 Expected Improvements

| Metric | Before | After (Expected) |
|--------|--------|------------------|
| 32×32 Quality | Noisy/speckle | Clean |
| PSNR @ 32×32 | ~20 dB | ~25-28 dB |
| SSIM @ 32×32 | ~0.70 | ~0.85-0.90 |
| Super-res | Good | Excellent |
| Training Speed | Baseline | ~Same |

### 🔬 Key Insights

**Why 64×64/96×96 looked better before:**
1. More tokens → smaller dt → smoother SSM dynamics (masked the dt=1/N bug)
2. Oversampling "anti-aliased" the high-frequency features

**After fixes:**
- 32×32 converges properly with clean reconstructions
- Super-resolution capacity preserved
- Model learns true continuous field, not grid-dependent features