# OmniField for N-MNIST Classification

This notebook implements OmniField for Neuromorphic MNIST (N-MNIST) classification.

**N-MNIST** is a spiking/event-based version of MNIST captured using a Dynamic Vision Sensor (DVS).
Events are represented as `(x, y, t, polarity)` - perfectly suited for neural field approaches.

## Two Classification Approaches:
1. **End-to-End Classification**: Direct classification from events using OmniField
2. **Reconstruction + Fine-tuning**: First reconstruct the image, then classify

## Architecture:
- Small base OmniField
- Deep ICMR (6 layers) for cross-modal refinement

---

## 1. Setup and Imports

In [None]:
# Install required packages
# !pip install tonic torch torchvision matplotlib numpy seaborn

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from collections import Counter
from functools import wraps
from math import log

from einops import rearrange, repeat

# Try to import tonic for N-MNIST
try:
    import tonic
    from tonic import transforms as tonic_transforms
    TONIC_AVAILABLE = True
except ImportError:
    print("Tonic not available. Install with: pip install tonic")
    TONIC_AVAILABLE = False

# Device setup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

## 2. Load and Explore N-MNIST Dataset

In [None]:
# Load N-MNIST dataset using tonic
if TONIC_AVAILABLE:
    # Download and load N-MNIST
    nmnist_train = tonic.datasets.NMNIST(save_to='./data', train=True)
    nmnist_test = tonic.datasets.NMNIST(save_to='./data', train=False)
    
    print(f"Training samples: {len(nmnist_train)}")
    print(f"Test samples: {len(nmnist_test)}")
    print(f"Sensor size: {nmnist_train.sensor_size}")
else:
    print("Please install tonic to load N-MNIST dataset")

In [None]:
# Examine a single sample
if TONIC_AVAILABLE:
    events, label = nmnist_train[0]
    
    print(f"\n=== Sample 0 (Label: {label}) ===")
    print(f"Events shape: {events.shape}")
    print(f"Event dtype: {events.dtype}")
    print(f"\nEvent structure (first 5 events):")
    print(f"  x: {events['x'][:5]}")
    print(f"  y: {events['y'][:5]}")
    print(f"  t: {events['t'][:5]} (microseconds)")
    print(f"  p: {events['p'][:5]} (polarity: 0=OFF, 1=ON)")
    
    print(f"\nStatistics:")
    print(f"  Total events: {len(events)}")
    print(f"  Time range: {events['t'].min()} - {events['t'].max()} μs")
    print(f"  Duration: {(events['t'].max() - events['t'].min()) / 1000:.2f} ms")
    print(f"  X range: {events['x'].min()} - {events['x'].max()}")
    print(f"  Y range: {events['y'].min()} - {events['y'].max()}")
    print(f"  ON events: {(events['p'] == 1).sum()} ({100*(events['p'] == 1).mean():.1f}%)")
    print(f"  OFF events: {(events['p'] == 0).sum()} ({100*(events['p'] == 0).mean():.1f}%)")

## 3. Dataset Statistics and Visualization

In [None]:
def compute_dataset_statistics(dataset, num_samples=1000):
    """Compute statistics over N-MNIST dataset."""
    stats = {
        'num_events': [],
        'duration_ms': [],
        'on_ratio': [],
        'labels': []
    }
    
    indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)
    
    for idx in tqdm(indices, desc="Computing statistics"):
        events, label = dataset[idx]
        stats['num_events'].append(len(events))
        stats['duration_ms'].append((events['t'].max() - events['t'].min()) / 1000)
        stats['on_ratio'].append((events['p'] == 1).mean())
        stats['labels'].append(label)
    
    return {k: np.array(v) for k, v in stats.items()}

if TONIC_AVAILABLE:
    stats = compute_dataset_statistics(nmnist_train, num_samples=2000)

