# SDF-Based Neural Field Diffusion: Training and Testing

**Key Innovation**: Instead of directly predicting velocity `v(x,t)`, we predict a scalar distance field `f(x,t)` and derive velocity as `v = -∇_x f(x,t)`.

## Benefits of SDF-Based Approach:

1. **Smoother Training**: No directional discontinuities in the output
2. **Implicit Surface**: At t=0, the SDF represents the learned shape
3. **Natural Gradients**: Gradient of scalar field is inherently continuous
4. **Same Capacity**: Mathematically equivalent expressiveness to direct velocity

## Architecture:
- DiT blocks with 3D RoPE, AdaLN, SwiGLU for global context
- NerfBlocks (hyper-network) for local neural field
- Output: scalar `f(x,t) ∈ ℝ` instead of vector `v(x,t) ∈ ℝ³`
- Velocity: `v(x,t) = -∇_x f(x,t)` computed via autograd

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm.notebook import tqdm
import time

# Our modules
from data.toy_data import get_all_generators, generate_multi_sphere_ring
from src.models.sdf_field import SDFNeuralField, SDFFlowMatchingLoss
from src.diffusion.flow_matching import FlowMatchingSampler

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')

## 1. Configuration

In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

# Data
SHAPES = ['multi_sphere_ring']  # 8 spheres in a ring with jitter
N_POINTS = 512
N_SAMPLES = 1000
RANDOM_TRANSFORM = False  # multi_sphere has built-in variation

# Model (SMALL for toy data)
HIDDEN_SIZE = 128
HIDDEN_SIZE_X = 32
NUM_HEADS = 4
NUM_BLOCKS = 6
NUM_COND_BLOCKS = 2
NERF_MLP_RATIO = 2
MAX_FREQS = 6

# Training
EPOCHS = 300
BATCH_SIZE = 32
LR = 1e-4
EIKONAL_WEIGHT = 0.0  # Try 0.1 for SDF regularization

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")
print(f"Training on: {SHAPES}")
print(f"Eikonal weight: {EIKONAL_WEIGHT}")

## 2. Dataset

In [None]:
class PointCloudDataset(Dataset):
    def __init__(self, shapes, n_samples, n_points, noise_std=0.001,
                 random_transform=True, scale_range=(0.7, 1.3)):
        self.shapes = shapes if isinstance(shapes, list) else [shapes]
        self.n_samples = n_samples
        self.n_points = n_points
        self.noise_std = noise_std
        self.random_transform = random_transform
        self.scale_range = scale_range
        
        all_generators = get_all_generators()
        self.generators = [all_generators[s] for s in self.shapes]
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        generator = self.generators[idx % len(self.generators)]
        pc = generator(n_points=self.n_points)
        
        if self.random_transform:
            pc = pc.random_transform(rotate=True, scale_range=self.scale_range, anisotropic=True)
        
        pc = pc.normalize()
        if self.noise_std > 0:
            pc = pc.add_noise(self.noise_std)
        return torch.tensor(pc.points, dtype=torch.float32)

dataset = PointCloudDataset(SHAPES, N_SAMPLES, N_POINTS, random_transform=RANDOM_TRANSFORM)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Dataset: {N_SAMPLES} samples, {N_POINTS} points each")

In [None]:
# Visualize training samples
fig = plt.figure(figsize=(16, 4))
for i in range(4):
    sample = dataset[i].numpy()
    ax = fig.add_subplot(1, 4, i+1, projection='3d')
    colors = sample[:, 2]
    ax.scatter(sample[:, 0], sample[:, 1], sample[:, 2], 
               c=colors, cmap='viridis', s=3, alpha=0.7)
    ax.set_title(f'Sample {i+1}')
    ax.set_xlim([-1.2, 1.2])
    ax.set_ylim([-1.2, 1.2])
    ax.set_zlim([-1.2, 1.2])
    ax.view_init(elev=20, azim=30 + i*20)

plt.suptitle(f'Training Data: {SHAPES[0]} - Each sample has different sphere positions', fontsize=12)
plt.tight_layout()
plt.show()

## 3. SDF Model

In [None]:
# Create SDF model (outputs scalar, velocity via gradient)
model = SDFNeuralField(
    in_channels=3,
    hidden_size=HIDDEN_SIZE,
    hidden_size_x=HIDDEN_SIZE_X,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    num_cond_blocks=NUM_COND_BLOCKS,
    nerf_mlp_ratio=NERF_MLP_RATIO,
    max_freqs=MAX_FREQS,
).to(DEVICE)

