# Neural Operator Training Demo: CDON Dataset

This notebook demonstrates end-to-end training of neural operator models (DeepONet, FNO, UNet) on the CDON dataset.

**Features:**
- Trains on **real CDON data**
- **Configurable loss functions** (Baseline, BSP, SA-BSP)
- Minimal custom code - reuses existing codebase
- Includes visualizations of training progress
- Compatible with Google Colab

**Models available:**
- `deeponet`: Branch-trunk architecture (~235K params)
- `fno`: Fourier Neural Operator (~261K params)
- `unet`: Encoder-decoder with skip connections (~249K params)

## Cell 1: Setup & Imports (Colab-Ready)

In [None]:
# Google Colab setup
import sys
import os
from pathlib import Path

# Ensure we're in /content
try:
    os.chdir('/content')
except:
    pass

# Clone repository if running in Colab
repo_path = Path('/content/local')
if not repo_path.exists():
    print("ðŸ“¥ Cloning repository...")
    !git clone https://github.com/maximbeekenkamp/local.git
    print("âœ… Repository cloned")
else:
    print("âœ… Repository exists")

# Change to repo directory
try:
    os.chdir('/content/local')
    print(f"âœ… Changed to: {os.getcwd()}")
except:
    pass

# Install dependencies
print("\nðŸ“¦ Installing dependencies...")
!pip install -r requirements.txt -q
print("âœ… Dependencies installed")

# Standard imports
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

# Project imports
from src.core.data_processing.cdon_dataset import CDONDataset
from src.core.data_processing.cdon_transforms import CDONNormalization
from src.core.models.model_factory import create_model
from src.core.training.simple_trainer import SimpleTrainer
from configs.training_config import TrainingConfig

print("\nâœ“ Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Cell 2: Load Real CDON Data

In [None]:
# Get project root
project_root = Path.cwd()
print(f"Project root: {project_root}")

# Data directory
DATA_DIR = project_root / 'CDONData'
print(f"Data directory: {DATA_DIR}")

# Create normalization object (required by CDONDataset)
stats_path = project_root / 'configs' / 'cdon_stats.json'
print(f"Loading stats from: {stats_path}")
normalizer = CDONNormalization(stats_path=str(stats_path))

# Create datasets
train_dataset = CDONDataset(
    data_dir=str(DATA_DIR),
    split='train',
    normalize=normalizer  # Pass normalizer object, not boolean
)

val_dataset = CDONDataset(
    data_dir=str(DATA_DIR),
    split='test',
    normalize=normalizer
)

# Create dataloaders
BATCH_SIZE = 16

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"\nâœ“ Data loaded successfully")
print(f"  Train samples: {len(train_dataset)}")
print(f"  Val samples: {len(val_dataset)}")
print(f"  Batch size: {BATCH_SIZE}")

# Inspect a sample
sample_input, sample_target = train_dataset[0]
print(f"\nSample shapes:")
print(f"  Input: {sample_input.shape}")
print(f"  Target: {sample_target.shape}")

## Cell 3: Choose Model Architecture

**Change `MODEL_ARCH` to try different models:**
- `'deeponet'`: Branch-trunk architecture
- `'fno'`: Fourier Neural Operator
- `'unet'`: U-Net encoder-decoder

In [None]:
# Choose model architecture
MODEL_ARCH = 'deeponet'  # Options: 'deeponet', 'fno', 'unet'

model = create_model(MODEL_ARCH)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"âœ“ Created {MODEL_ARCH.upper()} model")
print(f"  Parameters: {num_params:,}")

## Cell 4A: Import Loss Configurations

In [None]:
from configs.loss_config import BASELINE_CONFIG, BSP_CONFIG, SA_BSP_CONFIG
from src.core.evaluation.loss_factory import create_loss

print("âœ“ Loss configurations imported successfully")
print("\nAvailable loss types:")
print(f"  1. BASELINE: {BASELINE_CONFIG.description}")
print(f"  2. BSP:      {BSP_CONFIG.description}")
print(f"  3. SA-BSP:   {SA_BSP_CONFIG.description}")

## Cell 4B: Select Loss Type

**Change `LOSS_TYPE` below to experiment:**
- `'baseline'`: Standard Relative L2 loss (default)
- `'bsp'`: Binned Spectral Power loss
- `'sa-bsp'`: Self-Adaptive BSP with learnable weights

In [None]:
# Choose loss type (CHANGE THIS TO EXPERIMENT)
LOSS_TYPE = 'baseline'  # Options: 'baseline', 'bsp', 'sa-bsp'

# Map to configuration
loss_config_map = {
    'baseline': BASELINE_CONFIG,
    'bsp': BSP_CONFIG,
    'sa-bsp': SA_BSP_CONFIG
}

if LOSS_TYPE not in loss_config_map:
    raise ValueError(f"Invalid LOSS_TYPE: '{LOSS_TYPE}'")

selected_loss_config = loss_config_map[LOSS_TYPE]

print(f"âœ“ Selected loss type: {LOSS_TYPE.upper()}")
print(f"  Description: {selected_loss_config.description}")

