# Neural Operator Training Demo: CDON Dataset

This notebook demonstrates end-to-end training of neural operator models (DeepONet, FNO, UNet) on the CDON dataset.

**Features:**
- Trains on **real CDON data** for **50 epochs**
- Minimal custom code - reuses existing codebase
- Includes visualizations of training progress and predictions
- Compatible with Google Colab

**Models available:**
- `deeponet`: Branch-trunk architecture (~235K params)
- `fno`: Fourier Neural Operator (~261K params)
- `unet`: Encoder-decoder with skip connections (~249K params)

## Cell 1: Setup & Imports

In [None]:
# Google Colab setup (uncomment if running in Colab)
# import sys
# if 'google.colab' in sys.modules:
#     !git clone https://github.com/YOUR_USERNAME/CMAME.git
#     %cd CMAME
#     !pip install -r requirements.txt -q

# Standard imports - ALL from existing codebase
import sys
from pathlib import Path

# Add project root to path (if not in Colab)
project_root = Path.cwd().parent if Path.cwd().name == 'notebooks' else Path.cwd()
sys.path.insert(0, str(project_root))

import torch
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader
from src.core.data_processing.cdon_dataset import CDONDataset
from src.core.models.model_factory import create_model
from src.core.training.simple_trainer import SimpleTrainer
from configs.training_config import TrainingConfig

print("‚úì Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Cell 2: Load Real CDON Data

Uses the existing `CDONDataset` class - no custom data loading code needed.

In [None]:
# Data directory (adjust path if needed)
DATA_DIR = project_root / 'data' / 'CDONData'

# Create datasets using existing CDONDataset class
train_dataset = CDONDataset(
    data_dir=str(DATA_DIR),
    split='train',
    normalize=True
)

val_dataset = CDONDataset(
    data_dir=str(DATA_DIR),
    split='test',
    normalize=True
)

# Create dataloaders
BATCH_SIZE = 16

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"‚úì Data loaded successfully")
print(f"  Train samples: {len(train_dataset)}")
print(f"  Val samples: {len(val_dataset)}")
print(f"  Batch size: {BATCH_SIZE}")

# Inspect a sample
sample_input, sample_target = train_dataset[0]
print(f"\nSample shapes:")
print(f"  Input: {sample_input.shape}")
print(f"  Target: {sample_target.shape}")

## Cell 3: Choose Model Architecture

**Change `MODEL_ARCH` below to try different models:**
- `'deeponet'`: Branch-trunk architecture
- `'fno'`: Fourier Neural Operator
- `'unet'`: U-Net encoder-decoder
- `'all'`: **Train all three models and compare** (NEW!)

In [None]:
# Choose model architecture (change this to experiment)
MODEL_ARCH = 'deeponet'  # Options: 'deeponet', 'fno', 'unet', 'all'

# Check if training all models
if MODEL_ARCH == 'all':
    models_to_train = ['deeponet', 'fno', 'unet']
    models = {}
    
    print(f"‚úì Will train all {len(models_to_train)} models")
    print(f"  Models: {', '.join([m.upper() for m in models_to_train])}")
    
    # Create all models
    for arch in models_to_train:
        models[arch] = create_model(arch)
        num_params = sum(p.numel() for p in models[arch].parameters() if p.requires_grad)
        print(f"  {arch.upper()}: {num_params:,} parameters")
else:
    # Single model training
    model = create_model(MODEL_ARCH)
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"‚úì Created {MODEL_ARCH.upper()} model")
    print(f"  Parameters: {num_params:,}")
    print(f"\nModel architecture:")
    print(model)

## Cell 4: Configure Training

Uses existing `TrainingConfig` dataclass - all hyperparameters in one place.

In [None]:
# Create training configuration using existing TrainingConfig
config = TrainingConfig(
    # Training
    num_epochs=50,
    learning_rate=1e-3,
    batch_size=BATCH_SIZE,
    weight_decay=1e-4,
    
    # Scheduler
    scheduler_type='cosine',
    cosine_eta_min=1e-6,
    
    # Evaluation
    eval_metrics=['field_error', 'spectrum_error'],
    eval_frequency=1,
    
    # Checkpointing (will be customized per model if training all)
    checkpoint_dir=f'checkpoints/{MODEL_ARCH}_real_50epochs',
    save_best=True,
    save_latest=True,
    
    # Device (use GPU if available)
    device='cuda' if torch.cuda.is_available() else 'cpu',
    num_workers=2,
    
    # Logging
    verbose=True
)

print(f"‚úì Training configuration:")
print(f"  Epochs: {config.num_epochs}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Scheduler: {config.scheduler_type}")
print(f"  Device: {config.device}")
if MODEL_ARCH == 'all':
    print(f"  Training mode: ALL MODELS")
else:
    print(f"  Checkpoint dir: {config.checkpoint_dir}")

In [None]:
# Final summary before training
print(f"{'='*70}")
print(f"LOSS CONFIGURATION SUMMARY")
print(f"{'='*70}")

print(f"\n‚úì Loss Type: {LOSS_TYPE.upper()}")
print(f"‚úì Loss Module: {type(criterion).__name__}")
print(f"‚úì Configuration Validated: ‚úì")

