# Neural Field Diffusion: Training and Testing

This notebook trains and tests the neural field diffusion model on toy point cloud data.

**Key Innovation**: Unlike traditional flow matching that predicts velocities at discrete sample points,
we learn a **continuous vector field** `v_θ: ℝ³ × [0,T] → ℝ³` that can be queried at ANY spatial location.

**Architecture (PixNerd-style)**:
1. **Global DiT Blocks**: Points → Shape context `s` (with 3D RoPE, AdaLN, SwiGLU)
2. **NerfBlocks (HyperNetwork)**: `s` → MLP weights (with weight normalization)
3. **Neural Field**: `(x, t, weights)` → Velocity `v(x, t)`

**Key Components from PixNerd**:
- RMSNorm for efficient normalization
- SwiGLU feedforward networks
- 3D Rotary Position Embeddings (adapted from 2D)
- AdaLN modulation for condition injection
- HyperNetwork with weight normalization for stable training

**Important**: Each point is treated as a token, so this is O(N²) attention. We use 256-512 points for testing.

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm.notebook import tqdm
import time

# Our modules
from data.toy_data import (
    get_all_generators, generate_sphere, generate_torus,
    generate_helix, generate_multi_sphere_ring, generate_multi_sphere_cube,
    PointCloud, ManifoldDim
)
from src.models.neural_field import NeuralFieldDiffusion
from src.diffusion.flow_matching import FlowMatchingLoss, FlowMatchingSampler

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')

## 1. Configuration

**Key Settings**:
- `N_POINTS = 256-512` (each point is a token)
- `SHAPES = ['torus']` - use shapes with variation, or multi-shape training
- `RANDOM_TRANSFORM = True` - creates variation via rotation + anisotropic scaling
- **SMALL MODEL** - ~300K params is enough for toy data (vs. 10M+ for real data)

In [None]:
# =============================================================================
# CONFIGURATION - Adjust these as needed
# =============================================================================

# Data - multi_sphere_ring has structured variation (8 spheres with jitter)
SHAPES = ['multi_sphere_ring']  # Try: ['multi_sphere_cube'], ['torus'], or multi-shape
N_POINTS = 512           # Points per cloud (512 for 8 spheres = 64 per sphere)
N_SAMPLES = 1000         # Training samples
RANDOM_TRANSFORM = False # multi_sphere already has built-in variation!
SCALE_RANGE = (0.7, 1.3) # Only used if RANDOM_TRANSFORM=True

# Model (SMALL for toy data - PixNerd-style architecture)
HIDDEN_SIZE = 128        # Transformer hidden dimension (small!)
HIDDEN_SIZE_X = 32       # NerfBlock hidden dimension
NUM_HEADS = 4            # Number of attention heads
NUM_BLOCKS = 6           # Total blocks (2 DiT + 4 NerfBlocks)
NUM_COND_BLOCKS = 2      # DiT blocks (rest are NerfBlocks)
NERF_MLP_RATIO = 2       # MLP ratio for NerfBlocks
MAX_FREQS = 6            # Fourier frequency bands

# Training
EPOCHS = 300             # Training epochs
BATCH_SIZE = 32          # Batch size
LR = 1e-4                # Learning rate

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")
print(f"Training on: {SHAPES}")
print(f"Random transform: {RANDOM_TRANSFORM} (multi_sphere has built-in variation)")

## 2. Dataset

In [None]:
class PointCloudDataset(Dataset):
    """Dataset of point clouds with random transforms for variation."""
    
    def __init__(self, shapes, n_samples, n_points, 
                 noise_std=0.001, random_transform=True, scale_range=(0.7, 1.3)):
        self.shapes = shapes if isinstance(shapes, list) else [shapes]
        self.n_samples = n_samples
        self.n_points = n_points
        self.noise_std = noise_std
        self.random_transform = random_transform
        self.scale_range = scale_range
        
        all_generators = get_all_generators()
        self.generators = [all_generators[s] for s in self.shapes]
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        generator = self.generators[idx % len(self.generators)]
        pc = generator(n_points=self.n_points)
        
        # Apply random transforms for variation!
        if self.random_transform:
            pc = pc.random_transform(
                rotate=True,
                scale_range=self.scale_range,
                anisotropic=True  # sphere -> ellipsoid, etc.
            )
        
        pc = pc.normalize()
        if self.noise_std > 0:
            pc = pc.add_noise(self.noise_std)
        return torch.tensor(pc.points, dtype=torch.float32)

# Create dataset and dataloader
dataset = PointCloudDataset(
    SHAPES, N_SAMPLES, N_POINTS,
    random_transform=RANDOM_TRANSFORM,
    scale_range=SCALE_RANGE
)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Dataset: {SHAPES}")
print(f"Samples: {N_SAMPLES}")
print(f"Points per sample: {N_POINTS}")
print(f"Random transform: {RANDOM_TRANSFORM}")
print(f"Batches per epoch: {len(dataloader)}")

