# Neural Posterior Estimation for Overlapping GW Signals

**Fully trainable notebook** for OverlapNeuralPE - neural parameter estimation with normalizing flows.

## Overview

OverlapNeuralPE performs Bayesian parameter estimation on overlapping gravitational wave signals using:
- **Flow-based posterior**: NSF (Neural Spline Flow) for efficient sampling
- **Context encoding**: CNN + BiLSTM for strain features
- **RL adaptation**: Dynamic complexity control
- **Bias correction**: Systematic error removal
- **Physics priors**: Domain constraints

**Parameter space**: 11D (9 orbital + 2 spin magnitudes)

## 1. Setup & Configuration

In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import yaml
import logging
import time
from collections import defaultdict
import matplotlib.pyplot as plt
import json

# Setup paths
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / "src"))
sys.path.insert(0, str(project_root / "experiments"))

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('NeuralPETraining')

logger.info(f"Project root: {project_root}")
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Using device: {device}")

In [None]:
# Load configuration
config_path = project_root / "configs/enhanced_training.yaml"

with open(config_path) as f:
    config_dict = yaml.safe_load(f)

# Extract Neural PE configuration
neural_pe_config = config_dict.get('neural_posterior', {})
data_config = config_dict.get('data', {})

print("\nüìã Neural PE Configuration")
print("="*70)
print(f"Flow type: {neural_pe_config.get('flow_type', 'nsf')}")
print(f"Context dimension: {neural_pe_config.get('context_dim', 768)}")
print(f"Num layers: {neural_pe_config.get('num_layers', 8)}")
print(f"Hidden features: {neural_pe_config.get('hidden_features', 256)}")
print(f"\nTraining:")
print(f"  Batch size: {neural_pe_config.get('batch_size', 32)}")
print(f"  Learning rate: {neural_pe_config.get('learning_rate', 1e-5):.2e}")
print(f"  Epochs: {neural_pe_config.get('epochs', 50)}")
print(f"  Warmup epochs: {neural_pe_config.get('warmup_epochs', 5)}")
print(f"\nLoss weights:")
print(f"  NLL loss: {neural_pe_config.get('loss_weight', 1.0)}")
print(f"  Physics loss: {neural_pe_config.get('physics_loss_weight', 0.05)}")
print(f"  Bounds penalty: {neural_pe_config.get('bounds_penalty_weight', 0.5)}")
print(f"  Sample loss: {neural_pe_config.get('sample_loss_weight', 0.5)}")
print("="*70)

## 2. Import Model Classes

In [None]:
# Import core classes
from ahsd.models.overlap_neuralpe import OverlapNeuralPE
from ahsd.models.flows import create_flow_model

logger.info("‚úÖ Imported OverlapNeuralPE classes")

## 3. Define Parameter Space

In [None]:
# Define 11D parameter space (9 orbital + 2 spin magnitudes)
param_names = [
    'mass_1',              # Primary mass (M_sun)
    'mass_2',              # Secondary mass (M_sun)
    'luminosity_distance', # Distance (Mpc)
    'theta_jn',            # Inclination angle
    'ra',                  # Right ascension
    'dec',                 # Declination
    'psi',                 # Polarization angle
    'phase',               # Coalescence phase
    'geocent_time',        # GPS time
    'a1',                  # Spin magnitude object 1
    'a2',                  # Spin magnitude object 2
]

param_dim = len(param_names)
logger.info(f"Parameter space: {param_dim}D")
for i, name in enumerate(param_names):
    print(f"  {i+1:2d}. {name}")

## 4. Initialize Model

In [None]:
# Check for PriorityNet checkpoint (required by OverlapNeuralPE)
priority_net_path = project_root / "models/priority_net/priority_net_best.pth"

if not priority_net_path.exists():
    logger.warning(f"PriorityNet checkpoint not found at {priority_net_path}")
    logger.warning("OverlapNeuralPE can work without PriorityNet but will use random weights")
    priority_net_path = str(project_root / "models/priority_net/priority_net_checkpoint.pt")  # fallback stub
