# Neural Operator Training Demo: CDON Dataset

This notebook demonstrates end-to-end training of neural operator models (DeepONet, FNO, UNet) on the CDON dataset.

**Features:**
- Trains on **real CDON data**
- **Sequential training with all 6 loss functions**:
  - **BASELINE**: Relative L2 loss only (baseline MSE)
  - **BSP**: MSE + fixed BSP loss with k¬≤ weighting
  - **Log-BSP**: MSE + BSP with log‚ÇÅ‚ÇÄ spectral energies (uniform Œª_k weighting)
  - **SA-BSP (Per-bin)**: MSE + 32 adaptive per-bin weights (negated gradients for frequency emphasis)
  - **SA-BSP (Global)**: MSE + 2 adaptive weights (w_mse + w_bsp) for MSE/BSP balance
  - **SA-BSP (Combined)**: MSE + 34 weights (w_mse + w_bsp + 32 per-bin) with full competitive dynamics
- **Multi-loss comparison plots** showing training metrics
- **Energy spectrum visualization** (E(k) vs wavenumber) to identify spectral bias
- **Spectral bias quantification** with metrics and comparison plots
- Compatible with Google Colab

**Models available:**
- `deeponet`: Branch-trunk architecture with SIREN activation (~235K params)
- `fno`: Fourier Neural Operator (~261K params)
- `unet`: Encoder-decoder with skip connections (~249K params)

**SA-PINNs Implementation:**
Uses saddle-point optimization with negated gradients (gradient ascent on loss) to enable competitive dynamics. This automatically emphasizes difficult frequency bins and finds optimal loss balance through min-max optimization.

## Cell 0: Force Reload Modules

Run this cell to reload all project modules after code changes.

In [None]:
# Force reload of all modules
import sys
import importlib

# Get list of all loaded modules from the project
modules_to_reload = []
for module_name in list(sys.modules.keys()):
    if any(x in module_name for x in ['src.', 'configs.']):
        modules_to_reload.append(module_name)

# Remove modules from sys.modules to force reload
for module_name in modules_to_reload:
    if module_name in sys.modules:
        del sys.modules[module_name]

print(f"‚úì Cleared {len(modules_to_reload)} cached modules")
print("  Run Cell 1 to reimport all modules with latest code")

## Cell 1: Setup & Imports (Colab-Ready)

In [None]:
# Google Colab setup
import sys
import os
from pathlib import Path

# Ensure we're in /content
try:
    os.chdir('/content')
except:
    pass

# Clone repository if running in Colab
repo_path = Path('/content/local')
if not repo_path.exists():
    print("üì• Cloning repository...")
    !git clone https://github.com/maximbeekenkamp/local.git
    print("‚úÖ Repository cloned")
else:
    print("üì• Updating repository...")
    !git -C /content/local pull
    print("‚úÖ Repository updated")

# Change to repo directory
try:
    os.chdir('/content/local')
    print(f"‚úÖ Changed to: {os.getcwd()}")
except:
    pass

# Install dependencies
print("\nüì¶ Installing dependencies...")
!pip install -r requirements.txt -q
print("‚úÖ Dependencies installed")

# Standard imports
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

# Project imports
from src.core.data_processing.cdon_dataset import CDONDataset
from src.core.data_processing.cdon_transforms import CDONNormalization
from src.core.models.model_factory import create_model
from src.core.training.simple_trainer import SimpleTrainer
from configs.training_config import TrainingConfig

print("\n‚úì Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Cell 2: Load Real CDON Data

In [None]:
# Get project root
project_root = Path.cwd()
print(f"Project root: {project_root}")

# Data directory
DATA_DIR = project_root / 'CDONData'
print(f"Data directory: {DATA_DIR}")

# Create normalization object (required by CDONDataset)
stats_path = project_root / 'configs' / 'cdon_stats.json'
print(f"Loading stats from: {stats_path}")
normalizer = CDONNormalization(stats_path=str(stats_path))

# Create datasets with optional causal padding
train_dataset = CDONDataset(
    data_dir=str(DATA_DIR),
    split='train',
    normalize=normalizer,
    use_causal_padding=USE_CAUSAL_PADDING,  # NEW: Apply zero-padding if enabled
    signal_length=4000
)

val_dataset = CDONDataset(
    data_dir=str(DATA_DIR),
    split='test',
    normalize=normalizer,
    use_causal_padding=USE_CAUSAL_PADDING,  # NEW: Apply zero-padding if enabled
    signal_length=4000
)

# Create dataloaders
BATCH_SIZE = 16

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"\n‚úì Data loaded successfully")
print(f"  Train samples: {len(train_dataset)}")
print(f"  Val samples: {len(val_dataset)}")
print(f"  Batch size: {BATCH_SIZE}")

