# Multi-Directional MAMBA Diffusion (v2)

## Key Innovation: Multi-Directional Spatial Scanning

**Problem Solved**: Original MAMBA processes coordinates as 1D sequence, losing 2D spatial structure

**Solution**: Process in 4 directions (horizontal, vertical, diagonal, anti-diagonal), then fuse

**Expected**: Smoother reconstructions with 2-4 dB PSNR improvement

```
Single Direction (v1):        Multi-Direction (v2):
A→B→C→D→E→F→...              Horizontal: A→B→C→D
                              Vertical:   A→E→I→M
Missing spatial neighbors     Diagonal:   A→B→E→C→F
→ Noise and wiggliness        Anti-diag:  D→C→H→G
                              → Full 2D awareness!
```

In [None]:
import sys
import os

# Add parent directory to path
notebook_dir = os.path.abspath('')
parent_dir = os.path.dirname(notebook_dir)
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

from core.neural_fields.perceiver import FourierFeatures
from core.sparse.cifar10_sparse import SparseCIFAR10Dataset
from core.sparse.metrics import MetricsTracker, print_metrics, visualize_predictions

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"Working directory: {os.getcwd()}")
print(f"Parent directory added to path: {parent_dir}")

## 1. Core SSM Components (Fast Implementation)

In [None]:
class SSMBlockFast(nn.Module):
    """
    Ultra-fast SSM using cumulative scan
    Optimized version from MAMBA v1
    """
    def __init__(self, d_model, d_state=16, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # State space parameters
        self.A_log = nn.Parameter(torch.randn(d_state) * 0.1 - 1.0)
        self.B = nn.Linear(d_model, d_state, bias=False)
        self.C = nn.Linear(d_state, d_model, bias=False)
        self.D = nn.Parameter(torch.randn(d_model) * 0.01)
        
        nn.init.xavier_uniform_(self.B.weight, gain=0.5)
        nn.init.xavier_uniform_(self.C.weight, gain=0.5)
        
        self.gate = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Sigmoid()
        )
        
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.eps = 1e-8
    
    def forward(self, x):
        B, N, D = x.shape
        
        # Discretization
        A = -torch.exp(self.A_log).clamp(min=self.eps, max=10.0)
        dt = 1.0 / N
        A_bar = torch.exp(dt * A)
        B_bar = torch.where(
            torch.abs(A) > self.eps,
            (A_bar - 1.0) / (A + self.eps),
            torch.ones_like(A) * dt
        )
        
        Bu = self.B(x) * B_bar
        
        # Efficient state computation
        indices = torch.arange(N, device=x.device)
        decay = A_bar.unsqueeze(0).pow(
            (indices.unsqueeze(0) - indices.unsqueeze(1)).clamp(min=0).unsqueeze(-1)
        )
        mask = indices.unsqueeze(0) >= indices.unsqueeze(1)
        decay = decay * mask.unsqueeze(-1).float()
        
        h = torch.einsum('nmd,bnd->bmd', decay, Bu)
        h = torch.clamp(h, min=-10.0, max=10.0)
        
        # Output
        y = self.C(h) + self.D * x
        
        # Gating
        gate = self.gate(x)
        y = gate * y + (1 - gate) * x
        
        return self.dropout(self.norm(y))


class MambaBlock(nn.Module):
    """Standard Mamba block (for comparison)"""
    def __init__(self, d_model, d_state=16, expand_factor=2, dropout=0.1):
        super().__init__()
        
        self.proj_in = nn.Linear(d_model, d_model * expand_factor)
        self.ssm = SSMBlockFast(d_model * expand_factor, d_state, dropout)
        self.proj_out = nn.Linear(d_model * expand_factor, d_model)
        
        self.mlp = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        residual = x
        x = self.proj_in(x)
        x = self.ssm(x)
        x = self.proj_out(x)
        x = x + residual
        x = x + self.mlp(x)
        return x


print("✓ Core SSM components loaded")

## 2. Multi-Directional MAMBA Components

**Innovation**: Process coordinates in 4 directions to capture full 2D spatial structure

In [None]:
# ============================================
# Coordinate Ordering Functions
# ============================================