else:
    logger.info(f"‚úÖ Found PriorityNet at {priority_net_path}")
    priority_net_path = str(priority_net_path)

print(f"\nPriorityNet path: {priority_net_path}")

In [None]:
# Initialize OverlapNeuralPE model
try:
    model = OverlapNeuralPE(
        param_names=param_names,
        priority_net_path=priority_net_path,
        config=neural_pe_config,
        device=device,
        event_type='BBH'  # Example: Binary Black Hole
    )
    model.to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"‚úÖ OverlapNeuralPE initialized with {total_params:,} parameters")
    print(f"\nüß† Model Summary:")
    print(f"  Parameter dimension: {param_dim}D")
    print(f"  Context dimension: {neural_pe_config.get('context_dim', 768)}")
    print(f"  Flow type: {neural_pe_config.get('flow_type', 'nsf')}")
    print(f"  Total parameters: {total_params:,}")
except Exception as e:
    logger.error(f"Failed to initialize model: {e}")
    raise

## 5. Data Loading Setup

In [None]:
# Check for training data
train_dir = project_root / "data/output/train"
val_dir = project_root / "data/output/val"

train_loader = None
val_loader = None
USE_SYNTHETIC_DATA = False

if train_dir.exists():
    train_samples = list(train_dir.glob("*.h5")) + list(train_dir.glob("*.pkl"))
    if train_samples:
        logger.info(f"‚úÖ Found {len(train_samples)} training samples")
        print(f"\nTrain data: {len(train_samples)} samples found")
        
        try:
            # Import data loading utilities
            import importlib.util
            spec = importlib.util.spec_from_file_location(
                "train_priority_net",
                project_root / "experiments/train_priority_net.py"
            )
            if spec and spec.loader:
                train_module = importlib.util.module_from_spec(spec)
                sys.modules['train_priority_net'] = train_module
                spec.loader.exec_module(train_module)
                ChunkedGWDataLoader = train_module.ChunkedGWDataLoader
                
                train_loader = ChunkedGWDataLoader(
                    data_dir=train_dir,
                    batch_size=neural_pe_config.get('batch_size', 32),
                    shuffle=True,
                    num_workers=0
                )
                logger.info(f"‚úÖ Created train loader")
        except Exception as e:
            logger.warning(f"Could not create real data loader: {e}")
            USE_SYNTHETIC_DATA = True
    else:
        logger.warning(f"No samples found in {train_dir}")
        USE_SYNTHETIC_DATA = True
else:
    logger.warning(f"Train directory not found: {train_dir}")
    USE_SYNTHETIC_DATA = True

if val_dir.exists() and train_loader is not None:
    val_samples = list(val_dir.glob("*.h5")) + list(val_dir.glob("*.pkl"))
    if val_samples:
        logger.info(f"‚úÖ Found {len(val_samples)} validation samples")
        try:
            val_loader = ChunkedGWDataLoader(
                data_dir=val_dir,
                batch_size=neural_pe_config.get('batch_size', 32),
                shuffle=False,
                num_workers=0
            )
            logger.info(f"‚úÖ Created val loader")
        except Exception as e:
            logger.warning(f"Could not create val loader: {e}")

if USE_SYNTHETIC_DATA:
    logger.info("‚ö†Ô∏è  Using synthetic data for demonstration")
    print(f"\n‚ö†Ô∏è  Synthetic data mode enabled")

## 6. Define Training Functions