# Inspect a sample
sample_input, sample_target = train_dataset[0]
print(f"\nSample shapes:")
print(f"  Input: {sample_input.shape}")
print(f"  Target: {sample_target.shape}")

if USE_CAUSAL_PADDING:
    expected_input_len = 4000 + (4000 - 1)  # signal_length + padding
    if sample_input.shape[-1] == expected_input_len:
        print(f"  ‚úì Causal padding applied correctly (input length: {expected_input_len})")
    else:
        print(f"  ‚ö† Warning: Expected input length {expected_input_len}, got {sample_input.shape[-1]}")

In [None]:
# ============================================================================
# NEW FEATURES CONFIGURATION
# ============================================================================

# 1. CAUSALITY: Zero-padding preprocessing (Reference CausalityDeepONet)
USE_CAUSAL_PADDING = True  # ENABLED BY DEFAULT (matches reference)
# Set to False to disable causal padding (standard preprocessing)
# NOTE: DeepONet uses per-timestep windowing (handled in preprocessing_utils.py)
#       UNet/FNO use simple left-padding (enabled here)

# 2. DEEPONET ACTIVATION: Choose activation function
DEEPONET_ACTIVATION = 'requ'  # Options: 'requ' (default), 'tanh', 'relu', 'siren'
# 'requ' = ReLU¬≤ (reference default, smooth gradients)
# 'tanh' = Stable for operator learning
# 'relu' = Standard ReLU
# 'siren' = Sinusoidal activation (requires siren-pytorch)

# 3. PENALTY LOSS: Optional inverse-variance weighting
USE_PENALTY_LOSS = False  # Set to True to enable penalty weighting
PENALTY_EPSILON = 1e-8     # Numerical stability for penalty
PENALTY_PER_SAMPLE = True  # Per-sample (True) or global (False) penalty

print("‚úì New features configured:")
print(f"  Causal padding:     {'ENABLED' if USE_CAUSAL_PADDING else 'DISABLED'} (default: ENABLED)")
print(f"  DeepONet activation: {DEEPONET_ACTIVATION.upper()} (default: REQU)")
print(f"  Penalty loss:       {'ENABLED' if USE_PENALTY_LOSS else 'DISABLED'} (default: DISABLED)")
if USE_PENALTY_LOSS:
    print(f"    - Epsilon:        {PENALTY_EPSILON}")
    print(f"    - Per-sample:     {PENALTY_PER_SAMPLE}")
print()

# These settings will be applied in subsequent cells
print("üìù NOTE:")
if USE_CAUSAL_PADDING:
    print("  ‚Üí Zero-padding ENABLED (matches reference CausalityDeepONet)")
    print("  ‚Üí Inputs will be left-padded: [1, 4000] ‚Üí [1, 7999]")
    print("  ‚Üí Outputs remain unchanged: [1, 4000]")
else:
    print("  ‚Üí Standard preprocessing (inputs: [1, 4000], outputs: [1, 4000])")

if USE_PENALTY_LOSS:
    print("  ‚Üí Penalty weighting will be applied to all loss functions")
    print(f"    Formula: loss *= 1 / (max(abs(target))¬≤ + {PENALTY_EPSILON})")
    
print(f"  ‚Üí DeepONet will use {DEEPONET_ACTIVATION.upper()} activation")

# Choose model architecture
MODEL_ARCH = 'deeponet'  # Options: 'deeponet', 'fno', 'unet'

# Create model with optional DeepONet activation
if MODEL_ARCH == 'deeponet':
    model = create_model(MODEL_ARCH, config={'activation': DEEPONET_ACTIVATION})
    print(f"‚úì Created {MODEL_ARCH.upper()} model with {DEEPONET_ACTIVATION.upper()} activation")
else:
    model = create_model(MODEL_ARCH)
    print(f"‚úì Created {MODEL_ARCH.upper()} model")

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"  Parameters: {num_params:,}")

## Cell 3: Choose Model Architecture

**Change `MODEL_ARCH` to try different models:**
- `'deeponet'`: Branch-trunk architecture with SIREN activation
- `'fno'`: Fourier Neural Operator
- `'unet'`: U-Net encoder-decoder

In [None]:
# Choose model architecture
MODEL_ARCH = 'deeponet'  # Options: 'deeponet', 'fno', 'unet'

model = create_model(MODEL_ARCH)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úì Created {MODEL_ARCH.upper()} model")
print(f"  Parameters: {num_params:,}")

## Cell 4: Initialize Results Storage