def order_by_row(coords):
    """Row-major ordering (horizontal scan)"""
    B, N, _ = coords.shape
    indices_list = []
    for b in range(B):
        y_vals = coords[b, :, 1]
        x_vals = coords[b, :, 0]
        sort_keys = y_vals * 1000 + x_vals
        indices = torch.argsort(sort_keys)
        indices_list.append(indices)
    return torch.stack(indices_list, dim=0)

def order_by_column(coords):
    """Column-major ordering (vertical scan)"""
    B, N, _ = coords.shape
    indices_list = []
    for b in range(B):
        y_vals = coords[b, :, 1]
        x_vals = coords[b, :, 0]
        sort_keys = x_vals * 1000 + y_vals
        indices = torch.argsort(sort_keys)
        indices_list.append(indices)
    return torch.stack(indices_list, dim=0)

def order_by_diagonal(coords):
    """Diagonal ordering (top-left to bottom-right)"""
    B, N, _ = coords.shape
    indices_list = []
    for b in range(B):
        y_vals = coords[b, :, 1]
        x_vals = coords[b, :, 0]
        diag_vals = x_vals + y_vals
        sort_keys = diag_vals * 1000 + x_vals
        indices = torch.argsort(sort_keys)
        indices_list.append(indices)
    return torch.stack(indices_list, dim=0)

def order_by_antidiagonal(coords):
    """Anti-diagonal ordering (top-right to bottom-left)"""
    B, N, _ = coords.shape
    indices_list = []
    for b in range(B):
        y_vals = coords[b, :, 1]
        x_vals = coords[b, :, 0]
        antidiag_vals = x_vals - y_vals
        sort_keys = antidiag_vals * 1000 + x_vals
        indices = torch.argsort(sort_keys)
        indices_list.append(indices)
    return torch.stack(indices_list, dim=0)

def reorder_sequence(x, indices):
    """Apply ordering to sequence"""
    B, N, D = x.shape
    indices_expanded = indices.unsqueeze(-1).expand(B, N, D)
    return torch.gather(x, dim=1, index=indices_expanded)

def inverse_reorder(x, indices):
    """Reverse ordering back to original positions"""
    B, N, D = x.shape
    inverse_indices = torch.zeros_like(indices)
    for b in range(B):
        inverse_indices[b, indices[b]] = torch.arange(N, device=indices.device)
    indices_expanded = inverse_indices.unsqueeze(-1).expand(B, N, D)
    return torch.gather(x, dim=1, index=indices_expanded)


# ============================================
# Multi-Directional SSM
# ============================================

class MultiDirectionalSSM(nn.Module):
    """Process sequence in 4 directions, fuse results"""
    def __init__(self, d_model, d_state=16, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        
        # 4 separate SSM blocks for each direction
        self.ssm_horizontal = SSMBlockFast(d_model, d_state, dropout)
        self.ssm_vertical = SSMBlockFast(d_model, d_state, dropout)
        self.ssm_diagonal = SSMBlockFast(d_model, d_state, dropout)
        self.ssm_antidiagonal = SSMBlockFast(d_model, d_state, dropout)
        
        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(4 * d_model, d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model)
        )
        
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x, coords):
        # Get orderings
        indices_h = order_by_row(coords)
        indices_v = order_by_column(coords)
        indices_d = order_by_diagonal(coords)
        indices_a = order_by_antidiagonal(coords)
        
        # Horizontal
        x_h = reorder_sequence(x, indices_h)
        y_h = self.ssm_horizontal(x_h)
        y_h = inverse_reorder(y_h, indices_h)
        
        # Vertical
        x_v = reorder_sequence(x, indices_v)
        y_v = self.ssm_vertical(x_v)
        y_v = inverse_reorder(y_v, indices_v)
        
        # Diagonal
        x_d = reorder_sequence(x, indices_d)
        y_d = self.ssm_diagonal(x_d)
        y_d = inverse_reorder(y_d, indices_d)
        
        # Anti-diagonal
        x_a = reorder_sequence(x, indices_a)
        y_a = self.ssm_antidiagonal(x_a)
        y_a = inverse_reorder(y_a, indices_a)
        
        # Fuse all 4 directions
        y_concat = torch.cat([y_h, y_v, y_d, y_a], dim=-1)
        y_fused = self.fusion(y_concat)
        
        # Residual
        y = x + y_fused
        y = self.norm(y)
        
        return y