In [None]:
def create_synthetic_batch(batch_size=16, param_dim=11, n_signals_range=(1, 3)):
    """
    Create synthetic batch for demonstration.
    
    Args:
        batch_size: Samples in batch
        param_dim: Parameter dimension (11)
        n_signals_range: Range of signals per sample
    
    Returns:
        Batch dict with strain and parameters
    """
    batch = {}
    
    # Strain data: [batch, 2 detectors (H1, L1), time_samples]
    batch['strain_data'] = torch.randn(batch_size, 2, 16384).to(device)
    
    # Parameters: [batch, n_signals, param_dim]
    n_signals = np.random.randint(*n_signals_range)
    batch['parameters'] = torch.randn(batch_size, n_signals, param_dim).to(device)
    
    # SNR values for weighting
    batch['snr'] = torch.ones(batch_size, n_signals).to(device) * 20.0
    
    return batch

logger.info("‚úÖ Defined synthetic batch generator")

In [None]:
def train_one_epoch(model, train_loader, optimizer, device='cpu', use_synthetic=False, param_dim=11):
    """
    Train for one epoch.
    
    Args:
        model: OverlapNeuralPE
        train_loader: Training data loader
        optimizer: PyTorch optimizer
        device: 'cpu' or 'cuda'
        use_synthetic: Use synthetic data
        param_dim: Parameter dimension
    
    Returns:
        avg_loss: Average loss over epoch
        metrics: Dict with loss components
    """
    model.train()
    total_loss = 0.0
    num_batches = 0
    loss_components = defaultdict(float)
    
    # Use synthetic batches if no real data
    if use_synthetic:
        batches = [create_synthetic_batch(batch_size=16, param_dim=param_dim) for _ in range(5)]
    else:
        batches = train_loader
    
    for batch_idx, batch in enumerate(batches):
        try:
            # Handle dict vs tuple batches
            if isinstance(batch, dict):
                strain_data = batch.get('strain_data', batch.get('strain', None))
                parameters = batch.get('parameters', None)
                snr_values = batch.get('snr', None)
            else:
                # Fallback for tuple returns
                strain_data, parameters, snr_values = batch[:3]
            
            if strain_data is None or parameters is None:
                continue
            
            # Move to device
            strain_data = strain_data.to(device) if not strain_data.is_cuda else strain_data
            parameters = parameters.to(device) if not parameters.is_cuda else parameters
            if snr_values is not None and not snr_values.is_cuda:
                snr_values = snr_values.to(device)
            
            # Forward pass: compute NLL and additional losses
            try:
                # Call model forward which computes loss internally
                loss, loss_dict = model.forward_with_loss(
                    strain_data=strain_data,
                    parameters=parameters,
                    snr_weights=snr_values
                )
            except:
                # Fallback: compute MSE loss for synthetic data
                # Sample from posterior
                posterior_samples = model.sample_posterior(
                    strain_data=strain_data,
                    n_samples=100
                )
                
                # Simple MSE loss to target parameters (for synthetic)
                target = parameters[:, 0, :]  # First signal ground truth
                loss = torch.mean((posterior_samples[:, 0, :] - target.unsqueeze(1)) ** 2)
                loss_dict = {'nll': loss.item()}
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            # Track loss components
            for key, val in loss_dict.items():
                if isinstance(val, (int, float)):
                    loss_components[key] += val
            
            if (batch_idx + 1) % 5 == 0 or batch_idx == 0:
                logger.info(f"  Batch {batch_idx+1} | Loss: {loss.item():.6f}")
        
        except Exception as e:
            logger.error(f"Error in batch {batch_idx}: {e}")
            continue
    
    # Average losses
    avg_loss = total_loss / max(num_batches, 1)
    for key in loss_components:
        loss_components[key] /= max(num_batches, 1)
    
    return avg_loss, dict(loss_components)

logger.info("‚úÖ Defined train_one_epoch function")