## Cell 4C: Create and Validate Loss Function

In [None]:
# Create loss function
criterion = create_loss(selected_loss_config)

print(f"âœ“ Loss function created: {type(criterion).__name__}")

# Validate with dummy tensors
dummy_pred = torch.randn(4, 1, 1000)
dummy_target = torch.randn(4, 1, 1000)

test_loss = criterion(dummy_pred, dummy_target)
print(f"âœ“ Dummy loss value: {test_loss.item():.6f}")
print(f"âœ“ Loss is finite: {torch.isfinite(test_loss).item()}")

## Cell 4D: Test Loss on Real Data

In [None]:
# Test on real CDON data
sample_batch_input, sample_batch_target = next(iter(train_loader))

real_data_loss = criterion(sample_batch_input, sample_batch_target)

print(f"âœ“ Loss on real data: {real_data_loss.item():.6f}")
print(f"âœ“ Loss is finite: {torch.isfinite(real_data_loss).item()}")
print(f"âœ“ Ready for training!")

## Cell 5: Configure Training

In [None]:
config = TrainingConfig(
    num_epochs=50,
    learning_rate=1e-3,
    batch_size=BATCH_SIZE,
    weight_decay=1e-4,
    scheduler_type='cosine',
    cosine_eta_min=1e-6,
    eval_metrics=['field_error', 'spectrum_error'],
    eval_frequency=1,
    checkpoint_dir=f'checkpoints/{MODEL_ARCH}',
    save_best=False,      # Disabled for debugging
    save_latest=False,    # Disabled for debugging
    device='cuda' if torch.cuda.is_available() else 'cpu',
    num_workers=2,
    verbose=True
)

print(f"âœ“ Training configuration:")
print(f"  Epochs: {config.num_epochs}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Device: {config.device}")
print(f"  âš  Checkpointing: DISABLED")

## Cell 6: Create Trainer and Train

In [None]:
# Create trainer with loss_config (required parameter)
trainer = SimpleTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    loss_config=selected_loss_config,  # Required parameter
    experiment_name=f'{MODEL_ARCH}_{LOSS_TYPE}'
)

print(f"âœ“ Trainer initialized")
print(f"  Device: {trainer.device}")
print(f"  Optimizer: {type(trainer.optimizer).__name__}")
print(f"  Loss: {type(trainer.criterion).__name__}")

# Check for weight optimizer (SA-BSP only)
if LOSS_TYPE == 'sa-bsp':
    if trainer.weight_optimizer is not None:
        print(f"  Weight optimizer: {type(trainer.weight_optimizer).__name__} âœ“")
    else:
        print(f"  âš  WARNING: SA-BSP selected but weight_optimizer is None!")
else:
    print(f"  Weight optimizer: None")

print(f"\nStarting training...\n")

# Train
results = trainer.train()

print(f"\nâœ“ Training complete!")
print(f"  Best val loss: {results['best_val_loss']:.6f}")

## Cell 7: Plot Training History

In [None]:
# Extract metrics
train_losses = [h['loss'] for h in results['train_history']]
val_losses = [h['loss'] for h in results['val_history']]
val_field_errors = [h['field_error'] for h in results['val_history']]
val_spectrum_errors = [h['spectrum_error'] for h in results['val_history']]
epochs = range(1, len(train_losses) + 1)

# Create plots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Loss
axes[0].plot(epochs, train_losses, label='Train Loss', marker='o', markersize=3)
axes[0].plot(epochs, val_losses, label='Val Loss', marker='s', markersize=3)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title(f'Training Loss ({LOSS_TYPE.upper()})')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Field Error
axes[1].plot(epochs, val_field_errors, label='Val Field Error', marker='s', markersize=3, color='orange')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Field Error')
axes[1].set_title('Field Error (Real Space)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Plot 3: Spectrum Error
axes[2].plot(epochs, val_spectrum_errors, label='Val Spectrum Error', marker='s', markersize=3, color='green')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Spectrum Error')
axes[2].set_title('Spectrum Error (Frequency Space)')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle(f'{MODEL_ARCH.upper()} Training Results', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Print final metrics
print(f"\nFinal Metrics:")
print(f"  Train Loss: {train_losses[-1]:.6f}")
print(f"  Val Loss: {val_losses[-1]:.6f}")
print(f"  Val Field Error: {val_field_errors[-1]:.6f}")
print(f"  Val Spectrum Error: {val_spectrum_errors[-1]:.6f}")

## Summary

This notebook demonstrated:
1. âœ“ Loading real CDON data with proper normalization
2. âœ“ Creating neural operator models (DeepONet, FNO, UNet)
3. âœ“ **Configurable loss functions** (Baseline, BSP, SA-BSP)
4. âœ“ Training with SimpleTrainer
5. âœ“ Visualizing training metrics

**Experiment with different configurations:**
- **Cell 3**: Change `MODEL_ARCH` to try different models
- **Cell 4B**: Change `LOSS_TYPE` to try different loss functions
- **Cell 5**: Adjust hyperparameters (epochs, learning rate, etc.)