We'll train with all 6 loss types sequentially and store results for comparison:
- **BASELINE**: Relative L2 loss only (MSE baseline)
- **BSP**: MSE + fixed BSP loss with k¬≤ weighting
- **Log-BSP**: MSE + BSP with log‚ÇÅ‚ÇÄ spectral energies (uniform weighting)
- **SA-BSP-PERBIN**: MSE + 32 adaptive per-bin weights (negated gradients for frequency emphasis)
- **SA-BSP-GLOBAL**: 2 adaptive weights (w_mse + w_bsp) with negated gradients for MSE/BSP balance
- **SA-BSP-COMBINED**: 34 weights (w_mse + w_bsp + 32 per-bin) with full competitive dynamics

In [None]:
# Import loss configurations
from configs.loss_config import (
    BASELINE_CONFIG, 
    BSP_CONFIG,
    LOG_BSP_CONFIG,
    SA_BSP_PERBIN_CONFIG,
    SA_BSP_GLOBAL_CONFIG,
    SA_BSP_COMBINED_CONFIG
)
from src.core.evaluation.loss_factory import create_loss

# Loss configuration map
loss_config_map = {
    'baseline': BASELINE_CONFIG,
    'bsp': BSP_CONFIG,
    'log-bsp': LOG_BSP_CONFIG,
    'sa-bsp-perbin': SA_BSP_PERBIN_CONFIG,
    'sa-bsp-global': SA_BSP_GLOBAL_CONFIG,
    'sa-bsp-combined': SA_BSP_COMBINED_CONFIG
}

# Storage dictionaries for results from all loss types
all_training_results = {}  # Key: f"{MODEL_ARCH}_{loss_type}"
all_trainers = {}
trained_models = {}

print("‚úì Storage initialized for multi-loss training")
print("\nWill train with 6 loss types:")
print("  1. BASELINE:", BASELINE_CONFIG.description)
print("  2. BSP:", BSP_CONFIG.description)
print("  3. LOG-BSP:", LOG_BSP_CONFIG.description)
print("  4. SA-BSP-PERBIN:", SA_BSP_PERBIN_CONFIG.description)
print("  5. SA-BSP-GLOBAL:", SA_BSP_GLOBAL_CONFIG.description)
print("  6. SA-BSP-COMBINED:", SA_BSP_COMBINED_CONFIG.description)

## Cell 5: Sequential Training with All Loss Types

Train the same model architecture with all 6 loss functions sequentially:
1. **BASELINE** - Pure MSE baseline
2. **BSP** - Fixed spectral loss with k¬≤ weighting
3. **Log-BSP** - Spectral loss with log‚ÇÅ‚ÇÄ energies and uniform weighting
4. **SA-BSP-PERBIN** - 32 adaptive weights (emphasize hard frequency bins)
5. **SA-BSP-GLOBAL** - 2 adaptive weights (learn MSE/BSP balance)
6. **SA-BSP-COMBINED** - 34 adaptive weights (full competitive dynamics)

In [None]:
# Train with all 6 loss types sequentially
loss_types_to_train = ['baseline', 'bsp', 'log-bsp', 'sa-bsp-perbin', 'sa-bsp-global', 'sa-bsp-combined']

for LOSS_TYPE in loss_types_to_train:
    print(f"\n{'='*70}")
    print(f"Training {MODEL_ARCH.upper()} with {LOSS_TYPE.upper()} Loss")
    print(f"{'='*70}\n")
    
    # Select loss configuration
    selected_loss_config = loss_config_map[LOSS_TYPE]
    print(f"Loss config: {selected_loss_config.description}")
    
    # Create loss function
    criterion = create_loss(selected_loss_config)
    print(f"‚úì Loss function created: {type(criterion).__name__}")
    
    # Create FRESH model for this loss type (important!)
    model_for_loss = create_model(MODEL_ARCH)
    num_params = sum(p.numel() for p in model_for_loss.parameters() if p.requires_grad)
    print(f"‚úì Fresh model created ({num_params:,} parameters)")
    
    # Create training config
    # Select optimizer based on architecture
    # FNO has complex-valued Fourier layers incompatible with SOAP
    optimizer_type = 'adam' if MODEL_ARCH == 'fno' else 'soap'
    
    config = TrainingConfig(
        num_epochs=50,
        learning_rate=1e-3,
        optimizer_type=optimizer_type,  # Adam for FNO, SOAP for others
        batch_size=BATCH_SIZE,
        weight_decay=1e-4,
        scheduler_type='cosine',
        cosine_eta_min=1e-6,
        eval_metrics=['field_error', 'spectrum_error'],
        eval_frequency=1,
        checkpoint_dir=f'checkpoints/{MODEL_ARCH}_{LOSS_TYPE}',
        save_best=False,
        save_latest=False,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        num_workers=2,
        verbose=True
    )
    
    # Create trainer
    trainer = SimpleTrainer(
        model=model_for_loss,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        loss_config=selected_loss_config,
        experiment_name=f'{MODEL_ARCH}_{LOSS_TYPE}'
    )
    
    print(f"‚úì Trainer initialized")
    print(f"  Device: {trainer.device}")
    print(f"  Optimizer: {type(trainer.optimizer).__name__}")
    
    # Check for weight optimizer (SA-BSP variants only)
    if 'sa-bsp' in LOSS_TYPE:
        if trainer.weight_optimizer is not None:
            adapt_mode = trainer.adapt_mode
            print(f"  Weight optimizer: ‚úì Created for SA-BSP ({adapt_mode} mode)")
        else:
            print(f"  ‚ö† WARNING: SA-BSP but no weight_optimizer!")
    
    print(f"\nüöÄ Starting training...\n")
    
    # Train
    results = trainer.train()
    
    # Store results
    key = f"{MODEL_ARCH}_{LOSS_TYPE}"
    all_training_results[key] = results
    all_trainers[key] = trainer
    trained_models[key] = model_for_loss
    
    print(f"\n‚úÖ {LOSS_TYPE.upper()} training complete!")
    print(f"   Best val loss: {results['best_val_loss']:.6f}")
    print(f"   Final val loss: {results['val_history'][-1]['loss']:.6f}")