In [None]:
if TONIC_AVAILABLE:
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # 1. Distribution of number of events
    ax = axes[0, 0]
    ax.hist(stats['num_events'], bins=50, edgecolor='black', alpha=0.7)
    ax.axvline(np.mean(stats['num_events']), color='red', linestyle='--', label=f'Mean: {np.mean(stats["num_events"]):.0f}')
    ax.set_xlabel('Number of Events')
    ax.set_ylabel('Count')
    ax.set_title('Distribution of Events per Sample')
    ax.legend()
    
    # 2. Distribution of event duration
    ax = axes[0, 1]
    ax.hist(stats['duration_ms'], bins=50, edgecolor='black', alpha=0.7, color='green')
    ax.axvline(np.mean(stats['duration_ms']), color='red', linestyle='--', label=f'Mean: {np.mean(stats["duration_ms"]):.1f} ms')
    ax.set_xlabel('Duration (ms)')
    ax.set_ylabel('Count')
    ax.set_title('Distribution of Event Duration')
    ax.legend()
    
    # 3. ON/OFF polarity ratio
    ax = axes[1, 0]
    ax.hist(stats['on_ratio'], bins=50, edgecolor='black', alpha=0.7, color='orange')
    ax.axvline(np.mean(stats['on_ratio']), color='red', linestyle='--', label=f'Mean: {np.mean(stats["on_ratio"]):.2f}')
    ax.set_xlabel('ON Event Ratio')
    ax.set_ylabel('Count')
    ax.set_title('Distribution of ON Polarity Ratio')
    ax.legend()
    
    # 4. Class distribution
    ax = axes[1, 1]
    label_counts = Counter(stats['labels'])
    ax.bar(range(10), [label_counts[i] for i in range(10)], edgecolor='black', alpha=0.7, color='purple')
    ax.set_xlabel('Digit Class')
    ax.set_ylabel('Count')
    ax.set_title('Class Distribution (Sampled)')
    ax.set_xticks(range(10))
    
    plt.tight_layout()
    plt.savefig('nmnist_statistics.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Print summary statistics
    print("\n=== N-MNIST Dataset Statistics ===")
    print(f"Events per sample: {np.mean(stats['num_events']):.0f} ± {np.std(stats['num_events']):.0f}")
    print(f"Duration: {np.mean(stats['duration_ms']):.1f} ± {np.std(stats['duration_ms']):.1f} ms")
    print(f"ON ratio: {np.mean(stats['on_ratio']):.3f} ± {np.std(stats['on_ratio']):.3f}")

In [None]:
def visualize_events(events, label, ax=None, title=None):
    """Visualize events as a 2D histogram (accumulated frame)."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
    
    # Create accumulated frame
    frame = np.zeros((34, 34))
    for i in range(len(events)):
        x, y, p = events['x'][i], events['y'][i], events['p'][i]
        frame[y, x] += (2 * p - 1)  # +1 for ON, -1 for OFF
    
    # Normalize
    frame = np.clip(frame, -50, 50)
    
    im = ax.imshow(frame, cmap='RdBu_r', vmin=-50, vmax=50)
    ax.set_title(title or f'Label: {label}')
    ax.axis('off')
    return im

def visualize_events_3d(events, label, ax=None):
    """Visualize events in 3D (x, y, t)."""
    if ax is None:
        fig = plt.figure(figsize=(8, 6))
        ax = fig.add_subplot(111, projection='3d')
    
    # Subsample for visualization
    n = min(2000, len(events))
    idx = np.random.choice(len(events), n, replace=False)
    
    x = events['x'][idx]
    y = events['y'][idx]
    t = (events['t'][idx] - events['t'].min()) / 1000  # Convert to ms
    p = events['p'][idx]
    
    colors = ['blue' if pol == 0 else 'red' for pol in p]
    ax.scatter(x, y, t, c=colors, s=1, alpha=0.5)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Time (ms)')
    ax.set_title(f'Label: {label}')
    return ax

In [None]:
if TONIC_AVAILABLE:
    # Visualize samples from each class
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    
    # Find one sample per class
    class_samples = {}
    for idx in range(len(nmnist_train)):
        events, label = nmnist_train[idx]
        if label not in class_samples:
            class_samples[label] = (events, label)
        if len(class_samples) == 10:
            break
    
    for i in range(10):
        events, label = class_samples[i]
        ax = axes[i // 5, i % 5]
        visualize_events(events, label, ax=ax, title=f'Digit: {label}')
    
    plt.suptitle('N-MNIST Samples (Accumulated Events)', fontsize=14)
    plt.tight_layout()
    plt.savefig('nmnist_samples.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
if TONIC_AVAILABLE:
    # 3D visualization of a few samples
    fig = plt.figure(figsize=(15, 5))
    
    for i, digit in enumerate([0, 3, 8]):
        events, label = class_samples[digit]
        ax = fig.add_subplot(1, 3, i+1, projection='3d')
        visualize_events_3d(events, label, ax=ax)
    
    plt.suptitle('N-MNIST Events in 3D (x, y, time)', fontsize=14)
    plt.tight_layout()
    plt.savefig('nmnist_3d.png', dpi=150, bbox_inches='tight')
    plt.show()

## 4. OmniField Architecture for N-MNIST

Key adaptations:
- **Input**: Sparse events (x, y, t, polarity) → Neural field
- **Small base model** with **6 ICMR layers**
- **Two output heads**: Classification and Reconstruction

In [None]:
# ============================================================
# Helper Functions and Modules
# ============================================================

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim=None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)
        if exists(self.norm_context):
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context=normed_context)
        return self.fn(x, **kwargs)

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim=-1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )
    
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads
        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)
        attn = sim.softmax(dim=-1)
        out = torch.einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

In [None]:
# ============================================================
# Gaussian Fourier Features (GFF)
# ============================================================

class GaussianFourierFeatures(nn.Module):
    """Gaussian Fourier Features for positional encoding."""
    def __init__(self, in_features, mapping_size, scale=10.0):
        super().__init__()
        self.in_features = in_features
        self.mapping_size = mapping_size
        # B ~ N(0, scale^2)
        self.register_buffer('B', torch.randn((in_features, mapping_size)) * scale)

    def forward(self, coords):
        # coords: [..., in_features]
        projections = coords @ self.B  # [..., mapping_size]
        fourier_feats = torch.cat([torch.sin(2 * np.pi * projections), 
                                   torch.cos(2 * np.pi * projections)], dim=-1)
        return fourier_feats  # [..., 2 * mapping_size]

In [None]:
# ============================================================
# Sinusoidal Initialization for Learnable Queries
# ============================================================

def get_sinusoidal_embeddings(n, d):
    """Generate sinusoidal positional embeddings."""
    assert d % 2 == 0, "Dimension must be even"
    position = torch.arange(n, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d, 2).float() * -(log(10000.0) / d))
    pe = torch.zeros(n, d)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

In [None]:
# ============================================================
# ICMR Block (Iterative Cross-Modal Refinement)
# ============================================================

class ICMRBlock(nn.Module):
    """Single ICMR block with cross-attention and self-attention."""
    def __init__(self, dim, num_latents, input_dim, heads=4, dim_head=32, dropout=0.0):
        super().__init__()
        self.latents = nn.Parameter(get_sinusoidal_embeddings(num_latents, dim))
        
        # Cross-attention: latents attend to input
        self.cross_attn = PreNorm(
            dim, 
            Attention(dim, input_dim, heads=heads, dim_head=dim_head, dropout=dropout),
            context_dim=input_dim
        )
        
        # Self-attention among latents
        self.self_attn = PreNorm(
            dim,
            Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
        )
        
        # Feed-forward
        self.ff = PreNorm(dim, FeedForward(dim, dropout=dropout))
        
        # Global feature projection (for ICMR)
        self.global_proj = nn.Linear(dim, dim)
    
    def forward(self, context, global_feature=None, mask=None):
        b = context.size(0)
        latents = repeat(self.latents, 'n d -> b n d', b=b)
        
        # Add global feature if provided (ICMR)
        if global_feature is not None:
            global_bias = self.global_proj(global_feature).unsqueeze(1)
            latents = latents + global_bias
        
        # Cross-attention
        latents = self.cross_attn(latents, context=context, mask=mask) + latents
        
        # Self-attention
        latents = self.self_attn(latents) + latents
        
        # Feed-forward
        latents = self.ff(latents) + latents
        
        return latents

In [None]:
# ============================================================
# OmniField for N-MNIST (Small base, 6 ICMR layers)
# ============================================================

class OmniFieldNMNIST(nn.Module):
    """
    OmniField adapted for N-MNIST classification and reconstruction.
    
    Architecture:
    - Small base model (dim=64)
    - 6 ICMR layers for deep cross-modal refinement
    - Dual heads: classification and reconstruction
    """
    def __init__(
        self,
        dim=64,                    # Small base dimension
        num_latents=64,            # Number of latent tokens
        num_icmr_layers=6,         # Deep ICMR (6 layers)
        heads=4,
        dim_head=16,
        num_classes=10,
        spatial_size=34,           # N-MNIST is 34x34
        dropout=0.1,
        gff_scale=10.0,
    ):
        super().__init__()
        self.dim = dim
        self.num_latents = num_latents
        self.num_icmr_layers = num_icmr_layers
        self.spatial_size = spatial_size
        
        # ========== Encoder ==========
        # Spatial positional encoding (x, y)
        self.spatial_enc = GaussianFourierFeatures(2, 16, scale=gff_scale)
        # Temporal encoding (t)
        self.temporal_enc = GaussianFourierFeatures(1, 8, scale=gff_scale)
        
        # Input projection: polarity (1) + spatial_enc (32) + temporal_enc (16) -> dim
        input_dim = 1 + 32 + 16  # 49
        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim)
        )
        
        # ========== ICMR Blocks (6 layers) ==========
        self.icmr_blocks = nn.ModuleList([
            ICMRBlock(
                dim=dim,
                num_latents=num_latents,
                input_dim=dim,
                heads=heads,
                dim_head=dim_head,
                dropout=dropout
            )
            for _ in range(num_icmr_layers)
        ])
        
        # ========== Classification Head ==========
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.cls_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, num_classes)
        )
        
        # ========== Reconstruction Head ==========
        # Query encoding for reconstruction
        self.query_enc = GaussianFourierFeatures(2, 16, scale=gff_scale)
        self.query_proj = nn.Linear(32, dim)
        
        # Decoder cross-attention
        self.decoder_cross_attn = PreNorm(
            dim,
            Attention(dim, dim, heads=heads, dim_head=dim_head, dropout=dropout),
            context_dim=dim
        )
        self.decoder_ff = PreNorm(dim, FeedForward(dim, dropout=dropout))
        
        # Output: predict intensity at query locations
        self.recon_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, 1)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def encode_events(self, x, y, t, p):
        """
        Encode sparse events into dense features.
        
        Args:
            x: [B, N] x-coordinates (normalized to [0, 1])
            y: [B, N] y-coordinates (normalized to [0, 1])
            t: [B, N] timestamps (normalized to [0, 1])
            p: [B, N] polarity (-1 or 1)
        
        Returns:
            features: [B, N, dim]
        """
        B, N = x.shape
        
        # Spatial encoding
        spatial_coords = torch.stack([x, y], dim=-1)  # [B, N, 2]
        spatial_feat = self.spatial_enc(spatial_coords)  # [B, N, 32]
        
        # Temporal encoding
        temporal_feat = self.temporal_enc(t.unsqueeze(-1))  # [B, N, 16]
        
        # Combine with polarity
        p_feat = p.unsqueeze(-1)  # [B, N, 1]
        
        combined = torch.cat([p_feat, spatial_feat, temporal_feat], dim=-1)  # [B, N, 49]
        features = self.input_proj(combined)  # [B, N, dim]
        
        return features
    
    def forward_encoder(self, x, y, t, p, mask=None):
        """
        Forward pass through encoder and ICMR blocks.
        
        Returns:
            latents: [B, num_latents, dim]
            global_feature: [B, dim]
        """
        # Encode events
        features = self.encode_events(x, y, t, p)  # [B, N, dim]
        
        # ICMR: iterative refinement with global feature
        global_feature = None
        latents = None
        
        for block in self.icmr_blocks:
            latents = block(features, global_feature=global_feature, mask=mask)
            # Update global feature (mean pooling)
            global_feature = latents.mean(dim=1)  # [B, dim]
        
        return latents, global_feature
    
    def forward_classification(self, latents, global_feature):
        """
        Classification head.
        
        Args:
            latents: [B, num_latents, dim]
            global_feature: [B, dim]
        
        Returns:
            logits: [B, num_classes]
        """
        # Use global feature for classification
        logits = self.cls_head(global_feature)
        return logits
    
    def forward_reconstruction(self, latents, query_coords):
        """
        Reconstruction head.
        
        Args:
            latents: [B, num_latents, dim]
            query_coords: [B, Q, 2] query coordinates (x, y) normalized to [0, 1]
        
        Returns:
            recon: [B, Q, 1] reconstructed intensities
        """
        # Encode query positions
        query_feat = self.query_enc(query_coords)  # [B, Q, 32]
        query_feat = self.query_proj(query_feat)   # [B, Q, dim]
        
        # Cross-attention: queries attend to latents
        decoded = self.decoder_cross_attn(query_feat, context=latents) + query_feat
        decoded = self.decoder_ff(decoded) + decoded
        
        # Output intensities
        recon = self.recon_head(decoded)
        return recon
    
    def forward(self, x, y, t, p, mask=None, query_coords=None, return_recon=False):
        """
        Full forward pass.
        
        Args:
            x, y, t, p: Event coordinates and polarity
            mask: Optional attention mask
            query_coords: Query coordinates for reconstruction
            return_recon: Whether to return reconstruction
        
        Returns:
            logits: Classification logits
            recon: (optional) Reconstruction output
        """
        latents, global_feature = self.forward_encoder(x, y, t, p, mask)
        logits = self.forward_classification(latents, global_feature)
        
        if return_recon and query_coords is not None:
            recon = self.forward_reconstruction(latents, query_coords)
            return logits, recon
        
        return logits

In [None]:
# Test the model
model = OmniFieldNMNIST(
    dim=64,
    num_latents=64,
    num_icmr_layers=6,
    heads=4,
    dim_head=16,
    num_classes=10,
    dropout=0.1
).to(DEVICE)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,} ({num_params/1e6:.2f}M)")

# Test forward pass
B, N = 4, 1000
test_x = torch.rand(B, N).to(DEVICE)
test_y = torch.rand(B, N).to(DEVICE)
test_t = torch.rand(B, N).to(DEVICE)
test_p = torch.randint(0, 2, (B, N)).float().to(DEVICE) * 2 - 1  # -1 or 1

with torch.no_grad():
    logits = model(test_x, test_y, test_t, test_p)
    print(f"Output shape: {logits.shape}")
    
    # With reconstruction
    Q = 34 * 34
    query_coords = torch.stack(torch.meshgrid(
        torch.linspace(0, 1, 34),
        torch.linspace(0, 1, 34),
        indexing='ij'
    ), dim=-1).reshape(1, Q, 2).expand(B, -1, -1).to(DEVICE)
    
    logits, recon = model(test_x, test_y, test_t, test_p, query_coords=query_coords, return_recon=True)
    print(f"Reconstruction shape: {recon.shape}")

## 5. Dataset Wrapper for Training

In [None]:
class NMNISTDataset(Dataset):
    """
    N-MNIST dataset wrapper for OmniField.
    
    Converts raw events to normalized tensors.
    """
    def __init__(self, tonic_dataset, max_events=2048, spatial_size=34, augment=False):
        self.dataset = tonic_dataset
        self.max_events = max_events
        self.spatial_size = spatial_size
        self.augment = augment
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        events, label = self.dataset[idx]
        
        # Extract event components
        x = events['x'].astype(np.float32)
        y = events['y'].astype(np.float32)
        t = events['t'].astype(np.float32)
        p = events['p'].astype(np.float32)
        
        # Normalize coordinates to [0, 1]
        x = x / (self.spatial_size - 1)
        y = y / (self.spatial_size - 1)
        
        # Normalize time to [0, 1]
        t_min, t_max = t.min(), t.max()
        if t_max > t_min:
            t = (t - t_min) / (t_max - t_min)
        else:
            t = np.zeros_like(t)
        
        # Convert polarity to -1/+1
        p = p * 2 - 1
        
        # Data augmentation
        if self.augment:
            # Random horizontal flip
            if np.random.rand() > 0.5:
                x = 1.0 - x
            # Random time reversal
            if np.random.rand() > 0.5:
                t = 1.0 - t
                p = -p
        
        # Subsample or pad to max_events
        n_events = len(x)
        if n_events > self.max_events:
            # Random subsample
            idx_sub = np.random.choice(n_events, self.max_events, replace=False)
            idx_sub = np.sort(idx_sub)
            x, y, t, p = x[idx_sub], y[idx_sub], t[idx_sub], p[idx_sub]
            mask = np.ones(self.max_events, dtype=np.float32)
        else:
            # Pad with zeros
            pad_len = self.max_events - n_events
            x = np.pad(x, (0, pad_len), constant_values=0)
            y = np.pad(y, (0, pad_len), constant_values=0)
            t = np.pad(t, (0, pad_len), constant_values=0)
            p = np.pad(p, (0, pad_len), constant_values=0)
            mask = np.concatenate([np.ones(n_events), np.zeros(pad_len)]).astype(np.float32)
        
        # Create target image for reconstruction (accumulated events)
        target_img = np.zeros((self.spatial_size, self.spatial_size), dtype=np.float32)
        for i in range(len(events)):
            ex, ey, ep = events['x'][i], events['y'][i], events['p'][i]
            target_img[ey, ex] += (2 * ep - 1)
        # Normalize target
        target_img = np.clip(target_img / 50.0, -1, 1)  # Normalize to [-1, 1]
        
        return {
            'x': torch.from_numpy(x),
            'y': torch.from_numpy(y),
            't': torch.from_numpy(t),
            'p': torch.from_numpy(p),
            'mask': torch.from_numpy(mask),
            'label': torch.tensor(label, dtype=torch.long),
            'target_img': torch.from_numpy(target_img)
        }

In [None]:
def collate_fn(batch):
    """Custom collate function."""
    return {
        'x': torch.stack([b['x'] for b in batch]),
        'y': torch.stack([b['y'] for b in batch]),
        't': torch.stack([b['t'] for b in batch]),
        'p': torch.stack([b['p'] for b in batch]),
        'mask': torch.stack([b['mask'] for b in batch]),
        'label': torch.stack([b['label'] for b in batch]),
        'target_img': torch.stack([b['target_img'] for b in batch])
    }

In [None]:
if TONIC_AVAILABLE:
    # Create datasets
    train_dataset = NMNISTDataset(nmnist_train, max_events=2048, augment=True)
    test_dataset = NMNISTDataset(nmnist_test, max_events=2048, augment=False)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=64, 
        shuffle=True, 
        num_workers=4,
        collate_fn=collate_fn,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=64, 
        shuffle=False, 
        num_workers=4,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    # Test a batch
    batch = next(iter(train_loader))
    print("Batch shapes:")
    for k, v in batch.items():
        print(f"  {k}: {v.shape}")

## 6. Training Functions

In [None]:
def create_query_grid(batch_size, spatial_size=34, device='cpu'):
    """Create query grid for reconstruction."""
    coords = torch.stack(torch.meshgrid(
        torch.linspace(0, 1, spatial_size),
        torch.linspace(0, 1, spatial_size),
        indexing='ij'
    ), dim=-1)  # [H, W, 2]
    coords = coords.reshape(1, spatial_size * spatial_size, 2)  # [1, H*W, 2]
    coords = coords.expand(batch_size, -1, -1).to(device)  # [B, H*W, 2]
    return coords


def train_epoch_e2e(model, loader, optimizer, device, epoch):
    """Train one epoch - End-to-End classification."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc=f"Epoch {epoch} [E2E]")
    for batch in pbar:
        x = batch['x'].to(device)
        y = batch['y'].to(device)
        t = batch['t'].to(device)
        p = batch['p'].to(device)
        mask = batch['mask'].to(device).bool()
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        logits = model(x, y, t, p, mask=mask)
        loss = F.cross_entropy(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        pred = logits.argmax(dim=-1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*correct/total:.2f}%'})
    
    return total_loss / len(loader), 100 * correct / total


def train_epoch_recon(model, loader, optimizer, device, epoch, recon_weight=1.0, cls_weight=0.1):
    """Train one epoch - Reconstruction + Classification."""
    model.train()
    total_loss = 0
    total_recon_loss = 0
    total_cls_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc=f"Epoch {epoch} [Recon]")
    for batch in pbar:
        x = batch['x'].to(device)
        y = batch['y'].to(device)
        t = batch['t'].to(device)
        p = batch['p'].to(device)
        mask = batch['mask'].to(device).bool()
        labels = batch['label'].to(device)
        target_img = batch['target_img'].to(device)  # [B, H, W]
        
        B = x.size(0)
        query_coords = create_query_grid(B, spatial_size=34, device=device)
        
        optimizer.zero_grad()
        
        logits, recon = model(x, y, t, p, mask=mask, query_coords=query_coords, return_recon=True)
        
        # Reconstruction loss
        recon = recon.reshape(B, 34, 34)
        recon_loss = F.mse_loss(recon, target_img)
        
        # Classification loss
        cls_loss = F.cross_entropy(logits, labels)
        
        # Combined loss
        loss = recon_weight * recon_loss + cls_weight * cls_loss
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_cls_loss += cls_loss.item()
        pred = logits.argmax(dim=-1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'recon': f'{recon_loss.item():.4f}',
            'acc': f'{100*correct/total:.2f}%'
        })
    
    return (
        total_loss / len(loader),
        total_recon_loss / len(loader),
        total_cls_loss / len(loader),
        100 * correct / total
    )