class MultiDirectionalMambaBlock(nn.Module):
    """Complete multi-directional Mamba block"""
    def __init__(self, d_model, d_state=16, expand_factor=2, dropout=0.1):
        super().__init__()
        
        self.proj_in = nn.Linear(d_model, d_model * expand_factor)
        
        self.multi_ssm = MultiDirectionalSSM(
            d_model * expand_factor,
            d_state,
            dropout
        )
        
        self.proj_out = nn.Linear(d_model * expand_factor, d_model)
        
        self.mlp = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x, coords):
        # SSM branch with multi-directional processing
        residual = x
        x = self.proj_in(x)
        x = self.multi_ssm(x, coords)
        x = self.proj_out(x)
        x = x + residual
        
        # MLP branch
        x = x + self.mlp(x)
        
        return x


print("✓ Multi-directional MAMBA components loaded")
print("  - 4 scanning directions: horizontal, vertical, diagonal, anti-diagonal")
print("  - Fusion mechanism to combine all directions")
print("  - Full 2D spatial awareness!")

## 3. Time Embedding

In [None]:
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb

print("✓ Time embedding loaded")

## 4. Multi-Directional MAMBA Diffusion Model

In [None]:
class MAMBADiffusion(nn.Module):
    """
    Multi-Directional State Space Model for Sparse Field Diffusion
    
    Key Innovation:
    - Processes coordinates in 4 directions (horizontal, vertical, diagonal, anti-diagonal)
    - Fuses directional outputs to capture full 2D spatial structure
    - Addresses MAMBA's limitation of 1D sequential processing
    """
    def __init__(
        self,
        num_fourier_feats=256,
        d_model=512,
        num_layers=6,
        d_state=16,
        dropout=0.1
    ):
        super().__init__()
        self.d_model = d_model
        
        # Fourier features
        self.fourier = FourierFeatures(coord_dim=2, num_freqs=num_fourier_feats, scale=10.0)
        feat_dim = num_fourier_feats * 2
        
        # Project inputs and queries
        self.input_proj = nn.Linear(feat_dim + 3, d_model)
        self.query_proj = nn.Linear(feat_dim + 3, d_model)
        
        # Time embedding
        self.time_embed = SinusoidalTimeEmbedding(d_model)
        self.time_mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model)
        )
        
        # Multi-directional MAMBA blocks
        self.mamba_blocks = nn.ModuleList([
            MultiDirectionalMambaBlock(
                d_model,
                d_state=d_state,
                expand_factor=2,
                dropout=dropout
            )
            for _ in range(num_layers)
        ])
        
        # Cross-attention
        self.query_cross_attn = nn.MultiheadAttention(
            d_model, num_heads=8, dropout=dropout, batch_first=True
        )
        
        # Output decoder
        self.decoder = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 3)
        )
    
    def forward(self, noisy_values, query_coords, t, input_coords, input_values):
        B = query_coords.shape[0]
        N_in = input_coords.shape[1]
        N_out = query_coords.shape[1]
        
        # Time embedding
        t_emb = self.time_mlp(self.time_embed(t))
        
        # Fourier features
        input_feats = self.fourier(input_coords)
        query_feats = self.fourier(query_coords)
        
        # Encode
        input_tokens = self.input_proj(
            torch.cat([input_feats, input_values], dim=-1)
        )
        query_tokens = self.query_proj(
            torch.cat([query_feats, noisy_values], dim=-1)
        )
        
        # Add time
        input_tokens = input_tokens + t_emb.unsqueeze(1)
        query_tokens = query_tokens + t_emb.unsqueeze(1)
        
        # Concatenate coordinates and sequences
        all_coords = torch.cat([input_coords, query_coords], dim=1)
        seq = torch.cat([input_tokens, query_tokens], dim=1)
        
        # Process through multi-directional MAMBA blocks
        for mamba_block in self.mamba_blocks:
            seq = mamba_block(seq, all_coords)  # Pass coords for directional scanning
        
        # Split back
        input_seq = seq[:, :N_in, :]
        query_seq = seq[:, N_in:, :]
        
        # Cross-attention
        output, _ = self.query_cross_attn(query_seq, input_seq, input_seq)
        
        # Decode
        return self.decoder(output)