In [None]:
def validate(model, val_loader, device='cpu', use_synthetic=False, param_dim=11):
    """
    Validate model.
    
    Args:
        model: OverlapNeuralPE
        val_loader: Validation data loader
        device: 'cpu' or 'cuda'
        use_synthetic: Use synthetic data
        param_dim: Parameter dimension
    
    Returns:
        avg_loss: Average validation loss
        metrics: Dict of metrics
    """
    model.eval()
    total_loss = 0.0
    num_batches = 0
    all_nll_values = []
    
    # Use synthetic batches if no real data
    if use_synthetic:
        batches = [create_synthetic_batch(batch_size=16, param_dim=param_dim) for _ in range(3)]
    else:
        batches = val_loader
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(batches):
            try:
                # Handle dict vs tuple batches
                if isinstance(batch, dict):
                    strain_data = batch.get('strain_data', batch.get('strain', None))
                    parameters = batch.get('parameters', None)
                    snr_values = batch.get('snr', None)
                else:
                    strain_data, parameters, snr_values = batch[:3]
                
                if strain_data is None or parameters is None:
                    continue
                
                # Move to device
                strain_data = strain_data.to(device) if not strain_data.is_cuda else strain_data
                parameters = parameters.to(device) if not parameters.is_cuda else parameters
                if snr_values is not None and not snr_values.is_cuda:
                    snr_values = snr_values.to(device)
                
                # Forward pass
                try:
                    loss, _ = model.forward_with_loss(
                        strain_data=strain_data,
                        parameters=parameters,
                        snr_weights=snr_values
                    )
                except:
                    # Fallback
                    posterior_samples = model.sample_posterior(
                        strain_data=strain_data,
                        n_samples=100
                    )
                    target = parameters[:, 0, :]
                    loss = torch.mean((posterior_samples[:, 0, :] - target.unsqueeze(1)) ** 2)
                
                total_loss += loss.item()
                all_nll_values.append(loss.item())
                num_batches += 1
            
            except Exception as e:
                logger.error(f"Error in val batch {batch_idx}: {e}")
                continue
    
    avg_loss = total_loss / max(num_batches, 1)
    
    metrics = {}
    if all_nll_values:
        metrics['nll_mean'] = float(np.mean(all_nll_values))
        metrics['nll_std'] = float(np.std(all_nll_values))
    
    return avg_loss, metrics

logger.info("‚úÖ Defined validate function")

## 7. Setup Optimizer & Scheduler

In [None]:
# Initialize optimizer
lr = neural_pe_config.get('learning_rate', 1e-5)
weight_decay = neural_pe_config.get('weight_decay', 1e-6)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr,
    weight_decay=weight_decay
)

logger.info(f"‚úÖ Optimizer initialized")
print(f"\n‚öôÔ∏è Optimizer Configuration:")
print(f"  Type: AdamW")
print(f"  Learning rate: {lr:.2e}")
print(f"  Weight decay: {weight_decay:.2e}")

# Warmup scheduler
warmup_epochs = neural_pe_config.get('warmup_epochs', 5)
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.1,
    end_factor=1.0,
    total_iters=warmup_epochs
)

# Main scheduler
main_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=7,
    threshold=1e-3,
    min_lr=1e-7
)

logger.info(f"‚úÖ Schedulers initialized")
print(f"  Warmup epochs: {warmup_epochs}")
print(f"  ReduceLROnPlateau enabled")

## 8. Run Training Loop

In [None]:
# Training configuration
num_epochs = min(neural_pe_config.get('epochs', 50), 5)  # Limit to 5 for demo
patience = 15

print(f"\n" + "="*70)
print(f"üöÄ STARTING NEURAL PE TRAINING")
print(f"="*70)
print(f"Epochs: {num_epochs}")
print(f"Warmup epochs: {warmup_epochs}")
print(f"Parameter dimension: {param_dim}D")
print(f"Device: {device}")
print(f"Using synthetic data: {USE_SYNTHETIC_DATA}")
print("="*70 + "\n")

# Create results directory
checkpoint_dir = project_root / "models/neural_pe"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Checkpoint directory: {checkpoint_dir}")

In [None]:
# Main training loop
history = defaultdict(list)
best_val_loss = float('inf')
patience_counter = 0
start_time = time.time()