n_params = sum(p.numel() for p in model.parameters())
print(f"SDF Model Parameters: {n_params:,}")

# Test forward pass
test_input = torch.randn(2, N_POINTS, 3, device=DEVICE)
test_t = torch.rand(2, device=DEVICE)

# Test SDF output
with torch.no_grad():
    sdf_out = model.forward_sdf(test_input, test_t)
print(f"SDF output shape: {sdf_out.shape} (scalar per point)")

# Test velocity output (via gradient)
vel_out = model.get_velocity(test_input, test_t)
print(f"Velocity output shape: {vel_out.shape} (derived from -∇f)")

## 4. Training

In [None]:
# Setup training
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
loss_fn = SDFFlowMatchingLoss(schedule_type='linear', eikonal_weight=EIKONAL_WEIGHT)
sampler = FlowMatchingSampler(model)

train_losses = []
velocity_losses = []
epoch_times = []

In [None]:
def train_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0.0
    total_vel_loss = 0.0
    n_batches = 0
    
    for batch in dataloader:
        x0 = batch.to(device)
        
        optimizer.zero_grad()
        output = loss_fn(model, x0)
        loss = output['loss']
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_vel_loss += output['velocity_loss']
        n_batches += 1
    
    return total_loss / n_batches, total_vel_loss / n_batches

@torch.no_grad()
def generate_samples(model, sampler, n_samples=4, n_points=256, n_steps=50, device='cpu'):
    model.eval()
    noise = torch.randn(n_samples, n_points, 3, device=device)
    samples = sampler.sample_euler(noise, n_steps=n_steps)
    return samples

In [None]:
# Training loop
print(f"Training SDF model for {EPOCHS} epochs...")
print("="*60)

pbar = tqdm(range(EPOCHS), desc="Training SDF")

for epoch in pbar:
    start_time = time.time()
    loss, vel_loss = train_epoch(model, dataloader, optimizer, loss_fn, DEVICE)
    epoch_time = time.time() - start_time
    
    train_losses.append(loss)
    velocity_losses.append(vel_loss)
    epoch_times.append(epoch_time)
    
    pbar.set_postfix({'loss': f'{loss:.4f}', 'vel': f'{vel_loss:.4f}'})
    
    if (epoch + 1) % 50 == 0:
        tqdm.write(f"Epoch {epoch+1}/{EPOCHS} | Loss: {loss:.4f} | Vel: {vel_loss:.4f}")

print("\n" + "="*60)
print(f"Training complete!")
print(f"Final loss: {train_losses[-1]:.4f}")
print(f"Average epoch time: {np.mean(epoch_times):.2f}s")

In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Total')
plt.plot(velocity_losses, label='Velocity')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('SDF Training Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_losses[10:])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss (after warmup)')
plt.grid(True)

plt.tight_layout()
plt.show()

## 5. Generate Samples

In [None]:
# Generate samples
print("Generating samples from SDF model...")
samples = generate_samples(model, sampler, n_samples=8, n_points=N_POINTS, 
                           n_steps=50, device=DEVICE)
samples = samples.cpu().numpy()
print(f"Generated {samples.shape[0]} samples")

In [None]:
# Compare GT vs Generated
fig = plt.figure(figsize=(16, 8))

for i in range(4):
    gt = dataset[i].numpy()
    ax = fig.add_subplot(2, 4, i + 1, projection='3d')
    ax.scatter(gt[:, 0], gt[:, 1], gt[:, 2], s=2, alpha=0.5, c='blue')
    ax.set_title(f'Ground Truth {i+1}')
    ax.set_xlim([-1.5, 1.5])
    ax.set_ylim([-1.5, 1.5])
    ax.set_zlim([-1.5, 1.5])

for i in range(4):
    ax = fig.add_subplot(2, 4, i + 5, projection='3d')
    ax.scatter(samples[i, :, 0], samples[i, :, 1], samples[i, :, 2], 
               s=2, alpha=0.5, c='red')
    ax.set_title(f'Generated {i+1}')
    ax.set_xlim([-1.5, 1.5])
    ax.set_ylim([-1.5, 1.5])
    ax.set_zlim([-1.5, 1.5])

plt.suptitle('SDF Model: GT (blue) vs Generated (red)', fontsize=14)
plt.tight_layout()
plt.show()

## 6. Visualize Learned SDF

One advantage of the SDF approach: we can visualize the learned distance field at t=0.