# Test model
model = MAMBADiffusion(
    num_fourier_feats=256,
    d_model=512,
    num_layers=6,
    d_state=16
).to(device)

test_noisy = torch.rand(4, 204, 3).to(device)
test_query_coords = torch.rand(4, 204, 2).to(device)
test_t = torch.rand(4).to(device)
test_input_coords = torch.rand(4, 204, 2).to(device)
test_input_values = torch.rand(4, 204, 3).to(device)

test_out = model(test_noisy, test_query_coords, test_t, test_input_coords, test_input_values)
print(f"✓ Model test passed: {test_out.shape}")
print(f"✓ Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  (Expected: ~4× more than v1 due to 4 directional SSMs)")

## 5. Flow Matching Training

In [None]:
def conditional_flow(x_0, x_1, t):
    """Linear interpolation"""
    return (1 - t) * x_0 + t * x_1

def target_velocity(x_0, x_1):
    """Target velocity"""
    return x_1 - x_0

@torch.no_grad()
def heun_sample(model, output_coords, input_coords, input_values, num_steps=50, device='cuda'):
    """Heun ODE solver"""
    B, N_out = output_coords.shape[0], output_coords.shape[1]
    x_t = torch.randn(B, N_out, 3, device=device)
    
    dt = 1.0 / num_steps
    ts = torch.linspace(0, 1 - dt, num_steps)
    
    for t_val in tqdm(ts, desc="Sampling", leave=False):
        t = torch.full((B,), t_val.item(), device=device)
        t_next = torch.full((B,), t_val.item() + dt, device=device)
        
        v1 = model(x_t, output_coords, t, input_coords, input_values)
        x_next_pred = x_t + dt * v1
        
        v2 = model(x_next_pred, output_coords, t_next, input_coords, input_values)
        x_t = x_t + dt * 0.5 * (v1 + v2)
    
    return torch.clamp(x_t, 0, 1)