In [None]:
# Visualize training samples - note the VARIATION in sphere positions!
fig = plt.figure(figsize=(16, 4))
for i in range(4):
    sample = dataset[i].numpy()
    ax = fig.add_subplot(1, 4, i+1, projection='3d')
    # Color by z to show structure
    colors = sample[:, 2]
    ax.scatter(sample[:, 0], sample[:, 1], sample[:, 2], 
               c=colors, cmap='viridis', s=3, alpha=0.7)
    ax.set_title(f'Sample {i+1}')
    ax.set_xlim([-1.2, 1.2])
    ax.set_ylim([-1.2, 1.2])
    ax.set_zlim([-1.2, 1.2])
    ax.view_init(elev=20, azim=30 + i*20)

shapes_str = ', '.join(SHAPES)
plt.suptitle(f'Training Data: {shapes_str} ({N_POINTS} pts) - Each sample has DIFFERENT sphere positions!', fontsize=12)
plt.tight_layout()
plt.show()

print("Multi-sphere variation comes from:")
print("  - Position jitter: ±15% offset from base positions")
print("  - Radius jitter: ±30% variation in sphere sizes")
print("  - Ring rotation: random angle offset each sample")

## 3. Model

In [None]:
# Create SMALL model (PixNerd-style: DiT blocks + NerfBlocks)
model = NeuralFieldDiffusion(
    in_channels=3,
    out_channels=3,
    hidden_size=HIDDEN_SIZE,
    hidden_size_x=HIDDEN_SIZE_X,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    num_cond_blocks=NUM_COND_BLOCKS,
    nerf_mlp_ratio=NERF_MLP_RATIO,
    max_freqs=MAX_FREQS,
).to(DEVICE)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Model: NeuralFieldDiffusion (PixNerd-style, SMALL)")
print(f"Architecture: {NUM_COND_BLOCKS} DiT blocks + {NUM_BLOCKS - NUM_COND_BLOCKS} NerfBlocks")
print(f"Parameters: {n_params:,}")
print(f"  (For reference: ~300K is fine for toy data, ~10M+ for real data)")

# Test forward pass
test_input = torch.randn(2, N_POINTS, 3, device=DEVICE)
test_t = torch.rand(2, device=DEVICE)
test_output = model(test_input, test_t)
print(f"\nInput shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")

## 4. Training

In [None]:
# Setup training
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
loss_fn = FlowMatchingLoss(schedule_type='linear')
sampler = FlowMatchingSampler(model)

# Training history
train_losses = []
epoch_times = []

In [None]:
def train_epoch(model, dataloader, optimizer, loss_fn, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    n_batches = 0
    
    for batch in dataloader:
        x0 = batch.to(device)
        
        optimizer.zero_grad()
        output = loss_fn(model, x0)
        loss = output['loss']
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        n_batches += 1
    
    return total_loss / n_batches

@torch.no_grad()
def generate_samples(model, sampler, n_samples=4, n_points=256, n_steps=50, device='cpu'):
    """Generate samples from the model."""
    model.eval()
    noise = torch.randn(n_samples, n_points, 3, device=device)
    samples = sampler.sample_euler(noise, n_steps=n_steps)
    return samples

In [None]:
# Training loop
print(f"Training for {EPOCHS} epochs...")
print("="*60)

pbar = tqdm(range(EPOCHS), desc="Training")

for epoch in pbar:
    start_time = time.time()
    loss = train_epoch(model, dataloader, optimizer, loss_fn, DEVICE)
    epoch_time = time.time() - start_time
    
    train_losses.append(loss)
    epoch_times.append(epoch_time)
    
    pbar.set_postfix({'loss': f'{loss:.4f}'})
    
    # Log every 50 epochs
    if (epoch + 1) % 50 == 0:
        tqdm.write(f"Epoch {epoch+1}/{EPOCHS} | Loss: {loss:.4f} | Time: {epoch_time:.2f}s")

print("\n" + "="*60)
print(f"Training complete!")
print(f"Final loss: {train_losses[-1]:.4f}")
print(f"Average epoch time: {np.mean(epoch_times):.2f}s")

In [None]:
# Plot training curve
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_losses[10:])  # Skip first few for better scale
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss (after warmup)')
plt.grid(True)

plt.tight_layout()
plt.show()

## 5. Generate Samples

In [None]:
# Generate samples
print("Generating samples...")
samples = generate_samples(model, sampler, n_samples=8, n_points=N_POINTS, 
                           n_steps=50, device=DEVICE)