@torch.no_grad()
def evaluate(model, loader, device, return_recon=False):
    """Evaluate model."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    recon_samples = []
    
    for batch in tqdm(loader, desc="Evaluating"):
        x = batch['x'].to(device)
        y = batch['y'].to(device)
        t = batch['t'].to(device)
        p = batch['p'].to(device)
        mask = batch['mask'].to(device).bool()
        labels = batch['label'].to(device)
        
        B = x.size(0)
        
        if return_recon:
            query_coords = create_query_grid(B, spatial_size=34, device=device)
            logits, recon = model(x, y, t, p, mask=mask, query_coords=query_coords, return_recon=True)
            if len(recon_samples) < 10:
                recon_samples.append({
                    'recon': recon.cpu(),
                    'target': batch['target_img'],
                    'label': batch['label']
                })
        else:
            logits = model(x, y, t, p, mask=mask)
        
        loss = F.cross_entropy(logits, labels)
        total_loss += loss.item()
        
        pred = logits.argmax(dim=-1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
        
        all_preds.extend(pred.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    accuracy = 100 * correct / total
    avg_loss = total_loss / len(loader)
    
    if return_recon:
        return avg_loss, accuracy, all_preds, all_labels, recon_samples
    return avg_loss, accuracy, all_preds, all_labels

## 7. Approach 1: End-to-End Classification

In [None]:
# Create model for E2E classification
model_e2e = OmniFieldNMNIST(
    dim=64,
    num_latents=64,
    num_icmr_layers=6,
    heads=4,
    dim_head=16,
    num_classes=10,
    dropout=0.1
).to(DEVICE)

print(f"E2E Model parameters: {sum(p.numel() for p in model_e2e.parameters()):,}")

In [None]:
# Training configuration
NUM_EPOCHS_E2E = 30
LR = 1e-3

optimizer_e2e = AdamW(model_e2e.parameters(), lr=LR, weight_decay=0.01)
scheduler_e2e = CosineAnnealingLR(optimizer_e2e, T_max=NUM_EPOCHS_E2E, eta_min=1e-5)

In [None]:
if TONIC_AVAILABLE:
    # Training loop - E2E
    history_e2e = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc = 0
    
    for epoch in range(1, NUM_EPOCHS_E2E + 1):
        train_loss, train_acc = train_epoch_e2e(model_e2e, train_loader, optimizer_e2e, DEVICE, epoch)
        val_loss, val_acc, _, _ = evaluate(model_e2e, test_loader, DEVICE)
        scheduler_e2e.step()
        
        history_e2e['train_loss'].append(train_loss)
        history_e2e['train_acc'].append(train_acc)
        history_e2e['val_loss'].append(val_loss)
        history_e2e['val_acc'].append(val_acc)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model_e2e.state_dict(), 'omnifield_nmnist_e2e_best.pt')
        
        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, "
              f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%")
    
    print(f"\nBest E2E Validation Accuracy: {best_val_acc:.2f}%")

In [None]:
if TONIC_AVAILABLE and len(history_e2e['train_loss']) > 0:
    # Plot E2E training curves
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    axes[0].plot(history_e2e['train_loss'], label='Train')
    axes[0].plot(history_e2e['val_loss'], label='Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('E2E Classification Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    axes[1].plot(history_e2e['train_acc'], label='Train')
    axes[1].plot(history_e2e['val_acc'], label='Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('E2E Classification Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.savefig('e2e_training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()

## 8. Approach 2: Reconstruction + Fine-tuning Classification

In [None]:
# Create model for reconstruction + classification
model_recon = OmniFieldNMNIST(
    dim=64,
    num_latents=64,
    num_icmr_layers=6,
    heads=4,
    dim_head=16,
    num_classes=10,
    dropout=0.1
).to(DEVICE)

print(f"Recon Model parameters: {sum(p.numel() for p in model_recon.parameters()):,}")

In [None]:
# Phase 1: Pre-training with reconstruction (+ light classification)
NUM_EPOCHS_PRETRAIN = 20
LR_PRETRAIN = 1e-3

optimizer_pretrain = AdamW(model_recon.parameters(), lr=LR_PRETRAIN, weight_decay=0.01)
scheduler_pretrain = CosineAnnealingLR(optimizer_pretrain, T_max=NUM_EPOCHS_PRETRAIN, eta_min=1e-5)

In [None]:
if TONIC_AVAILABLE:
    # Phase 1: Reconstruction pre-training
    print("=" * 50)
    print("Phase 1: Reconstruction Pre-training")
    print("=" * 50)
    
    history_pretrain = {'loss': [], 'recon_loss': [], 'cls_loss': [], 'acc': []}
    
    for epoch in range(1, NUM_EPOCHS_PRETRAIN + 1):
        loss, recon_loss, cls_loss, acc = train_epoch_recon(
            model_recon, train_loader, optimizer_pretrain, DEVICE, epoch,
            recon_weight=1.0, cls_weight=0.1
        )
        scheduler_pretrain.step()
        
        history_pretrain['loss'].append(loss)
        history_pretrain['recon_loss'].append(recon_loss)
        history_pretrain['cls_loss'].append(cls_loss)
        history_pretrain['acc'].append(acc)
        
        print(f"Epoch {epoch}: Loss={loss:.4f}, Recon={recon_loss:.4f}, Cls={cls_loss:.4f}, Acc={acc:.2f}%")
    
    # Save pre-trained model
    torch.save(model_recon.state_dict(), 'omnifield_nmnist_pretrained.pt')

In [None]:
# Phase 2: Fine-tuning for classification
NUM_EPOCHS_FINETUNE = 15
LR_FINETUNE = 5e-4

# Optionally freeze encoder and only train classification head
# for param in model_recon.icmr_blocks.parameters():
#     param.requires_grad = False

optimizer_finetune = AdamW(model_recon.parameters(), lr=LR_FINETUNE, weight_decay=0.01)
scheduler_finetune = CosineAnnealingLR(optimizer_finetune, T_max=NUM_EPOCHS_FINETUNE, eta_min=1e-6)

In [None]:
if TONIC_AVAILABLE:
    # Phase 2: Fine-tuning
    print("\n" + "=" * 50)
    print("Phase 2: Fine-tuning for Classification")
    print("=" * 50)
    
    history_finetune = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc_ft = 0
    
    for epoch in range(1, NUM_EPOCHS_FINETUNE + 1):
        train_loss, train_acc = train_epoch_e2e(model_recon, train_loader, optimizer_finetune, DEVICE, epoch)
        val_loss, val_acc, _, _ = evaluate(model_recon, test_loader, DEVICE)
        scheduler_finetune.step()
        
        history_finetune['train_loss'].append(train_loss)
        history_finetune['train_acc'].append(train_acc)
        history_finetune['val_loss'].append(val_loss)
        history_finetune['val_acc'].append(val_acc)
        
        if val_acc > best_val_acc_ft:
            best_val_acc_ft = val_acc
            torch.save(model_recon.state_dict(), 'omnifield_nmnist_finetuned_best.pt')
        
        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, "
              f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%")
    
    print(f"\nBest Fine-tuned Validation Accuracy: {best_val_acc_ft:.2f}%")

In [None]:
if TONIC_AVAILABLE and len(history_pretrain['loss']) > 0:
    # Plot reconstruction pre-training curves
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    axes[0].plot(history_pretrain['recon_loss'], label='Reconstruction')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('MSE Loss')
    axes[0].set_title('Pre-training: Reconstruction Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    axes[1].plot(history_pretrain['cls_loss'], label='Classification', color='orange')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('CE Loss')
    axes[1].set_title('Pre-training: Classification Loss')
    axes[1].legend()
    axes[1].grid(True)
    
    axes[2].plot(history_pretrain['acc'], label='Accuracy', color='green')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Accuracy (%)')
    axes[2].set_title('Pre-training: Classification Accuracy')
    axes[2].legend()
    axes[2].grid(True)
    
    plt.tight_layout()
    plt.savefig('pretrain_curves.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
if TONIC_AVAILABLE and len(history_finetune['train_loss']) > 0:
    # Plot fine-tuning curves
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    axes[0].plot(history_finetune['train_loss'], label='Train')
    axes[0].plot(history_finetune['val_loss'], label='Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Fine-tuning: Classification Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    axes[1].plot(history_finetune['train_acc'], label='Train')
    axes[1].plot(history_finetune['val_acc'], label='Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Fine-tuning: Classification Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.savefig('finetune_curves.png', dpi=150, bbox_inches='tight')
    plt.show()

## 9. Results Visualization and Comparison

In [None]:
if TONIC_AVAILABLE:
    # Load best models and evaluate
    model_e2e.load_state_dict(torch.load('omnifield_nmnist_e2e_best.pt', map_location=DEVICE))
    model_recon.load_state_dict(torch.load('omnifield_nmnist_finetuned_best.pt', map_location=DEVICE))
    
    # Final evaluation
    print("\n" + "=" * 50)
    print("Final Evaluation")
    print("=" * 50)
    
    _, e2e_acc, e2e_preds, e2e_labels = evaluate(model_e2e, test_loader, DEVICE)
    _, ft_acc, ft_preds, ft_labels, recon_samples = evaluate(model_recon, test_loader, DEVICE, return_recon=True)
    
    print(f"\nEnd-to-End Classification Accuracy: {e2e_acc:.2f}%")
    print(f"Reconstruction + Fine-tuning Accuracy: {ft_acc:.2f}%")

In [None]:
if TONIC_AVAILABLE:
    # Confusion matrices
    from sklearn.metrics import confusion_matrix
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    cm_e2e = confusion_matrix(e2e_labels, e2e_preds)
    cm_ft = confusion_matrix(ft_labels, ft_preds)
    
    sns.heatmap(cm_e2e, annot=True, fmt='d', cmap='Blues', ax=axes[0])
    axes[0].set_xlabel('Predicted')
    axes[0].set_ylabel('True')
    axes[0].set_title(f'E2E Classification (Acc: {e2e_acc:.2f}%)')
    
    sns.heatmap(cm_ft, annot=True, fmt='d', cmap='Blues', ax=axes[1])
    axes[1].set_xlabel('Predicted')
    axes[1].set_ylabel('True')
    axes[1].set_title(f'Recon + Fine-tune (Acc: {ft_acc:.2f}%)')
    
    plt.tight_layout()
    plt.savefig('confusion_matrices.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
if TONIC_AVAILABLE and len(recon_samples) > 0:
    # Visualize reconstructions
    fig, axes = plt.subplots(3, 6, figsize=(15, 8))
    
    sample = recon_samples[0]
    
    for i in range(6):
        # Target
        axes[0, i].imshow(sample['target'][i].numpy(), cmap='RdBu_r', vmin=-1, vmax=1)
        axes[0, i].set_title(f'Target (Label: {sample["label"][i].item()})')
        axes[0, i].axis('off')
        
        # Reconstruction
        recon_img = sample['recon'][i].reshape(34, 34).numpy()
        axes[1, i].imshow(recon_img, cmap='RdBu_r', vmin=-1, vmax=1)
        axes[1, i].set_title('Reconstruction')
        axes[1, i].axis('off')
        
        # Difference
        diff = np.abs(sample['target'][i].numpy() - recon_img)
        axes[2, i].imshow(diff, cmap='hot', vmin=0, vmax=0.5)
        axes[2, i].set_title(f'|Error| (MSE: {diff.mean():.4f})')
        axes[2, i].axis('off')
    
    axes[0, 0].set_ylabel('Target', fontsize=12)
    axes[1, 0].set_ylabel('Recon', fontsize=12)
    axes[2, 0].set_ylabel('Error', fontsize=12)
    
    plt.suptitle('Reconstruction Results', fontsize=14)
    plt.tight_layout()
    plt.savefig('reconstruction_results.png', dpi=150, bbox_inches='tight')
    plt.show()

## 10. Summary and Model Architecture

In [None]:
# Print model summary
print("=" * 60)
print("OmniField N-MNIST Model Summary")
print("=" * 60)
print(f"\nArchitecture:")
print(f"  - Base dimension: 64 (small)")
print(f"  - Number of latents: 64")
print(f"  - ICMR layers: 6 (deep)")
print(f"  - Attention heads: 4")
print(f"  - Head dimension: 16")
print(f"  - Total parameters: {sum(p.numel() for p in model_e2e.parameters()):,}")

print(f"\nPositional Encoding:")
print(f"  - Spatial GFF: 2 -> 32 features")
print(f"  - Temporal GFF: 1 -> 16 features")

print(f"\nTraining:")
print(f"  - E2E epochs: {NUM_EPOCHS_E2E}")
print(f"  - Pre-train epochs: {NUM_EPOCHS_PRETRAIN}")
print(f"  - Fine-tune epochs: {NUM_EPOCHS_FINETUNE}")

if TONIC_AVAILABLE:
    print(f"\nResults:")
    print(f"  - E2E Classification: {e2e_acc:.2f}%")
    print(f"  - Recon + Fine-tune: {ft_acc:.2f}%")

In [None]:
# Architecture diagram (text-based)
print("""
╔══════════════════════════════════════════════════════════════════╗
║                  OmniField for N-MNIST                          ║
╠══════════════════════════════════════════════════════════════════╣
║                                                                  ║
║  INPUT: Sparse Events (x, y, t, polarity)                       ║
║         └── N events per sample (~2000 avg)                     ║
║                                                                  ║
║  ┌─────────────────────────────────────────────────────────┐    ║
║  │ ENCODER                                                  │    ║
║  │  ├── Spatial GFF: (x,y) → 32-dim                        │    ║
║  │  ├── Temporal GFF: t → 16-dim                           │    ║
║  │  ├── Polarity: p → 1-dim                                │    ║
║  │  └── MLP: 49 → 64-dim                                   │    ║
║  └─────────────────────────────────────────────────────────┘    ║
║                           ↓                                      ║
║  ┌─────────────────────────────────────────────────────────┐    ║
║  │ ICMR BLOCKS (×6)                                         │    ║
║  │  ├── Sinusoidal Latents: 64 tokens × 64-dim             │    ║
║  │  ├── Cross-Attention: latents ← events                  │    ║
║  │  ├── Global Feature: mean(latents) → 64-dim             │    ║
║  │  ├── Self-Attention: latents ← latents                  │    ║
║  │  └── Feed-Forward: GEGLU                                │    ║
║  │                                                          │    ║
║  │  [Iterative Refinement via Global Feature z]            │    ║
║  └─────────────────────────────────────────────────────────┘    ║
║                           ↓                                      ║
║  ┌─────────────────┐     ┌─────────────────────────────────┐    ║
║  │ CLASSIFICATION  │     │ RECONSTRUCTION                   │    ║
║  │  └── Global     │     │  ├── Query Grid: 34×34          │    ║
║  │      Feature    │     │  ├── Cross-Attn: queries←latents│    ║
║  │      → MLP      │     │  └── MLP → intensity            │    ║
║  │      → 10 cls   │     │                                  │    ║
║  └─────────────────┘     └─────────────────────────────────┘    ║
║                                                                  ║
╚══════════════════════════════════════════════════════════════════╝
""")

---

## Notes

### Key Design Decisions:

1. **Small Base Model (dim=64)**: Keeps the model lightweight while still expressive

2. **Deep ICMR (6 layers)**: Allows for iterative refinement of the global representation
   - Each ICMR block refines latents using cross-attention to events
   - Global feature z is updated via mean pooling and passed to next layer
   - This enables progressive alignment of sparse event information

3. **Gaussian Fourier Features**: Better high-frequency learning than standard positional encoding

4. **Dual Task Training**: 
   - Reconstruction pretraining helps learn better spatial representations
   - Fine-tuning focuses the model on classification

### Expected Results:
- N-MNIST is relatively easy (typical accuracy >95%)
- The reconstruction + fine-tuning approach may provide marginal improvements
- The neural field approach naturally handles the sparse, irregular event data

### Future Improvements:
- Add temporal attention for better handling of event sequences
- Experiment with different ICMR depths and latent sizes
- Try harder datasets like N-Caltech101 or DVS-CIFAR10