print(f"\n{'='*70}")
print(f"ALL TRAINING COMPLETE!")
print(f"{'='*70}")
print(f"Trained {len(all_training_results)} models with different loss functions")

## Cell 6: Multi-Loss Training Comparison

Compare training metrics across all 6 loss functions.

In [None]:
# Create multi-loss comparison plots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Color scheme for loss types
colors = {
    'baseline': '#1f77b4',       # Blue
    'bsp': '#ff7f0e',             # Orange
    'log-bsp': '#2ca02c',         # Green
    'sa-bsp-perbin': '#d62728',   # Red
    'sa-bsp-global': '#9467bd',   # Purple
    'sa-bsp-combined': '#17becf'  # Cyan
}
linestyles = {
    'baseline': '-', 
    'bsp': '--', 
    'log-bsp': '-.', 
    'sa-bsp-perbin': ':', 
    'sa-bsp-global': '-',
    'sa-bsp-combined': '--'
}
markers = {
    'baseline': 'o', 
    'bsp': 's', 
    'log-bsp': '^', 
    'sa-bsp-perbin': 'D', 
    'sa-bsp-global': 'v',
    'sa-bsp-combined': 'p'
}

for loss_type in ['baseline', 'bsp', 'log-bsp', 'sa-bsp-perbin', 'sa-bsp-global', 'sa-bsp-combined']:
    key = f"{MODEL_ARCH}_{loss_type}"
    results = all_training_results[key]
    
    # Extract metrics
    val_losses = [h['loss'] for h in results['val_history']]
    val_field_errors = [h['field_error'] for h in results['val_history']]
    val_spectrum_errors = [h['spectrum_error'] for h in results['val_history']]
    epochs = range(1, len(val_losses) + 1)
    
    # Create label with short name
    label_map = {
        'baseline': 'BASELINE',
        'bsp': 'BSP',
        'log-bsp': 'Log-BSP',
        'sa-bsp-perbin': 'SA-BSP (Per-bin)',
        'sa-bsp-global': 'SA-BSP (Global)',
        'sa-bsp-combined': 'SA-BSP (Combined)'
    }
    label = label_map[loss_type]
    
    # Plot on all 3 axes
    axes[0].plot(epochs, val_losses, label=label, 
                color=colors[loss_type], linestyle=linestyles[loss_type],
                linewidth=2, alpha=0.9, marker=markers[loss_type], markersize=4, markevery=5)
    
    axes[1].plot(epochs, val_field_errors, label=label,
                color=colors[loss_type], linestyle=linestyles[loss_type],
                linewidth=2, alpha=0.9, marker=markers[loss_type], markersize=4, markevery=5)
    
    axes[2].plot(epochs, val_spectrum_errors, label=label,
                color=colors[loss_type], linestyle=linestyles[loss_type],
                linewidth=2, alpha=0.9, marker=markers[loss_type], markersize=4, markevery=5)

# Configure axes with LOG SCALE on y-axis
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Validation Loss', fontsize=12)
axes[0].set_title('Validation Loss Comparison', fontsize=14, fontweight='bold')
axes[0].set_yscale('log')  # LOG SCALE
axes[0].legend(fontsize=9, loc='best')
axes[0].grid(True, alpha=0.3, which='both')

axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Field Error', fontsize=12)
axes[1].set_title('Field Error (Real Space)', fontsize=14, fontweight='bold')
axes[1].set_yscale('log')  # LOG SCALE
axes[1].legend(fontsize=9, loc='best')
axes[1].grid(True, alpha=0.3, which='both')

axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('Spectrum Error', fontsize=12)
axes[2].set_title('Spectrum Error (Frequency Space)', fontsize=14, fontweight='bold')
axes[2].set_yscale('log')  # LOG SCALE
axes[2].legend(fontsize=9, loc='best')
axes[2].grid(True, alpha=0.3, which='both')

plt.suptitle(f'{MODEL_ARCH.upper()}: Loss Function Comparison (6 Variants)', 
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Print final metrics table
print(f"\n{'='*70}")
print("Final Metrics Summary")
print(f"{'='*70}")
print(f"{'Loss Type':<25} {'Val Loss':<12} {'Field Error':<15} {'Spectrum Error':<15}")
print("-"*70)

for loss_type in ['baseline', 'bsp', 'log-bsp', 'sa-bsp-perbin', 'sa-bsp-global', 'sa-bsp-combined']:
    key = f"{MODEL_ARCH}_{loss_type}"
    results = all_training_results[key]
    final_val = results['val_history'][-1]
    
    label = loss_type.upper()
    print(f"{label:<25} {final_val['loss']:<12.6f} "
          f"{final_val['field_error']:<15.6f} {final_val['spectrum_error']:<15.6f}")

## Cell 7: Spectral Bias Visualization (Energy Spectrum)

Visualize E(k) vs wavenumber to identify spectral bias in trained models.

In [None]:
import torch.fft as fft
from src.core.visualization.spectral_analysis import compute_unbinned_spectrum, compute_cached_true_spectrum
from configs.visualization_config import SPECTRUM_CACHE_FILENAME, CACHE_DIR

# Get validation batch for energy spectrum analysis
print("Computing energy spectra for all trained models...")

# Check if models have been trained
if 'trained_models' not in globals() or len(trained_models) == 0:
    print("\n‚ö†Ô∏è  WARNING: No trained models found!")
    print("   Please run Cell 12 (training) first before running this cell.")
    print("   This cell requires the 'trained_models' dictionary to be populated.\n")
else:
    print(f"‚úì Found {len(trained_models)} trained models")
    print(f"  Keys: {list(trained_models.keys())}\n")

val_batch_input, val_batch_target = next(iter(val_loader))

# Move to device for inference
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
val_batch_input = val_batch_input.to(device)
val_batch_target = val_batch_target.to(device)

# Compute ground truth spectrum with percentile-based uncertainty bands from cache
cache_path = f'{CACHE_DIR}/{SPECTRUM_CACHE_FILENAME}'
print(f"Loading true spectrum from cache: {cache_path}")
cached = np.load(cache_path)
k_true = cached['unbinned_frequencies']  # Full FFT resolution (~2000 frequencies)
E_true_median = cached['unbinned_energy_median']  # Median (50th percentile)
E_true_p16 = cached['unbinned_energy_p16']        # Lower bound (16th percentile ‚âà -1œÉ)
E_true_p84 = cached['unbinned_energy_p84']        # Upper bound (84th percentile ‚âà +1œÉ)
print(f"‚úì True spectrum loaded ({len(k_true)} frequencies, unbinned)")
print(f"  Using percentile-based uncertainty bands (16th-84th ‚âà ¬±1œÉ)")

# Collect ALL validation predictions for uncertainty bands
print("\nComputing unbinned spectra with percentile-based uncertainty bands for all models...")
spectra = {}

# Store true spectrum with percentile uncertainty bounds
spectra['True'] = {
    'frequencies': k_true,
    'energy_median': E_true_median,
    'energy_p16': E_true_p16,
    'energy_p84': E_true_p84
}

for loss_type in ['baseline', 'bsp', 'log-bsp', 'sa-bsp-perbin', 'sa-bsp-global', 'sa-bsp-combined']:
    key = f"{MODEL_ARCH}_{loss_type}"
    
    # Check if model exists
    if key not in trained_models:
        print(f"  ‚ö†Ô∏è  Skipping {loss_type.upper()}: model key '{key}' not found in trained_models")
        continue
    
    model_trained = trained_models[key]
    model_trained.eval()
    model_trained.to(device)
    
    try:
        # Collect predictions from ALL validation batches for uncertainty bands
        all_preds = []
        print(f"  Processing {loss_type.upper()}...", end='')
        
        with torch.no_grad():
            for val_input, _ in val_loader:
                val_input = val_input.to(device)
                pred = model_trained(val_input)
                all_preds.append(pred.cpu())
        
        # Stack all predictions: [total_val_samples, C, T]
        all_preds_tensor = torch.cat(all_preds, dim=0)
        
        # Compute unbinned spectrum with percentile-based uncertainty bands
        k_pred, E_pred_median, E_pred_p16, E_pred_p84 = compute_unbinned_spectrum(all_preds_tensor)
        
        # Create display label
        label_map = {
            'baseline': 'BASELINE',
            'bsp': 'BSP',
            'log-bsp': 'Log-BSP',
            'sa-bsp-perbin': 'SA-BSP (Per-bin)',
            'sa-bsp-global': 'SA-BSP (Global)',
            'sa-bsp-combined': 'SA-BSP (Combined)'
        }
        spec_key = f"{MODEL_ARCH.upper()} + {label_map[loss_type]}"
        
        spectra[spec_key] = {
            'frequencies': k_pred,
            'energy_median': E_pred_median,
            'energy_p16': E_pred_p16,
            'energy_p84': E_pred_p84
        }
        
        print(f" ‚úì ({all_preds_tensor.shape[0]} samples)")
    except Exception as e:
        print(f" ‚ùå Error: {e}")
        continue

print(f"\n‚úì Spectra computed for {len(spectra)} entries with percentile-based uncertainty bands\n")

# Plot energy spectrum with percentile-based uncertainty bands (safe for log scale!)
fig, ax = plt.subplots(figsize=(14, 9))

# Color scheme for loss types
colors_plot = {
    'True': '#000000',  # Black for ground truth
    'baseline': '#1f77b4',
    'bsp': '#ff7f0e',
    'log-bsp': '#2ca02c',
    'sa-bsp-perbin': '#d62728',
    'sa-bsp-global': '#9467bd',
    'sa-bsp-combined': '#17becf'
}

# Plot ground truth with uncertainty band (black)
if 'True' in spectra:
    data = spectra['True']
    k = data['frequencies']
    E_median = data['energy_median']
    E_p16 = data['energy_p16']
    E_p84 = data['energy_p84']
    
    # Plot median line
    ax.loglog(k, E_median, color=colors_plot['True'], linewidth=3, 
             label='True (Real Data)', zorder=10, alpha=0.9)
    
    # Plot percentile-based uncertainty band (16th-84th percentiles ‚âà ¬±1œÉ)
    # These are GUARANTEED to be positive ‚Üí safe for log scale!
    ax.fill_between(k, E_p16, E_p84,
                     color=colors_plot['True'], alpha=0.15, zorder=9,
                     label='True (16th-84th percentile)')

# Plot model predictions with percentile-based uncertainty bands
for loss_type in ['baseline', 'bsp', 'log-bsp', 'sa-bsp-perbin', 'sa-bsp-global', 'sa-bsp-combined']:
    label_map = {
        'baseline': 'BASELINE',
        'bsp': 'BSP',
        'log-bsp': 'Log-BSP',
        'sa-bsp-perbin': 'SA-BSP (Per-bin)',
        'sa-bsp-global': 'SA-BSP (Global)',
        'sa-bsp-combined': 'SA-BSP (Combined)'
    }
    label_key = f"{MODEL_ARCH.upper()} + {label_map[loss_type]}"
    
    # Check if spectrum exists before plotting
    if label_key not in spectra:
        print(f"  ‚ö†Ô∏è  Skipping plot for {loss_type.upper()}: '{label_key}' not in spectra dictionary")
        continue
    
    data = spectra[label_key]
    k = data['frequencies']
    E_median = data['energy_median']
    E_p16 = data['energy_p16']
    E_p84 = data['energy_p84']
    
    color = colors_plot[loss_type]
    
    # Plot median line
    ax.loglog(k, E_median, color=color, linewidth=2.5, 
             alpha=0.85, label=label_key, zorder=5)
    
    # Plot percentile-based uncertainty band (guaranteed positive for log scale)
    ax.fill_between(k, E_p16, E_p84,
                     color=color, alpha=0.12, zorder=4)

# Configure plot
ax.set_xlabel('Frequency (normalized)', fontsize=14, fontweight='bold')
ax.set_ylabel('E(k) - Spectral Power', fontsize=14, fontweight='bold')
ax.set_title(f'Energy Spectrum Comparison with Percentile Uncertainty Bands\n{MODEL_ARCH.upper()} Model (6 Loss Variants)', 
            fontsize=16, fontweight='bold')
ax.legend(fontsize=10, loc='best', framealpha=0.95, ncol=1)
ax.grid(True, alpha=0.3, which='both', linestyle='--')

# Set nice axis limits
ax.set_xlim(k_true.min() * 0.9, k_true.max() * 1.1)

plt.tight_layout()
plt.show()

print(f"\n‚úì Energy spectrum plot complete")
print(f"  ‚Ä¢ Unbinned spectrum: Full FFT resolution (~{len(k_true)} frequencies)")
print(f"  ‚Ä¢ Uncertainty bands: 16th-84th percentiles (‚âà ¬±1œÉ) across all validation samples")
print(f"  ‚Ä¢ Percentiles are ALWAYS positive ‚Üí safe for log-scale display!")
print(f"  ‚Ä¢ This visualization shows spectral bias: deviation from ground truth at high frequencies")
print(f"  ‚Ä¢ Log-BSP and SA-BSP variants should show better high-frequency matching than baseline")

## Cell 8: Spectral Bias Quantification

Compute spectral bias metrics to quantify how well each model captures high-frequency content.

In [None]:
from src.core.visualization.spectral_analysis import compute_spectral_bias_metric

print("="*70)
print("SPECTRAL BIAS METRICS")
print("="*70)
print("\nQuantifies how well each model captures different frequency ranges.")
print("Spectral Bias Ratio = High Freq Error / Low Freq Error")
print("  - Ratio > 2.0: Significant spectral bias (struggles with high frequencies)")
print("  - Ratio > 1.5: Moderate spectral bias")
print("  - Ratio ‚â§ 1.5: Low spectral bias (captures frequencies well)")
print("="*70)

# Compute metrics for each trained model
spectral_metrics = {}

for loss_type in ['baseline', 'bsp', 'log-bsp', 'sa-bsp-perbin', 'sa-bsp-global', 'sa-bsp-combined']:
    key = f"{MODEL_ARCH}_{loss_type}"
    model_trained = trained_models[key]
    model_trained.eval()
    model_trained.to(device)
    
    with torch.no_grad():
        pred = model_trained(val_input)
    
    metrics = compute_spectral_bias_metric(pred.cpu(), val_target.cpu(), n_bins=32)
    spectral_metrics[loss_type] = metrics
    
    label_map = {
        'baseline': 'BASELINE',
        'bsp': 'BSP',
        'log-bsp': 'Log-BSP',
        'sa-bsp-perbin': 'SA-BSP (Per-bin)',
        'sa-bsp-global': 'SA-BSP (Global)',
        'sa-bsp-combined': 'SA-BSP (Combined)'
    }
    
    print(f"\n{MODEL_ARCH.upper()} + {label_map[loss_type]}:")
    print(f"  Low frequency error:   {metrics['low_freq_error']:.6f}")
    print(f"  Mid frequency error:   {metrics['mid_freq_error']:.6f}")
    print(f"  High frequency error:  {metrics['high_freq_error']:.6f}")
    print(f"  Spectral bias ratio:   {metrics['spectral_bias_ratio']:.4f}")
    
    # Interpretation
    if metrics['spectral_bias_ratio'] > 2.0:
        print(f"  ‚Üí ‚ö†Ô∏è  SIGNIFICANT spectral bias detected!")
        print(f"     Model struggles with high-frequency content")
    elif metrics['spectral_bias_ratio'] > 1.5:
        print(f"  ‚Üí ‚ö° MODERATE spectral bias")
        print(f"     Some difficulty with high frequencies")
    else:
        print(f"  ‚Üí ‚úÖ LOW spectral bias")
        print(f"     Model captures frequency content well")

# Create comparison visualization
print(f"\n{'='*70}")
print("Spectral Bias Comparison")
print(f"{'='*70}")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))