samples = samples.cpu().numpy()
print(f"Generated {samples.shape[0]} samples with {samples.shape[1]} points each")

In [None]:
# Compare ground truth vs generated
fig = plt.figure(figsize=(16, 8))

# Ground truth (top row) - showing variation
for i in range(4):
    gt = dataset[i].numpy()
    ax = fig.add_subplot(2, 4, i + 1, projection='3d')
    ax.scatter(gt[:, 0], gt[:, 1], gt[:, 2], s=2, alpha=0.5, c='blue')
    ax.set_title(f'Ground Truth {i+1}')
    ax.set_xlim([-1.5, 1.5])
    ax.set_ylim([-1.5, 1.5])
    ax.set_zlim([-1.5, 1.5])

# Generated (bottom row)
for i in range(4):
    ax = fig.add_subplot(2, 4, i + 5, projection='3d')
    ax.scatter(samples[i, :, 0], samples[i, :, 1], samples[i, :, 2], 
               s=2, alpha=0.5, c='red')
    ax.set_title(f'Generated {i+1}')
    ax.set_xlim([-1.5, 1.5])
    ax.set_ylim([-1.5, 1.5])
    ax.set_zlim([-1.5, 1.5])

shapes_str = ', '.join(SHAPES)
plt.suptitle(f'{shapes_str}: GT (blue, varied) vs Generated (red)', fontsize=14)
plt.tight_layout()
plt.show()

print("Note: GT samples show variation from random transforms.")
print("Model should learn to generate similar variety.")

## 6. Resolution Independence Test

**Key Point**: This is UNCONDITIONAL generation - no reference shape needed!

The model should generate valid shapes starting from noise at ANY resolution.
We just call `sample_euler` with different sized noise tensors.

In [None]:
# =============================================================================
# Resolution Independence Test - UNCONDITIONAL GENERATION
# =============================================================================
# No reference shape! Just generate from noise at different resolutions.
# The model's forward() handles any number of points.

print("Testing UNCONDITIONAL resolution independence...")
print("(No reference shape - pure generation from noise)\n")

resolutions = [64, 128, 256, 512, 1024]

fig = plt.figure(figsize=(20, 4))

for i, n_pts in enumerate(resolutions):
    # Start from noise at this resolution
    noise = torch.randn(1, n_pts, 3, device=DEVICE)
    
    # Generate using standard Euler sampling - NO reference needed!
    with torch.no_grad():
        model.eval()
        sample = sampler.sample_euler(noise, n_steps=50)
    
    sample = sample.cpu().numpy()[0]
    
    ax = fig.add_subplot(1, 5, i + 1, projection='3d')
    ax.scatter(sample[:, 0], sample[:, 1], sample[:, 2], 
               s=max(1, 5 - i), alpha=0.5, c='green')
    ax.set_title(f'N = {n_pts}')
    ax.set_xlim([-1.5, 1.5])
    ax.set_ylim([-1.5, 1.5])
    ax.set_zlim([-1.5, 1.5])
    
    print(f"  Generated {n_pts:5d} points (unconditional, from noise)")

plt.suptitle('Unconditional Resolution Independence: Same Model, Any Number of Points', fontsize=14)
plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("This is TRUE resolution independence:")
print("- No reference shape")
print("- Pure unconditional generation")
print("- Model handles any N because attention works on variable length")
print("="*60)

## 7. Generation Process Visualization

Watch the ODE integration transform noise into the target manifold.

In [None]:
@torch.no_grad()
def generate_with_trajectory(model, sampler, n_points=256, n_steps=50, device='cpu'):
    """Generate with full trajectory."""
    model.eval()
    noise = torch.randn(1, n_points, 3, device=device)
    trajectory = sampler.sample_euler(noise, n_steps=n_steps, return_trajectory=True)
    return trajectory

# Generate trajectory
trajectory = generate_with_trajectory(model, sampler, n_points=N_POINTS, 
                                      n_steps=50, device=DEVICE)
trajectory = trajectory.cpu().numpy()[:, 0]  # [steps, N, 3]

# Visualize at selected timesteps
n_steps = trajectory.shape[0]
step_indices = [0, n_steps//4, n_steps//2, 3*n_steps//4, n_steps-1]
t_values = [1.0, 0.75, 0.5, 0.25, 0.0]

fig = plt.figure(figsize=(20, 4))

for i, (step_idx, t_val) in enumerate(zip(step_indices, t_values)):
    points = trajectory[step_idx]
    
    ax = fig.add_subplot(1, 5, i + 1, projection='3d')
    color = plt.cm.coolwarm(1 - t_val)
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=2, alpha=0.5, c=[color])
    ax.set_title(f't = {t_val:.2f}')
    ax.set_xlim([-1.5, 1.5])
    ax.set_ylim([-1.5, 1.5])
    ax.set_zlim([-1.5, 1.5])