def train_flow_matching(
    model, train_loader, test_loader, epochs=100, lr=1e-4, device='cuda',
    visualize_every=5, eval_every=2, save_dir='checkpoints_multidir'
):
    """Train with flow matching"""
    import os
    os.makedirs(save_dir, exist_ok=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    losses = []
    best_val_loss = float('inf')
    
    # Full grid for visualization
    y, x = torch.meshgrid(
        torch.linspace(0, 1, 32),
        torch.linspace(0, 1, 32),
        indexing='ij'
    )
    full_coords = torch.stack([x.flatten(), y.flatten()], dim=-1).to(device)
    
    viz_batch = next(iter(train_loader))
    viz_input_coords = viz_batch['input_coords'][:4].to(device)
    viz_input_values = viz_batch['input_values'][:4].to(device)
    viz_output_coords = viz_batch['output_coords'][:4].to(device)
    viz_output_values = viz_batch['output_values'][:4].to(device)
    viz_full_images = viz_batch['full_image'][:4].to(device)
    viz_input_indices = viz_batch['input_indices'][:4]
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_coords = batch['input_coords'].to(device)
            input_values = batch['input_values'].to(device)
            output_coords = batch['output_coords'].to(device)
            output_values = batch['output_values'].to(device)
            
            B = input_coords.shape[0]
            t = torch.rand(B, device=device)
            
            x_0 = torch.randn_like(output_values)
            x_1 = output_values
            
            t_broadcast = t.view(B, 1, 1)
            x_t = conditional_flow(x_0, x_1, t_broadcast)
            u_t = target_velocity(x_0, x_1)
            
            v_pred = model(x_t, output_coords, t, input_coords, input_values)
            loss = F.mse_loss(v_pred, u_t)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        scheduler.step()
        
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.6f}, LR = {scheduler.get_last_lr()[0]:.6f}")
        
        # Evaluation
        val_loss = None
        if (epoch + 1) % eval_every == 0 or epoch == 0:
            model.eval()
            tracker = MetricsTracker()
            val_loss_accum = 0
            val_batches = 0
            with torch.no_grad():
                for i, batch in enumerate(test_loader):
                    if i >= 10:
                        break
                    pred_values = heun_sample(
                        model, batch['output_coords'].to(device),
                        batch['input_coords'].to(device), batch['input_values'].to(device),
                        num_steps=50, device=device
                    )
                    tracker.update(pred_values, batch['output_values'].to(device))
                    val_loss_accum += F.mse_loss(pred_values, batch['output_values'].to(device)).item()
                    val_batches += 1
                    
                results = tracker.compute()
                val_loss = val_loss_accum / val_batches
                print(f"  Eval - MSE: {results['mse']:.6f}, MAE: {results['mae']:.6f}, Val Loss: {val_loss:.6f}")
                
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    torch.save({
                        'epoch': epoch + 1,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'loss': avg_loss,
                        'val_loss': val_loss,
                        'best_val_loss': best_val_loss
                    }, f'{save_dir}/mamba_multidir_best.pth')
                    print(f"  ✓ Saved best model (val_loss: {val_loss:.6f})")
        
        # Save latest
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_loss,
            'val_loss': val_loss if val_loss is not None else avg_loss,
            'best_val_loss': best_val_loss
        }, f'{save_dir}/mamba_multidir_latest.pth')
        
        # Visualization
        if (epoch + 1) % visualize_every == 0 or epoch == 0:
            model.eval()
            with torch.no_grad():
                pred_values = heun_sample(
                    model, viz_output_coords, viz_input_coords, viz_input_values,
                    num_steps=50, device=device
                )
                
                full_coords_batch = full_coords.unsqueeze(0).expand(4, -1, -1)
                full_pred_values = heun_sample(
                    model, full_coords_batch, viz_input_coords, viz_input_values,
                    num_steps=50, device=device
                )
                full_pred_images = full_pred_values.view(4, 32, 32, 3).permute(0, 3, 1, 2)
                
                fig, axes = plt.subplots(4, 5, figsize=(20, 16))
                
                for i in range(4):
                    # Ground truth
                    axes[i, 0].imshow(viz_full_images[i].permute(1, 2, 0).cpu().numpy())
                    axes[i, 0].set_title('Ground Truth' if i == 0 else '', fontsize=10)
                    axes[i, 0].axis('off')
                    
                    # Sparse input
                    input_img = torch.zeros(3, 32, 32, device=device)
                    input_idx = viz_input_indices[i]
                    input_img.view(3, -1)[:, input_idx] = viz_input_values[i].T
                    axes[i, 1].imshow(input_img.permute(1, 2, 0).cpu().numpy())
                    axes[i, 1].set_title('Sparse Input (20%)' if i == 0 else '', fontsize=10)
                    axes[i, 1].axis('off')
                    
                    # Sparse target
                    target_img = torch.zeros(3, 32, 32, device=device)
                    output_idx = viz_batch['output_indices'][i]
                    target_img.view(3, -1)[:, output_idx] = viz_output_values[i].T
                    axes[i, 2].imshow(target_img.permute(1, 2, 0).cpu().numpy())
                    axes[i, 2].set_title('Sparse Target (20%)' if i == 0 else '', fontsize=10)
                    axes[i, 2].axis('off')
                    
                    # Sparse prediction
                    pred_img = torch.zeros(3, 32, 32, device=device)
                    pred_img.view(3, -1)[:, output_idx] = pred_values[i].T
                    axes[i, 3].imshow(np.clip(pred_img.permute(1, 2, 0).cpu().numpy(), 0, 1))
                    axes[i, 3].set_title('Sparse Prediction' if i == 0 else '', fontsize=10)
                    axes[i, 3].axis('off')
                    
                    # Full field
                    axes[i, 4].imshow(np.clip(full_pred_images[i].permute(1, 2, 0).cpu().numpy(), 0, 1))
                    axes[i, 4].set_title('Full Field' if i == 0 else '', fontsize=10)
                    axes[i, 4].axis('off')
                
                plt.suptitle(f'Multi-Dir MAMBA - Epoch {epoch+1} (Best Val: {best_val_loss:.6f})', 
                           fontsize=14, y=0.995)
                plt.tight_layout()
                plt.savefig(f'{save_dir}/epoch_{epoch+1:03d}.png', dpi=150, bbox_inches='tight')
                plt.show()
                plt.close()
    
    print(f"\n✓ Training complete! Best validation loss: {best_val_loss:.6f}")
    return losses

print("✓ Training functions loaded")

## 6. Load Data and Train