In [None]:
@torch.no_grad()
def visualize_sdf_slices(model, device, resolution=50):
    """Visualize SDF at different z-slices."""
    model.eval()
    
    x = torch.linspace(-1.5, 1.5, resolution)
    y = torch.linspace(-1.5, 1.5, resolution)
    xx, yy = torch.meshgrid(x, y, indexing='ij')
    
    z_slices = [-0.5, 0.0, 0.5]
    
    fig, axes = plt.subplots(1, len(z_slices), figsize=(15, 5))
    
    for idx, z_val in enumerate(z_slices):
        zz = torch.full_like(xx, z_val)
        points = torch.stack([xx.flatten(), yy.flatten(), zz.flatten()], dim=-1)
        points = points.unsqueeze(0).to(device)
        
        t = torch.zeros(1, device=device)
        sdf = model.get_sdf(points, t)
        sdf = sdf.squeeze().cpu().numpy().reshape(resolution, resolution)
        
        ax = axes[idx]
        im = ax.contourf(xx.numpy(), yy.numpy(), sdf, levels=20, cmap='RdBu')
        ax.contour(xx.numpy(), yy.numpy(), sdf, levels=[0], colors='black', linewidths=2)
        ax.set_title(f'SDF at z={z_val:.1f}')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_aspect('equal')
        plt.colorbar(im, ax=ax)
    
    plt.suptitle('Learned SDF at t=0 (black contour = zero level set)', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_sdf_slices(model, DEVICE)

## 7. Resolution Independence Test

In [None]:
# Test unconditional generation at different resolutions
print("Testing resolution independence (SDF model)...")

resolutions = [64, 128, 256, 512, 1024]

fig = plt.figure(figsize=(20, 4))

for i, n_pts in enumerate(resolutions):
    noise = torch.randn(1, n_pts, 3, device=DEVICE)
    
    with torch.no_grad():
        model.eval()
        sample = sampler.sample_euler(noise, n_steps=50)
    
    sample = sample.cpu().numpy()[0]
    
    ax = fig.add_subplot(1, 5, i + 1, projection='3d')
    ax.scatter(sample[:, 0], sample[:, 1], sample[:, 2], 
               s=max(1, 5 - i), alpha=0.5, c='green')
    ax.set_title(f'N = {n_pts}')
    ax.set_xlim([-1.5, 1.5])
    ax.set_ylim([-1.5, 1.5])
    ax.set_zlim([-1.5, 1.5])
    
    print(f"  Generated {n_pts:5d} points")

plt.suptitle('SDF Model Resolution Independence', fontsize=14)
plt.tight_layout()
plt.show()

## 8. Summary

### SDF-Based Approach:

1. **Model outputs scalar**: f(x,t) ∈ ℝ instead of v(x,t) ∈ ℝ³

2. **Velocity via gradient**: v(x,t) = -∇_x f(x,t) computed by autograd

3. **Smoother training**: Scalar field gradients are naturally continuous

4. **Implicit surface**: At t=0, the zero level set f(x,0)=0 represents the learned shape

### Comparison to Direct Velocity:

| Aspect | Direct Velocity | SDF-Based |
|--------|----------------|------------|
| Output | v(x,t) ∈ ℝ³ | f(x,t) ∈ ℝ |
| Training | May have discontinuities | Smoother gradients |
| Surface representation | Implicit via flow | Explicit SDF |
| Computation | Direct forward | Forward + autograd |

### When to use SDF:
- Training is unstable with direct velocity
- You want to visualize/use the learned SDF
- You need guaranteed smooth velocity field

In [None]:
# Save checkpoint
shapes_str = '_'.join(SHAPES)
checkpoint = {
    'model_state_dict': model.state_dict(),
    'config': {
        'hidden_size': HIDDEN_SIZE,
        'hidden_size_x': HIDDEN_SIZE_X,
        'num_heads': NUM_HEADS,
        'num_blocks': NUM_BLOCKS,
        'num_cond_blocks': NUM_COND_BLOCKS,
        'nerf_mlp_ratio': NERF_MLP_RATIO,
        'max_freqs': MAX_FREQS,
    },
    'train_losses': train_losses,
    'shapes': SHAPES,
    'n_points': N_POINTS,
    'model_type': 'SDFNeuralField',
}

torch.save(checkpoint, f'../experiments/outputs/sdf_notebook_checkpoint_{shapes_str}.pt')
print(f"Saved SDF checkpoint")