plt.suptitle('Generation Process: Noise (t=1) → Manifold (t=0)', fontsize=14)
plt.tight_layout()
plt.show()

## 8. Test on Different Shapes

Let's test the model (without retraining) by using context from different shapes.

This shows how the shape context drives generation.

In [None]:
# Get some different shapes for context transfer test
test_shapes = ['sphere', 'torus', 'helix', 'trefoil_knot']
generators = get_all_generators()

fig = plt.figure(figsize=(16, 8))

for i, shape_name in enumerate(test_shapes):
    # Generate a reference point cloud
    pc = generators[shape_name](n_points=N_POINTS).normalize()
    ref_points = torch.tensor(pc.points, dtype=torch.float32, device=DEVICE).unsqueeze(0)
    
    # Get context from this shape
    ref_t = torch.zeros(1, device=DEVICE)
    context = model.get_context(ref_points, ref_t)
    
    # Generate using this context
    with torch.no_grad():
        model.eval()
        generated = sampler.sample_at_resolution(context, n_points=N_POINTS, n_steps=50)
    generated = generated.cpu().numpy()[0]
    ref_np = ref_points.cpu().numpy()[0]
    
    # Plot reference
    ax = fig.add_subplot(2, 4, i + 1, projection='3d')
    ax.scatter(ref_np[:, 0], ref_np[:, 1], ref_np[:, 2], s=2, alpha=0.5, c='blue')
    ax.set_title(f'{shape_name} (input)')
    ax.set_xlim([-1.5, 1.5])
    ax.set_ylim([-1.5, 1.5])
    ax.set_zlim([-1.5, 1.5])
    
    # Plot generated
    ax = fig.add_subplot(2, 4, i + 5, projection='3d')
    ax.scatter(generated[:, 0], generated[:, 1], generated[:, 2], s=2, alpha=0.5, c='red')
    ax.set_title(f'{shape_name} (generated)')
    ax.set_xlim([-1.5, 1.5])
    ax.set_ylim([-1.5, 1.5])
    ax.set_zlim([-1.5, 1.5])

shapes_str = ', '.join(SHAPES)
plt.suptitle(f'Context Transfer Test (model trained on {shapes_str})', fontsize=14)
plt.tight_layout()
plt.show()

print(f"Note: Model was trained ONLY on '{shapes_str}'.")
print("Generation quality depends on how well the learned field generalizes.")

## 9. Summary

### What We Demonstrated:

1. **Multi-Object Scenes**: 8 spheres with position/radius jitter creates meaningful variation

2. **Spatial Relationship Learning**: Model learns:
   - Number and arrangement of spheres
   - Relative positions (ring pattern)
   - Variation in positions and sizes

3. **Small Model**: ~300K params is sufficient for toy data (vs 10M+ for complex real-world shapes)

4. **Neural Field Architecture** (PixNerd-style): 
   - DiT blocks with 3D RoPE, AdaLN, SwiGLU for global context
   - NerfBlocks (hyper-network with weight normalization) for local field

5. **Resolution Independence**: Same model generates at 64, 256, or 1024 points

### Key Observations:

- Multi-sphere has built-in variation (no random transforms needed)
- Model must learn spatial relationships, not just single shape geometry
- Position jitter (±15%) and radius jitter (±30%) create unique samples

### Next Steps:

1. Try other arrangements: `SHAPES = ['multi_sphere_cube']` or `['multi_sphere_random']`
2. Try single shapes with transforms: `SHAPES = ['torus']` with `RANDOM_TRANSFORM = True`
3. Add conditional generation (shape class labels)
4. Scale to more complex shapes (ShapeNet)
5. Extract geometric information (normals, SDF) from learned field

In [None]:
# Save model checkpoint
shapes_str = '_'.join(SHAPES)
checkpoint = {
    'model_state_dict': model.state_dict(),
    'config': {
        'hidden_size': HIDDEN_SIZE,
        'hidden_size_x': HIDDEN_SIZE_X,
        'num_heads': NUM_HEADS,
        'num_blocks': NUM_BLOCKS,
        'num_cond_blocks': NUM_COND_BLOCKS,
        'nerf_mlp_ratio': NERF_MLP_RATIO,
        'max_freqs': MAX_FREQS,
    },
    'train_losses': train_losses,
    'shapes': SHAPES,
    'n_points': N_POINTS,
    'random_transform': RANDOM_TRANSFORM,
}

torch.save(checkpoint, f'../experiments/outputs/notebook_checkpoint_{shapes_str}.pt')
print(f"Saved checkpoint to ../experiments/outputs/notebook_checkpoint_{shapes_str}.pt")