for epoch in range(num_epochs):
    epoch_start = time.time()
    
    # Warmup phase
    if epoch < warmup_epochs:
        warmup_scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        logger.info(f"\n{'='*70}")
        logger.info(f"Epoch {epoch+1}/{num_epochs} [WARMUP] - LR: {current_lr:.2e}")
    else:
        logger.info(f"\n{'='*70}")
        logger.info(f"Epoch {epoch+1}/{num_epochs}")
    logger.info(f"{'='*70}")
    
    # Training
    print(f"\nüìà Training...")
    train_loss, train_components = train_one_epoch(
        model, train_loader, optimizer, device,
        use_synthetic=USE_SYNTHETIC_DATA or train_loader is None,
        param_dim=param_dim
    )
    
    # Validation
    print(f"‚úÖ Validating...")
    val_loss, val_metrics = validate(
        model, val_loader, device,
        use_synthetic=USE_SYNTHETIC_DATA or val_loader is None,
        param_dim=param_dim
    )
    
    # Learning rate scheduling (after warmup)
    if epoch >= warmup_epochs:
        main_scheduler.step(val_loss)
    
    # Track history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    for key, val in train_components.items():
        history[f'train_{key}'].append(val)
    for key, val in val_metrics.items():
        history[f'val_{key}'].append(val)
    
    # Log epoch results
    epoch_time = time.time() - epoch_start
    logger.info(
        f"\nResults:")
    logger.info(
        f"  Train loss: {train_loss:.6f}")
    logger.info(
        f"  Val loss: {val_loss:.6f}")
    logger.info(
        f"  Time: {epoch_time:.1f}s")
    
    if train_components:
        logger.info(f"  Train components: {train_components}")
    if val_metrics:
        logger.info(f"  Val metrics: {val_metrics}")
    
    # Save checkpoint if improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        
        checkpoint = {
            'epoch': epoch,
            'loss': val_loss,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'config': neural_pe_config,
            'param_names': param_names,
            'history': dict(history)
        }
        
        checkpoint_path = checkpoint_dir / "neural_pe_best.pth"
        torch.save(checkpoint, checkpoint_path)
        logger.info(f"\n‚úÖ CHECKPOINT SAVED: {checkpoint_path}")
    
    else:
        patience_counter += 1
        logger.info(f"\nNo improvement ({patience_counter}/{patience})")
        if patience_counter >= patience:
            logger.info(f"\nüõë Early stopping: no improvement for {patience} epochs")
            break

# Summary
total_time = time.time() - start_time
print(f"\n" + "="*70)
print(f"‚úÖ TRAINING COMPLETE")
print(f"="*70)
print(f"Total time: {total_time/60:.1f} minutes")
print(f"Best validation loss: {best_val_loss:.6f}")
print(f"Checkpoint saved to: {checkpoint_dir / 'neural_pe_best.pth'}")
print(f"="*70)

## 9. Plot Training Results