# Show loss components
from src.core.evaluation.loss_factory import CombinedLoss

if isinstance(criterion, CombinedLoss):
    print(f"\nüì¶ Combined Loss Components:")
    print(f"  1. Base loss: {type(criterion.base_loss).__name__}")
    print(f"  2. Spectral loss: {type(criterion.spectral_loss).__name__}")
    print(f"  3. Lambda spectral: {criterion.lambda_spectral}")
    
    print(f"\nüí° Total loss = Base loss + {criterion.lambda_spectral} √ó Spectral loss")
else:
    print(f"\nüì¶ Single Loss Component:")
    print(f"  {type(criterion).__name__}")

# Check if weight optimizer will be needed
requires_weight_optimizer = LOSS_TYPE == 'sa-bsp'

print(f"\nüîß Trainer Configuration:")
print(f"  Model optimizer: Adam (for model parameters)")
if requires_weight_optimizer:
    print(f"  Weight optimizer: Adam (for adaptive weights) ‚Üê WILL BE CREATED")
    print(f"    ‚Üí Separate optimizer for SA-BSP adaptive weights")
else:
    print(f"  Weight optimizer: None")

print(f"\nüéØ Expected Loss Behavior:")
if LOSS_TYPE == 'baseline':
    print(f"  - Loss scale: ~0.1 - 2.0 (Relative L2 normalized)")
    print(f"  - Focuses on: Overall field accuracy")
elif LOSS_TYPE == 'bsp':
    print(f"  - Loss scale: May be higher (combined MSE + spectral)")
    print(f"  - Focuses on: Field accuracy + frequency spectrum matching")
    print(f"  - Better for: Mitigating spectral bias")
elif LOSS_TYPE == 'sa-bsp':
    print(f"  - Loss scale: May be higher (combined MSE + adaptive spectral)")
    print(f"  - Focuses on: Field accuracy + adaptive frequency emphasis")
    print(f"  - Better for: Automatically learning which frequencies matter")
    print(f"  - Adaptive weights will evolve during training")

print(f"\n{'='*70}")
print(f"‚úì Loss configuration complete - ready for trainer creation")
print(f"{'='*70}")

# Training logic - handles both single model and all models
if MODEL_ARCH == 'all':
    # Train all models sequentially
    training_results = {}
    trainers = {}
    
    for arch in models_to_train:
        print(f"\n{'='*70}")
        print(f"Training {arch.upper()}")
        print(f"{'='*70}\n")
        
        # Create model-specific config
        model_config = TrainingConfig(
            num_epochs=config.num_epochs,
            learning_rate=config.learning_rate,
            batch_size=config.batch_size,
            weight_decay=config.weight_decay,
            scheduler_type=config.scheduler_type,
            cosine_eta_min=config.cosine_eta_min,
            eval_metrics=config.eval_metrics,
            eval_frequency=config.eval_frequency,
            checkpoint_dir=f'checkpoints/{arch}_real_50epochs',
            save_best=True,
            save_latest=True,
            device=config.device,
            num_workers=config.num_workers,
            verbose=True
        )
        
        # Create trainer (NOW WITH REQUIRED LOSS_CONFIG)
        trainer = SimpleTrainer(
            model=models[arch],
            train_loader=train_loader,
            val_loader=val_loader,
            config=model_config,
            loss_config=selected_loss_config,  # ‚Üê NEW: Required parameter
            experiment_name=f'{arch}_real_50epochs'
        )
        
        # Verify SA-BSP weight optimizer if applicable
        if LOSS_TYPE == 'sa-bsp':
            assert trainer.weight_optimizer is not None, \
                "SA-BSP should create weight_optimizer"
            print(f"  ‚úì Weight optimizer created for SA-BSP")
        
        # Train
        results = trainer.train()
        
        # Store results and trainer
        training_results[arch] = results
        trainers[arch] = trainer
        
        print(f"\n‚úì {arch.upper()} training complete!")
        print(f"  Best val loss: {results['best_val_loss']:.6f}")
        print(f"  Checkpoints: {trainer.checkpoint_dir}")
    
    print(f"\n{'='*70}")
    print(f"ALL MODELS TRAINING COMPLETE!")
    print(f"{'='*70}")
    
else:
    # Single model training (WITH REQUIRED LOSS_CONFIG)
    trainer = SimpleTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        loss_config=selected_loss_config,  # ‚Üê NEW: Required parameter
        experiment_name=f'{MODEL_ARCH}_real_50epochs'
    )
    
    print(f"‚úì Trainer initialized")
    print(f"  Device: {trainer.device}")
    print(f"  Optimizer: {type(trainer.optimizer).__name__}")
    print(f"  Scheduler: {type(trainer.scheduler).__name__}")
    print(f"  Loss function: {type(trainer.criterion).__name__}")
    
    # Verify SA-BSP weight optimizer if applicable
    if LOSS_TYPE == 'sa-bsp':
        if trainer.weight_optimizer is not None:
            print(f"  Weight optimizer: {type(trainer.weight_optimizer).__name__} ‚úì")
            print(f"    ‚Üí Optimizing {sum(p.numel() for p in trainer.criterion.spectral_loss.adaptive_weights.parameters())} adaptive weight parameters")
        else:
            print(f"  ‚ö† WARNING: SA-BSP selected but weight_optimizer is None!")
    else:
        print(f"  Weight optimizer: None (not needed for {LOSS_TYPE})")
    
    print(f"\nStarting training for {config.num_epochs} epochs...\n")
    
    # Train model (with rich progress bars)
    results = trainer.train()
    
    print(f"\n‚úì Training complete!")
    print(f"  Best val loss: {results['best_val_loss']:.6f}")
    print(f"  Checkpoints saved to: {trainer.checkpoint_dir}")