# Bar plot 1: Frequency errors
loss_types = ['baseline', 'bsp', 'log-bsp', 'sa-bsp-perbin', 'sa-bsp-global', 'sa-bsp-combined']
x = np.arange(len(loss_types))
width = 0.2

low_errors = [spectral_metrics[lt]['low_freq_error'] for lt in loss_types]
mid_errors = [spectral_metrics[lt]['mid_freq_error'] for lt in loss_types]
high_errors = [spectral_metrics[lt]['high_freq_error'] for lt in loss_types]

ax1.bar(x - width, low_errors, width, label='Low Freq', color='#2ca02c', alpha=0.8)
ax1.bar(x, mid_errors, width, label='Mid Freq', color='#ff7f0e', alpha=0.8)
ax1.bar(x + width, high_errors, width, label='High Freq', color='#d62728', alpha=0.8)

ax1.set_xlabel('Loss Type', fontsize=12, fontweight='bold')
ax1.set_ylabel('Frequency Error', fontsize=12, fontweight='bold')
ax1.set_title('Frequency Range Errors', fontsize=14, fontweight='bold')
ax1.set_yscale('log')  # LOG SCALE
ax1.set_xticks(x)
ax1.set_xticklabels(['BASE', 'BSP', 'Log-BSP', 'SA-Per', 'SA-Glob', 'SA-Comb'], rotation=15, ha='right')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3, axis='y', which='both')