In [None]:
# Load dataset
train_dataset = SparseCIFAR10Dataset(
    root='../data', train=True, input_ratio=0.2, output_ratio=0.2, download=True, seed=42
)
test_dataset = SparseCIFAR10Dataset(
    root='../data', train=False, input_ratio=0.2, output_ratio=0.2, download=True, seed=42
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")

# Initialize model
model = MAMBADiffusion(
    num_fourier_feats=256,
    d_model=512,
    num_layers=6,
    d_state=16
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train
losses = train_flow_matching(model, train_loader, test_loader, epochs=100, lr=1e-4, device=device)

## 7. Final Evaluation & Loss Plot

In [None]:
# Plot loss
plt.figure(figsize=(10, 4))
plt.plot(losses, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss: Multi-Directional MAMBA')
plt.grid(alpha=0.3)
plt.savefig('checkpoints_multidir/training_loss.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Multi-Scale Evaluation

Test scale-invariant continuous representations at 32×32, 64×64, 96×96

In [None]:
def create_multi_scale_grids(device='cuda'):
    """Create coordinate grids at different resolutions"""
    grids = {}
    for size in [32, 64, 96]:
        y, x = torch.meshgrid(
            torch.linspace(0, 1, size),
            torch.linspace(0, 1, size),
            indexing='ij'
        )
        grids[size] = torch.stack([x.flatten(), y.flatten()], dim=-1).to(device)
    return grids

@torch.no_grad()
def multi_scale_reconstruction(model, input_coords, input_values, grids, num_steps=100, device='cuda'):
    """Reconstruct at multiple scales"""
    model.eval()
    B = input_coords.shape[0]
    reconstructions = {}
    
    for size, coords in grids.items():
        print(f"Reconstructing at {size}×{size}...")
        coords_batch = coords.unsqueeze(0).expand(B, -1, -1)
        pred_values = heun_sample(
            model, coords_batch, input_coords, input_values,
            num_steps=num_steps, device=device
        )
        pred_images = pred_values.view(B, size, size, 3).permute(0, 3, 1, 2)
        reconstructions[size] = pred_images
    
    return reconstructions

# Create grids
multi_scale_grids = create_multi_scale_grids(device)

# Test on batch
test_batch = next(iter(test_loader))
B_test = 4

multi_scale_results = multi_scale_reconstruction(
    model,
    test_batch['input_coords'][:B_test].to(device),
    test_batch['input_values'][:B_test].to(device),
    multi_scale_grids,
    num_steps=100,
    device=device
)

print("\nMulti-scale reconstruction complete!")
for size, imgs in multi_scale_results.items():
    print(f"  {size}×{size}: {imgs.shape}")

## 9. Visualize Multi-Scale Results

In [None]:
def visualize_multi_scale(ground_truth, sparse_input_img, multi_scale_results, sample_idx=0):
    """Visualize multi-scale reconstructions"""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Row 1
    axes[0, 0].imshow(ground_truth.permute(1, 2, 0).cpu().numpy())
    axes[0, 0].set_title('Ground Truth (32×32)', fontsize=12, fontweight='bold')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(sparse_input_img.permute(1, 2, 0).cpu().numpy())
    axes[0, 1].set_title('Sparse Input (20%)', fontsize=12, fontweight='bold')
    axes[0, 1].axis('off')
    
    img_32 = multi_scale_results[32][sample_idx].permute(1, 2, 0).cpu().numpy()
    axes[0, 2].imshow(np.clip(img_32, 0, 1))
    axes[0, 2].set_title('Reconstructed 32×32\n(Native)', fontsize=12, fontweight='bold')
    axes[0, 2].axis('off')
    
    # Row 2
    img_64 = multi_scale_results[64][sample_idx].permute(1, 2, 0).cpu().numpy()
    axes[1, 0].imshow(np.clip(img_64, 0, 1))
    axes[1, 0].set_title('Reconstructed 64×64\n(2× Upsampling)', fontsize=12, fontweight='bold')
    axes[1, 0].axis('off')
    
    img_96 = multi_scale_results[96][sample_idx].permute(1, 2, 0).cpu().numpy()
    axes[1, 1].imshow(np.clip(img_96, 0, 1))
    axes[1, 1].set_title('Reconstructed 96×96\n(3× Upsampling)', fontsize=12, fontweight='bold')
    axes[1, 1].axis('off')
    
    # Comparison
    img_32_up = torch.nn.functional.interpolate(
        multi_scale_results[32][sample_idx:sample_idx+1],
        size=64, mode='nearest'
    )[0].permute(1, 2, 0).cpu().numpy()
    axes[1, 2].imshow(np.clip(img_32_up, 0, 1))
    axes[1, 2].set_title('32×32 Upsampled\n(Nearest)', fontsize=12, fontweight='bold')
    axes[1, 2].axis('off')
    
    plt.suptitle('Multi-Directional MAMBA: Scale-Invariant Reconstruction', fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    return fig

# Visualize samples
for i in range(min(B_test, 4)):
    sparse_img = torch.zeros(3, 32, 32)
    input_idx = test_batch['input_indices'][i]
    sparse_img.view(3, -1)[:, input_idx] = test_batch['input_values'][i].T
    
    fig = visualize_multi_scale(
        test_batch['full_image'][i],
        sparse_img,
        multi_scale_results,
        sample_idx=i
    )
    plt.savefig(f'checkpoints_multidir/multiscale_sample_{i}.png', dpi=150, bbox_inches='tight')
    plt.show()
    plt.close()

## 10. Quantitative Evaluation

In [None]:
print("="*60)
print("QUANTITATIVE EVALUATION: Full Field Reconstruction at 32×32")
print("="*60)

model.eval()
tracker_full_field = MetricsTracker()

with torch.no_grad():
    for i, batch in enumerate(tqdm(test_loader, desc="Full field evaluation")):
        if i >= 100:
            break
        
        B = batch['input_coords'].shape[0]
        full_coords_batch = multi_scale_grids[32].unsqueeze(0).expand(B, -1, -1)
        
        pred_values = heun_sample(
            model, full_coords_batch,
            batch['input_coords'].to(device),
            batch['input_values'].to(device),
            num_steps=100, device=device
        )
        
        pred_images = pred_values.view(B, 32, 32, 3).permute(0, 3, 1, 2)
        tracker_full_field.update(None, None, pred_images, batch['full_image'].to(device))

results = tracker_full_field.compute()
results_std = tracker_full_field.compute_std()

print(f"\nMulti-Directional MAMBA Results (32×32):")
print(f"  PSNR: {results['psnr']:.2f} ± {results_std['psnr_std']:.2f} dB")
print(f"  SSIM: {results['ssim']:.4f} ± {results_std['ssim_std']:.4f}")
print(f"  MSE:  {results['mse']:.6f} ± {results_std['mse_std']:.6f}")
print(f"  MAE:  {results['mae']:.6f} ± {results_std['mae_std']:.6f}")

print("\n" + "="*60)
print("Expected Improvement over v1 (single-direction):")
print("  PSNR: +2-4 dB (if spatial locality was the issue)")
print("  SSIM: +0.05-0.07")
print("  Visual: Significantly smoother textures")
print("="*60)

## Summary

### Multi-Directional MAMBA Advantages

**vs v1 (Single-Direction)**:
- ✅ **Full 2D spatial awareness**: 4 scanning directions capture all spatial relationships
- ✅ **Better texture quality**: Horizontal, vertical, and diagonal context preserved
- ✅ **Smoother reconstructions**: Reduced noise from proper spatial structure
- ✅ **Improved multi-scale**: Better generalization to 64×64 and 96×96

**Trade-offs**:
- ⚠️ **4× more parameters**: 4 SSM blocks per direction
- ⚠️ **3-4× slower training**: 4 directional passes per layer
- ⚠️ **More memory**: ~4× VRAM for SSM layers

**Expected Performance**:
- **PSNR**: 26-28 dB (vs v1: ~24 dB) → +2-4 dB improvement
- **SSIM**: 0.90-0.92 (vs v1: ~0.85) → +0.05-0.07 improvement
- **Visual**: Smooth textures, clear edges, natural multi-scale

### If Multi-Directional Works:
→ Confirms **spatial locality** was the main issue causing noise

### If No Improvement:
→ Investigate other causes:
1. Training duration (200 epochs insufficient)
2. Fourier scale (scale=10 too high)
3. ODE steps (50 too few)