In [None]:
# Only run this cell for BSP or SA-BSP loss types
if LOSS_TYPE in ['bsp', 'sa-bsp']:
    print(f"{'='*70}")
    print(f"BSP/SA-BSP Loss Inspection")
    print(f"{'='*70}")
    
    # Access the spectral loss component
    from src.core.evaluation.loss_factory import CombinedLoss
    
    if isinstance(criterion, CombinedLoss):
        spectral_loss = criterion.spectral_loss
        
        print(f"\nüìä Frequency Binning Configuration:")
        print(f"  Number of bins: {spectral_loss.n_bins}")
        print(f"  Binning mode: {spectral_loss.binning_mode}")
        print(f"  Lambda (spectral weight): {spectral_loss.lambda_bsp}")
        
        # Show bin edges (example for 4000 timesteps)
        timesteps = 4000
        n_freq = timesteps // 2 + 1  # rfft output size
        
        print(f"\nüìè Frequency Domain:")
        print(f"  Time domain length: {timesteps}")
        print(f"  Frequency domain length: {n_freq}")
        print(f"  Nyquist frequency index: {n_freq - 1}")
        
        # Calculate bin boundaries
        bin_size = n_freq / spectral_loss.n_bins
        print(f"\nüóÇÔ∏è  Bin Structure:")
        print(f"  Bin size (avg): {bin_size:.2f} frequency components per bin")
        
        # Show first few bins
        for i in range(min(5, spectral_loss.n_bins)):
            start_idx = int(i * bin_size)
            end_idx = int((i + 1) * bin_size)
            print(f"  Bin {i}: freq indices [{start_idx}, {end_idx})")
        
        if spectral_loss.n_bins > 5:
            print(f"  ... ({spectral_loss.n_bins - 5} more bins)")
        
        # SA-BSP specific: show adaptive weights
        if LOSS_TYPE == 'sa-bsp':
            print(f"\nüéØ Adaptive Weights (SA-BSP):")
            
            from src.core.evaluation.adaptive_spectral_loss import SelfAdaptiveBSPLoss
            
            if isinstance(spectral_loss, SelfAdaptiveBSPLoss):
                initial_weights = spectral_loss.adaptive_weights()
                
                print(f"  Weight mode: {spectral_loss.adaptive_weights.mode}")
                print(f"  Number of weight parameters: {initial_weights.numel()}")
                print(f"\n  Initial weights (first 10):")
                for i in range(min(10, len(initial_weights))):
                    print(f"    Bin {i}: {initial_weights[i].item():.4f}")
                
                if len(initial_weights) > 10:
                    print(f"    ... ({len(initial_weights) - 10} more)")
                
                print(f"\n  Weight statistics:")
                print(f"    Mean: {initial_weights.mean().item():.4f}")
                print(f"    Std:  {initial_weights.std().item():.4f}")
                print(f"    Min:  {initial_weights.min().item():.4f}")
                print(f"    Max:  {initial_weights.max().item():.4f}")
                
                # Plot initial weights
                import matplotlib.pyplot as plt
                
                fig, ax = plt.subplots(1, 1, figsize=(10, 4))
                weights_np = initial_weights.detach().cpu().numpy()
                ax.bar(range(len(weights_np)), weights_np, alpha=0.7, color='steelblue')
                ax.set_xlabel('Bin Index')
                ax.set_ylabel('Weight Value')
                ax.set_title('Initial Adaptive Weights (Before Training)')
                ax.grid(True, alpha=0.3)
                plt.tight_layout()
                plt.show()
                
                print(f"\nüí° These weights will be learned during training!")
        
        else:
            print(f"\nüí° BSP uses fixed equal weights for all bins")
    
    print(f"\n{'='*70}")
    
else:
    print(f"‚è≠Ô∏è  Skipping BSP/SA-BSP inspection (using {LOSS_TYPE} loss)")

## Cell 4E: BSP/SA-BSP Specific Inspection (Conditional)

**Only runs if LOSS_TYPE is 'bsp' or 'sa-bsp'**

Inspects frequency binning configuration and adaptive weights.

In [None]:
# Get one batch from the training data
print("Testing loss on real CDON data...")
print(f"{'='*70}")

sample_batch_input, sample_batch_target = next(iter(train_loader))

print(f"Batch shapes:")
print(f"  Input: {sample_batch_input.shape}")
print(f"  Target: {sample_batch_target.shape}")