# Bar plot 2: Spectral bias ratio
bias_ratios = [spectral_metrics[lt]['spectral_bias_ratio'] for lt in loss_types]
colors_bars = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#17becf']

bars = ax2.bar(x, bias_ratios, color=colors_bars, alpha=0.8, edgecolor='black', linewidth=1.5)

# Add threshold lines
ax2.axhline(y=2.0, color='red', linestyle='--', linewidth=2, alpha=0.7, label='Significant bias threshold')
ax2.axhline(y=1.5, color='orange', linestyle='--', linewidth=2, alpha=0.7, label='Moderate bias threshold')

ax2.set_xlabel('Loss Type', fontsize=12, fontweight='bold')
ax2.set_ylabel('Spectral Bias Ratio', fontsize=12, fontweight='bold')
ax2.set_title('Spectral Bias Ratio (High/Low)', fontsize=14, fontweight='bold')
ax2.set_xticks(x)
ax2.set_xticklabels(['BASE', 'BSP', 'Log-BSP', 'SA-Per', 'SA-Glob', 'SA-Comb'], rotation=15, ha='right')
ax2.legend(fontsize=10, loc='upper right')
ax2.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for i, (bar, ratio) in enumerate(zip(bars, bias_ratios)):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height + 0.05,
            f'{ratio:.2f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.suptitle(f'{MODEL_ARCH.upper()}: Spectral Bias Analysis (6 Loss Variants)', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"\n{'='*70}")