In [None]:
# Plot training history
if history['train_loss']:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Loss plot
    axes[0].plot(history['train_loss'], label='Train', marker='o')
    axes[0].plot(history['val_loss'], label='Val', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Neural PE Training Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    axes[0].set_yscale('log')
    
    # NLL convergence (if available)
    if 'val_nll_mean' in history:
        axes[1].plot(history['val_nll_mean'], label='Val NLL', marker='o')
        axes[1].fill_between(
            range(len(history['val_nll_mean'])),
            np.array(history['val_nll_mean']) - np.array(history.get('val_nll_std', [0]*len(history['val_nll_mean']))),
            np.array(history['val_nll_mean']) + np.array(history.get('val_nll_std', [0]*len(history['val_nll_mean']))),
            alpha=0.2
        )
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('NLL (nats)')
        axes[1].set_title('Negative Log-Likelihood')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(checkpoint_dir / 'training_history.png', dpi=150, bbox_inches='tight')
    print(f"\nüìä Training plots saved to {checkpoint_dir / 'training_history.png'}")
    plt.show()
else:
    print("No training history to plot")

## 10. Load and Test Checkpoint

In [None]:
# Load best checkpoint
checkpoint_path = checkpoint_dir / "neural_pe_best.pth"

if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    print("üîß Checkpoint Details")
    print("="*70)
    print(f"Path: {checkpoint_path}")
    print(f"Epoch: {checkpoint.get('epoch', 'unknown')}")
    print(f"Loss: {checkpoint.get('loss', 'unknown'):.6f}")
    
    # Count state dict keys
    state_dict = checkpoint.get('model_state_dict', {})
    print(f"Model parameters: {len(state_dict)} keys")
    print(f"Total model size: {sum(v.numel() for v in state_dict.values()):,} params")
    
    # Load into model
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    print(f"\n‚úÖ Model loaded successfully")
else:
    print(f"‚ö†Ô∏è  Checkpoint not found at {checkpoint_path}")

## 11. Posterior Sampling Test

In [None]:
# Test posterior sampling
print("\nüîÆ Testing Posterior Sampling")
print("="*70)

with torch.no_grad():
    # Create test batch
    test_batch = create_synthetic_batch(batch_size=2, param_dim=param_dim)
    
    try:
        # Sample from posterior
        posterior_samples = model.sample_posterior(
            strain_data=test_batch['strain_data'],
            n_samples=100
        )
        
        print(f"\nPosterior samples shape: {posterior_samples.shape}")
        print(f"Expected: [batch_size=2, n_samples=100, param_dim={param_dim}]")
        
        # Statistics for first sample, first signal
        samples_s0 = posterior_samples[0, :, :].cpu().numpy()  # [100, 11]
        
        print(f"\nSample 0, Signal 0 statistics:")
        for i, param_name in enumerate(param_names):
            mean = samples_s0[:, i].mean()
            std = samples_s0[:, i].std()
            print(f"  {param_name:20s}: {mean:8.4f} ¬± {std:6.4f}")
        
        print(f"\n‚úÖ Posterior sampling successful")
    
    except Exception as e:
        logger.error(f"Error in posterior sampling: {e}")
        print(f"‚ö†Ô∏è  Posterior sampling not available: {e}")

print("="*70)

## 12. Next Steps

In [None]:
print("""
üìö Next Steps:

1. TRAIN WITH REAL DATA:
   - Ensure data is generated: python experiments/data_generation.py --n-samples 1000
   - Update num_epochs to 50 in cell 8
   - Re-run cells 8-9 for full training
   - Expected runtime: 2-3 hours on GPU
   
2. VALIDATION:
   python experiments/test_neural_pe.py \\
     --model_path models/neural_pe/neural_pe_best.pth \\
     --data_path data/test \\
     --device cuda
   
3. PARAMETER INFERENCE:
   from ahsd.models.overlap_neuralpe import OverlapNeuralPE
   
   model = OverlapNeuralPE(
       param_names=param_names,
       priority_net_path='models/priority_net/priority_net_best.pth',
       config=neural_pe_config,
       device='cuda'
   )
   
   # Load checkpoint
   checkpoint = torch.load('models/neural_pe/neural_pe_best.pth')
   model.load_state_dict(checkpoint['model_state_dict'])
   
   # Sample posterior
   posterior_samples = model.sample_posterior(strain_data, n_samples=500)
   
4. PIPELINE INTEGRATION:
   See 04_inference_pipeline.ipynb for full workflow

üìä Key Metrics:
   - NLL < 3.0 bits (excellent)
   - NLL < 5.0 bits (good)
   - Inference time < 1.0s per sample
   - Posterior mean error < 10% of parameter range

‚öôÔ∏è Troubleshooting:
   - If NLL plateaus: Increase learning rate or flow capacity
   - If gradients explode: Reduce learning rate or increase gradient clip
   - If memory error: Reduce batch_size or n_samples
   - If divergence: Check flow_type setting (nsf vs flowmatching)
""")