# Compute loss on real data (CPU is fine for testing)
try:
    real_data_loss = criterion(sample_batch_input, sample_batch_target)
    
    print(f"\n‚úì Loss computed on real data successfully")
    print(f"  Loss value: {real_data_loss.item():.6f}")
    print(f"  Loss is finite: {torch.isfinite(real_data_loss).item()}")
    
    # Test gradients
    real_data_loss.backward()
    print(f"‚úì Gradients computed successfully")
    
    # Check gradient magnitudes (should not be zero or extreme)
    grad_magnitude = sample_batch_input.grad.abs().mean().item() if sample_batch_input.grad is not None else 0.0
    print(f"  Gradient magnitude (mean abs): {grad_magnitude:.6e}")
    
    if grad_magnitude > 0:
        print(f"  Gradients are non-zero (good for training)")
    
except Exception as e:
    print(f"\n‚ùå ERROR during loss computation on real data:")
    print(f"  {type(e).__name__}: {e}")
    import traceback
    traceback.print_exc()
    raise

print(f"{'='*70}")
print(f"‚úì Loss function validated on real CDON data")
print(f"  Ready to proceed with training")

## Cell 4D: Test Loss on Sample Real Data

Test the loss function on actual CDON data to catch shape mismatches or NaN issues early.

In [None]:
# Create loss function using the factory
criterion = create_loss(selected_loss_config)

print(f"‚úì Loss function created successfully")
print(f"\nLoss function type: {type(criterion).__name__}")
print(f"\nLoss module structure:")
print(criterion)

# Validate with dummy tensors
print(f"\n{'='*70}")
print("Validation Test: Computing loss on dummy data")
print('='*70)

# Create dummy tensors [batch=4, channels=1, timesteps=1000]
dummy_pred = torch.randn(4, 1, 1000)
dummy_target = torch.randn(4, 1, 1000)

try:
    # Compute loss
    test_loss = criterion(dummy_pred, dummy_target)
    
    # Check if loss is finite
    if torch.isfinite(test_loss):
        print(f"‚úì Loss computation successful")
        print(f"  Dummy loss value: {test_loss.item():.6f}")
        print(f"  Loss is finite: True")
        
        # Test backward pass
        test_loss.backward()
        print(f"‚úì Gradient computation successful")
        print(f"  Dummy gradients computed without errors")
    else:
        print(f"‚ö† WARNING: Loss is not finite (NaN or Inf)")
        print(f"  Loss value: {test_loss.item()}")
        
except Exception as e:
    print(f"‚ùå ERROR during loss computation:")
    print(f"  {type(e).__name__}: {e}")
    raise

print(f"\n{'='*70}")
print("‚úì Loss function validation complete - ready for training")
print('='*70)

## Cell 4C: Create and Validate Loss Function

Test that the loss function is created correctly and computes without errors.

In [None]:
# Choose loss type (CHANGE THIS TO EXPERIMENT)
LOSS_TYPE = 'baseline'  # Options: 'baseline', 'bsp', 'sa-bsp'

# Map loss type to configuration
loss_config_map = {
    'baseline': BASELINE_CONFIG,
    'bsp': BSP_CONFIG,
    'sa-bsp': SA_BSP_CONFIG
}

# Validate selection
if LOSS_TYPE not in loss_config_map:
    raise ValueError(f"Invalid LOSS_TYPE: '{LOSS_TYPE}'. Must be one of {list(loss_config_map.keys())}")

# Get selected configuration
selected_loss_config = loss_config_map[LOSS_TYPE]

print(f"‚úì Selected loss type: {LOSS_TYPE.upper()}")
print(f"\nConfiguration:")
print(f"  Description: {selected_loss_config.description}")
print(f"  Loss type: {selected_loss_config.loss_type}")
print(f"  Parameters:")
for key, value in selected_loss_config.loss_params.items():
    print(f"    {key}: {value}")

# Additional info based on loss type
if LOSS_TYPE == 'bsp':
    print(f"\nüí° BSP Loss will combine:")
    print(f"   - Base loss (MSE in real space)")
    print(f"   - Spectral loss (MSPE on frequency bins)")
    print(f"   - Weighted by lambda_spectral = {selected_loss_config.loss_params.get('lambda_spectral', 'N/A')}")
elif LOSS_TYPE == 'sa-bsp':
    print(f"\nüí° SA-BSP Loss features:")
    print(f"   - Adaptive per-bin weights (trainable)")
    print(f"   - Separate weight optimizer will be created")
    print(f"   - Weights adapt during training to emphasize difficult frequencies")