print("‚úÖ Spectral bias analysis complete!")
print(f"{'='*70}")
print("\nKey Findings:")
print("  ‚Ä¢ Baseline: Pure MSE - typically shows significant spectral bias")
print("  ‚Ä¢ BSP: Fixed spectral loss with k¬≤ weighting - moderate improvement")
print("  ‚Ä¢ Log-BSP: Log-domain spectral loss - addresses wide dynamic range")
print("  ‚Ä¢ SA-BSP (Per-bin): Adaptive per-bin weights - emphasize hard frequencies")
print("  ‚Ä¢ SA-BSP (Global): Adaptive MSE/BSP balance - optimize overall trade-off")
print("  ‚Ä¢ SA-BSP (Combined): Full competitive dynamics - most expressive approach")

## Summary

This notebook demonstrated:
1. ‚úì Loading real CDON data with proper normalization
2. ‚úì Creating neural operator models (DeepONet, FNO, UNet)
3. ‚úì **Sequential training with all 6 loss functions**:
   - **BASELINE**: Relative L2 loss only (MSE baseline)
   - **BSP**: MSE + fixed BSP loss with k¬≤ weighting
   - **Log-BSP**: MSE + BSP with log‚ÇÅ‚ÇÄ spectral energies (uniform weighting)
   - **SA-BSP (Per-bin)**: MSE + 32 adaptive per-bin weights (negated gradients for frequency emphasis)
   - **SA-BSP (Global)**: MSE + 2 adaptive weights (w_mse + w_bsp, negated gradients for MSE/BSP balance)
   - **SA-BSP (Combined)**: MSE + 34 weights (w_mse + w_bsp + 32 per-bin, all negated gradients for full competitive dynamics)
4. ‚úì **Multi-loss comparison plots** showing training metrics
5. ‚úì **Energy spectrum visualization** (E(k) vs wavenumber) to identify spectral bias
6. ‚úì **Spectral bias quantification** with metrics and comparison plots

**Key Results:**
- All 6 loss types trained on the same model architecture
- Direct comparison shows which loss function best mitigates spectral bias
- Energy spectrum plot reveals how well each model captures high-frequency content
- Quantitative metrics identify spectral bias ratio for each approach

**SA-PINNs Implementation:**
- **Per-bin mode**: Uses negated gradients (ascent) to emphasize difficult frequency bins
- **Global mode**: Uses negated gradients (ascent) to learn optimal MSE/BSP balance via competitive dynamics
- **Combined mode**: Full competitive dynamics with all weights (w_mse, w_bsp, and 32 per-bin) using negated gradients

**Experiment with different configurations:**
- **Cell 0**: Run to force reload modules after code changes
- **Cell 3**: Change `MODEL_ARCH` to try different models ('deeponet', 'fno', 'unet')
- **Cell 5**: Adjust hyperparameters (epochs, learning rate, etc.) in TrainingConfig
- Run all cells sequentially to train and compare all 6 loss types automatically!