In [None]:
# Only run for SA-BSP loss
if LOSS_TYPE == 'sa-bsp':
    print(f"{'='*70}")
    print("SA-BSP Adaptive Weight Evolution Analysis")
    print(f"{'='*70}")
    
    from src.core.evaluation.adaptive_spectral_loss import SelfAdaptiveBSPLoss
    
    # Get the trained spectral loss module
    if MODEL_ARCH == 'all':
        # Use first model's trainer as example
        example_trainer = trainers[models_to_train[0]]
        example_arch = models_to_train[0]
    else:
        example_trainer = trainer
        example_arch = MODEL_ARCH
    
    spectral_loss = example_trainer.criterion.spectral_loss
    
    if isinstance(spectral_loss, SelfAdaptiveBSPLoss):
        # Get final adaptive weights
        final_weights = spectral_loss.adaptive_weights()
        final_weights_np = final_weights.detach().cpu().numpy()
        
        print(f"\nüìä Adaptive Weights After Training:")
        print(f"  Model: {example_arch.upper()}")
        print(f"  Number of bins: {len(final_weights_np)}")
        
        print(f"\n  Final weight statistics:")
        print(f"    Mean: {final_weights_np.mean():.4f}")
        print(f"    Std:  {final_weights_np.std():.4f}")
        print(f"    Min:  {final_weights_np.min():.4f}")
        print(f"    Max:  {final_weights_np.max():.4f}")
        print(f"    Range: {final_weights_np.max() - final_weights_np.min():.4f}")
        
        # Find emphasized bins
        mean_weight = final_weights_np.mean()
        std_weight = final_weights_np.std()
        emphasized_bins = np.where(final_weights_np > mean_weight + std_weight)[0]
        deemphasized_bins = np.where(final_weights_np < mean_weight - std_weight)[0]
        
        print(f"\n  Frequency emphasis:")
        print(f"    High-weight bins (>Œº+œÉ): {len(emphasized_bins)} bins ‚Üí {list(emphasized_bins)[:10]}")
        print(f"    Low-weight bins (<Œº-œÉ):  {len(deemphasized_bins)} bins ‚Üí {list(deemphasized_bins)[:10]}")
        
        # Visualize weight evolution
        fig, axes = plt.subplots(1, 2, figsize=(16, 5))
        
        # Plot 1: Final weights bar chart
        ax = axes[0]
        ax.bar(range(len(final_weights_np)), final_weights_np, alpha=0.7, color='steelblue')
        ax.axhline(y=mean_weight, color='red', linestyle='--', label=f'Mean ({mean_weight:.2f})', linewidth=2)
        ax.axhline(y=mean_weight + std_weight, color='orange', linestyle=':', label=f'Mean+Std', linewidth=1.5)
        ax.axhline(y=mean_weight - std_weight, color='orange', linestyle=':', label=f'Mean-Std', linewidth=1.5)
        ax.set_xlabel('Bin Index (Low‚ÜíHigh Frequency)', fontsize=12)
        ax.set_ylabel('Weight Value', fontsize=12)
        ax.set_title('Final Adaptive Weights After Training', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Plot 2: Weight distribution histogram
        ax = axes[1]
        ax.hist(final_weights_np, bins=20, alpha=0.7, color='steelblue', edgecolor='black')
        ax.axvline(x=mean_weight, color='red', linestyle='--', label=f'Mean', linewidth=2)
        ax.set_xlabel('Weight Value', fontsize=12)
        ax.set_ylabel('Frequency (# of bins)', fontsize=12)
        ax.set_title('Weight Value Distribution', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        
        plt.suptitle(f'SA-BSP Adaptive Weight Analysis ({example_arch.upper()})', 
                     fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.show()
        
        # Interpretation
        print(f"\nüí° Interpretation:")
        
        weight_range = final_weights_np.max() - final_weights_np.min()
        if weight_range < 0.5:
            print(f"  ‚ö° Low weight variation (range={weight_range:.2f})")
            print(f"     ‚Üí Model found relatively uniform importance across frequencies")
        elif weight_range < 1.5:
            print(f"  ‚ö° Moderate weight variation (range={weight_range:.2f})")
            print(f"     ‚Üí Model identified some frequency-specific challenges")
        else:
            print(f"  ‚ö° High weight variation (range={weight_range:.2f})")
            print(f"     ‚Üí Model strongly emphasized certain frequency ranges")
        
        # Check if high frequencies are emphasized
        n_bins = len(final_weights_np)
        low_freq_mean = final_weights_np[:n_bins//3].mean()
        mid_freq_mean = final_weights_np[n_bins//3:2*n_bins//3].mean()
        high_freq_mean = final_weights_np[2*n_bins//3:].mean()
        
        print(f"\n  Frequency range emphasis:")
        print(f"    Low frequencies  (bins 0-{n_bins//3}):     avg weight = {low_freq_mean:.4f}")
        print(f"    Mid frequencies  (bins {n_bins//3}-{2*n_bins//3}):   avg weight = {mid_freq_mean:.4f}")
        print(f"    High frequencies (bins {2*n_bins//3}-{n_bins}): avg weight = {high_freq_mean:.4f}")
        
        if high_freq_mean > low_freq_mean * 1.2:
            print(f"     ‚Üí Model emphasized HIGH frequencies (spectral bias detected)")
        elif low_freq_mean > high_freq_mean * 1.2:
            print(f"     ‚Üí Model emphasized LOW frequencies")
        else:
            print(f"     ‚Üí Relatively balanced frequency emphasis")
    
    print(f"\n{'='*70}")
    
else:
    print(f"‚è≠Ô∏è  Skipping SA-BSP weight evolution (using {LOSS_TYPE} loss)")

## Cell 10: SA-BSP Adaptive Weight Evolution (Conditional)

**Only runs if LOSS_TYPE is 'sa-bsp'**

Visualizes how adaptive weights changed during training.

## Cell 4B: Select Loss Type

**Change `LOSS_TYPE` below to experiment with different loss functions:**
- `'baseline'`: Standard Relative L2 loss (default)
- `'bsp'`: BSP loss - better for spectral bias mitigation
- `'sa-bsp'`: Self-Adaptive BSP - learns frequency weights during training

In [None]:
# Import loss configurations from existing codebase
from configs.loss_config import BASELINE_CONFIG, BSP_CONFIG, SA_BSP_CONFIG
from src.core.evaluation.loss_factory import create_loss

print("‚úì Loss configurations imported successfully")
print("\nAvailable loss types:")
print(f"  1. BASELINE: {BASELINE_CONFIG.description}")
print(f"  2. BSP:      {BSP_CONFIG.description}")
print(f"  3. SA-BSP:   {SA_BSP_CONFIG.description}")

# Show configuration details
print("\n" + "="*70)
print("Configuration Details:")
print("="*70)

for name, config in [('BASELINE', BASELINE_CONFIG), ('BSP', BSP_CONFIG), ('SA-BSP', SA_BSP_CONFIG)]:
    print(f"\n{name}:")
    print(f"  Loss type: {config.loss_type}")
    print(f"  Parameters: {config.loss_params}")
    
print("\n" + "="*70)

## Cell 4A: Import Loss Configurations

**NEW: Configurable Loss Functions**

Import loss configurations to enable:
- **Baseline**: Standard Relative L2 loss
- **BSP**: Binned Spectral Power loss (mitigates spectral bias)
- **SA-BSP**: Self-Adaptive BSP with learnable weights

## Cell 5: Create Trainer and Train

Uses existing `SimpleTrainer` class - handles all training logic with rich progress bars.

In [None]:
# Training logic - handles both single model and all models
if MODEL_ARCH == 'all':
    # Train all models sequentially
    training_results = {}
    trainers = {}
    
    for arch in models_to_train:
        print(f"\n{'='*70}")
        print(f"Training {arch.upper()}")
        print(f"{'='*70}\n")
        
        # Create model-specific config
        model_config = TrainingConfig(
            num_epochs=config.num_epochs,
            learning_rate=config.learning_rate,
            batch_size=config.batch_size,
            weight_decay=config.weight_decay,
            scheduler_type=config.scheduler_type,
            cosine_eta_min=config.cosine_eta_min,
            eval_metrics=config.eval_metrics,
            eval_frequency=config.eval_frequency,
            checkpoint_dir=f'checkpoints/{arch}_real_50epochs',
            save_best=True,
            save_latest=True,
            device=config.device,
            num_workers=config.num_workers,
            verbose=True
        )
        
        # Create trainer
        trainer = SimpleTrainer(
            model=models[arch],
            train_loader=train_loader,
            val_loader=val_loader,
            config=model_config,
            experiment_name=f'{arch}_real_50epochs'
        )
        
        # Train
        results = trainer.train()
        
        # Store results and trainer
        training_results[arch] = results
        trainers[arch] = trainer
        
        print(f"\n‚úì {arch.upper()} training complete!")
        print(f"  Best val loss: {results['best_val_loss']:.6f}")
        print(f"  Checkpoints: {trainer.checkpoint_dir}")
    
    print(f"\n{'='*70}")
    print(f"ALL MODELS TRAINING COMPLETE!")
    print(f"{'='*70}")
    
else:
    # Single model training (original code)
    trainer = SimpleTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        experiment_name=f'{MODEL_ARCH}_real_50epochs'
    )
    
    print(f"‚úì Trainer initialized")
    print(f"  Device: {trainer.device}")
    print(f"  Optimizer: {type(trainer.optimizer).__name__}")
    print(f"  Scheduler: {type(trainer.scheduler).__name__}")
    print(f"\nStarting training for {config.num_epochs} epochs...\n")
    
    # Train model (with rich progress bars)
    results = trainer.train()
    
    print(f"\n‚úì Training complete!")
    print(f"  Best val loss: {results['best_val_loss']:.6f}")
    print(f"  Checkpoints saved to: {trainer.checkpoint_dir}")

## Cell 6: Plot Training History

In [None]:
# Plot training history - handles both single model and comparison
if MODEL_ARCH == 'all':
    # OVERLAY COMPARISON PLOT for all models
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Color scheme
    colors = {'deeponet': '#1f77b4', 'fno': '#ff7f0e', 'unet': '#2ca02c'}
    
    for arch in models_to_train:
        results_arch = training_results[arch]
        train_losses = [h['loss'] for h in results_arch['train_history']]
        val_losses = [h['loss'] for h in results_arch['val_history']]
        val_field_errors = [h['field_error'] for h in results_arch['val_history']]
        val_spectrum_errors = [h['spectrum_error'] for h in results_arch['val_history']]
        epochs = range(1, len(train_losses) + 1)
        
        color = colors[arch]
        label = arch.upper()
        
        # Plot 1: Validation Loss
        axes[0].plot(epochs, val_losses, label=label, color=color, linewidth=2, alpha=0.9)
        
        # Plot 2: Field Error
        axes[1].plot(epochs, val_field_errors, label=label, color=color, linewidth=2, alpha=0.9)
        
        # Plot 3: Spectrum Error
        axes[2].plot(epochs, val_spectrum_errors, label=label, color=color, linewidth=2, alpha=0.9)
    
    # Configure Plot 1
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Validation Loss', fontsize=12)
    axes[0].set_title('Validation Loss Comparison', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)
    
    # Configure Plot 2
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Field Error', fontsize=12)
    axes[1].set_title('Field Error Comparison', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3)
    
    # Configure Plot 3
    axes[2].set_xlabel('Epoch', fontsize=12)
    axes[2].set_ylabel('Spectrum Error', fontsize=12)
    axes[2].set_title('Spectrum Error Comparison', fontsize=14, fontweight='bold')
    axes[2].legend(fontsize=10)
    axes[2].grid(True, alpha=0.3)
    
    plt.suptitle('Multi-Model Training Comparison on Real CDON Data', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # Print summary table
    print("\nFinal Metrics Summary:")
    print(f"{'Model':<12} {'Val Loss':<12} {'Field Error':<15} {'Spectrum Error':<15}")
    print("-" * 60)
    for arch in models_to_train:
        results_arch = training_results[arch]
        final_val = results_arch['val_history'][-1]
        print(f"{arch.upper():<12} {final_val['loss']:<12.6f} {final_val['field_error']:<15.6f} {final_val['spectrum_error']:<15.6f}")
    
else:
    # SINGLE MODEL PLOT (original code)
    # Extract metrics from results
    train_losses = [h['loss'] for h in results['train_history']]
    val_losses = [h['loss'] for h in results['val_history']]
    train_field_errors = [h['field_error'] for h in results['train_history']]
    val_field_errors = [h['field_error'] for h in results['val_history']]
    val_spectrum_errors = [h['spectrum_error'] for h in results['val_history']]
    epochs = range(1, len(train_losses) + 1)
    
    # Create figure with subplots
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Plot 1: Loss
    axes[0].plot(epochs, train_losses, label='Train Loss', marker='o', markersize=3)
    axes[0].plot(epochs, val_losses, label='Val Loss', marker='s', markersize=3)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Relative L2 Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot 2: Field Error
    axes[1].plot(epochs, train_field_errors, label='Train Field Error', marker='o', markersize=3)
    axes[1].plot(epochs, val_field_errors, label='Val Field Error', marker='s', markersize=3)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Field Error')
    axes[1].set_title('Field Error (Real Space)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Plot 3: Spectrum Error
    axes[2].plot(epochs, val_spectrum_errors, label='Val Spectrum Error', marker='s', markersize=3, color='green')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Spectrum Error')
    axes[2].set_title('Spectrum Error (Frequency Space)')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.suptitle(f'{MODEL_ARCH.upper()} Training on Real CDON Data ({config.num_epochs} Epochs)', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # Print final metrics
    print(f"\nFinal Metrics:")
    print(f"  Train Loss: {train_losses[-1]:.6f}")
    print(f"  Val Loss: {val_losses[-1]:.6f}")
    print(f"  Val Field Error: {val_field_errors[-1]:.6f}")
    print(f"  Val Spectrum Error: {val_spectrum_errors[-1]:.6f}")

## Cell 7: Evaluate Best Model on Test Set

In [None]:
# Load best checkpoint
best_checkpoint_path = trainer.checkpoint_dir / 'best_model.pt'

if best_checkpoint_path.exists():
    epoch = trainer.load_checkpoint(str(best_checkpoint_path))
    print(f"‚úì Loaded best model from epoch {epoch}")
    
    # Evaluate on validation set
    test_metrics = trainer.validate()
    
    print(f"\nBest Model Test Results:")
    print(f"  Loss: {test_metrics['loss']:.6f}")
    print(f"  Field Error: {test_metrics['field_error']:.6f}")
    print(f"  Spectrum Error: {test_metrics['spectrum_error']:.6f}")
else:
    print("No checkpoint found. Using current model state.")
    test_metrics = trainer.validate()
    print(f"\nCurrent Model Test Results:")
    print(f"  Loss: {test_metrics['loss']:.6f}")
    print(f"  Field Error: {test_metrics['field_error']:.6f}")
    print(f"  Spectrum Error: {test_metrics['spectrum_error']:.6f}")

## Cell 8: Visualize Sample Predictions

In [None]:
# Get a batch from validation set
model.eval()
sample_inputs, sample_targets = next(iter(val_loader))
sample_inputs = sample_inputs.to(config.device)
sample_targets = sample_targets.to(config.device)

# Make predictions
with torch.no_grad():
    sample_preds = model(sample_inputs)

# Move to CPU for plotting
sample_inputs = sample_inputs.cpu().numpy()
sample_targets = sample_targets.cpu().numpy()
sample_preds = sample_preds.cpu().numpy()

# Plot 3 samples
num_samples = min(3, len(sample_inputs))
fig, axes = plt.subplots(num_samples, 1, figsize=(14, 4 * num_samples))

if num_samples == 1:
    axes = [axes]

for idx in range(num_samples):
    ax = axes[idx]
    
    # Extract data
    target = sample_targets[idx, 0, :]
    pred = sample_preds[idx, 0, :]
    
    # Compute error
    error = np.abs(target - pred)
    relative_error = np.linalg.norm(target - pred) / np.linalg.norm(target)
    
    # Plot
    timesteps = np.arange(len(target))
    ax.plot(timesteps, target, label='Ground Truth', alpha=0.8, linewidth=1.5)
    ax.plot(timesteps, pred, label='Prediction', alpha=0.8, linewidth=1.5, linestyle='--')
    
    ax.set_xlabel('Timestep')
    ax.set_ylabel('Displacement')
    ax.set_title(f'Sample {idx + 1}: Prediction vs Ground Truth (Relative Error: {relative_error:.4f})')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle(f'{MODEL_ARCH.upper()} Sample Predictions', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Print statistics
print(f"\nPrediction Statistics (first 3 samples):")
for idx in range(num_samples):
    target = sample_targets[idx, 0, :]
    pred = sample_preds[idx, 0, :]
    rel_error = np.linalg.norm(target - pred) / np.linalg.norm(target)
    print(f"  Sample {idx + 1}: Relative Error = {rel_error:.6f}")

In [None]:
# Import spectral analysis functions
from src.core.visualization.spectral_analysis import (
    plot_spectral_bias_comparison,
    compute_spectral_bias_metric,
    plot_spectral_bias_metrics
)

# Get predictions for spectral analysis
if MODEL_ARCH == 'all':
    # Use trained models to generate predictions
    predictions_for_spectral = {}
    
    # Get a validation batch
    sample_batch_input, sample_batch_target = next(iter(val_loader))
    sample_batch_input = sample_batch_input.to(config.device)
    
    # Generate predictions from all models
    for arch in models_to_train:
        models[arch].eval()
        with torch.no_grad():
            pred = models[arch](sample_batch_input)
            predictions_for_spectral[arch] = pred.cpu()
    
    ground_truth_spectral = sample_batch_target
    
    # Plot spectral bias comparison
    print("\nGenerating spectral bias comparison plot...")
    plot_spectral_bias_comparison(
        predictions=predictions_for_spectral,
        ground_truth=ground_truth_spectral,
        title='Frequency Spectrum: Model Predictions vs Ground Truth',
        n_bins=32,
        show_uncertainty=True
    )
    
    # Compute quantitative metrics
    print("\nSpectral Bias Metrics:")
    metrics_spectral = {}
    for arch in models_to_train:
        metrics = compute_spectral_bias_metric(
            prediction=predictions_for_spectral[arch],
            ground_truth=ground_truth_spectral,
            n_bins=32
        )
        metrics_spectral[arch] = metrics
        
        print(f"\n{arch.upper()}:")
        print(f"  Low freq error:  {metrics['low_freq_error']:.4f}")
        print(f"  Mid freq error:  {metrics['mid_freq_error']:.4f}")
        print(f"  High freq error: {metrics['high_freq_error']:.4f}")
        print(f"  Spectral bias ratio: {metrics['spectral_bias_ratio']:.4f}")
    
    # Plot metrics comparison
    plot_spectral_bias_metrics(metrics_spectral)
    
else:
    # Single model spectral analysis
    model.eval()
    
    # Get a validation batch
    sample_batch_input, sample_batch_target = next(iter(val_loader))
    sample_batch_input = sample_batch_input.to(config.device)
    
    # Generate prediction
    with torch.no_grad():
        pred = model(sample_batch_input)
    
    # Create predictions dict
    predictions_for_spectral = {MODEL_ARCH: pred.cpu()}
    ground_truth_spectral = sample_batch_target
    
    # Plot spectral comparison
    print("\nGenerating frequency spectrum plot...")
    plot_spectral_bias_comparison(
        predictions=predictions_for_spectral,
        ground_truth=ground_truth_spectral,
        title=f'{MODEL_ARCH.upper()} Frequency Spectrum Analysis',
        n_bins=32,
        show_uncertainty=True
    )
    
    # Compute metrics
    metrics = compute_spectral_bias_metric(
        prediction=pred.cpu(),
        ground_truth=ground_truth_spectral,
        n_bins=32
    )
    
    print(f"\nSpectral Bias Metrics for {MODEL_ARCH.upper()}:")
    print(f"  Low freq error:  {metrics['low_freq_error']:.4f}")
    print(f"  Mid freq error:  {metrics['mid_freq_error']:.4f}")
    print(f"  High freq error: {metrics['high_freq_error']:.4f}")
    print(f"  Spectral bias ratio: {metrics['spectral_bias_ratio']:.4f}")
    print(f"\nInterpretation:")
    if metrics['spectral_bias_ratio'] > 2.0:
        print(f"  ‚ö† Significant spectral bias (ratio > 2.0)")
        print(f"  Model struggles with high frequencies")
    elif metrics['spectral_bias_ratio'] > 1.5:
        print(f"  ‚ö° Moderate spectral bias (ratio > 1.5)")
    else:
        print(f"  ‚úì Low spectral bias (ratio ‚â§ 1.5)")
        print(f"  Model captures frequencies well")

## Cell 9: Spectral Bias Analysis (NEW!)

Analyzes frequency spectrum of model predictions to identify spectral bias.
Shows which models capture high-frequency content better.

## Summary

This notebook demonstrated:
1. ‚úì Loading real CDON data using existing `CDONDataset`
2. ‚úì Creating models using existing `create_model` factory
3. ‚úì Training with existing `SimpleTrainer` class (50 epochs)
4. ‚úì Visualizing training progress and predictions

**Next steps:**
- Try different model architectures by changing `MODEL_ARCH`
- Experiment with hyperparameters in `TrainingConfig`
- Train for more epochs for better convergence
- Compare results across different models