# WP4 - PINN Training Demo

This notebook demonstrates training and evaluating a Physics-Informed Neural Network (PINN) for 6-DOF rocket trajectory prediction.


In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import torch
import yaml

# Add project root to path
project_root = Path().resolve().parent
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / 'src'))

from src.models import PINN
from src.utils.loaders import create_dataloaders
from src.data.preprocess import load_scales, Scales
from src.eval.visualize_pinn import (
    evaluate_model,
    plot_trajectory_comparison,
    plot_loss_curves
)


## Load Data and Configuration


In [None]:
# Load configuration
config_path = project_root / 'configs' / 'train.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Load scales
scales_path = project_root / 'configs' / 'scales.yaml'
scales_dict = load_scales(str(scales_path))
scales = Scales(**scales_dict)

# Create dataloaders
data_dir = project_root / 'data' / 'processed'
train_loader, val_loader, test_loader = create_dataloaders(
    data_dir=str(data_dir),
    batch_size=8,
    num_workers=0
)

print(f"Train cases: {len(train_loader.dataset)}")
print(f"Val cases: {len(val_loader.dataset)}")
print(f"Test cases: {len(test_loader.dataset)}")
print(f"Context dimension: {train_loader.dataset.context_dim}")


## Create Model


In [None]:
# Model configuration
model_cfg = config.get('model', {})
context_dim = train_loader.dataset.context_dim

model = PINN(
    context_dim=context_dim,
    n_hidden=model_cfg.get('n_hidden', 6),
    n_neurons=model_cfg.get('n_neurons', 128),
    activation=model_cfg.get('activation', 'tanh'),
    fourier_features=model_cfg.get('fourier_features', 8),
    layer_norm=model_cfg.get('layer_norm', True),
    dropout=model_cfg.get('dropout', 0.05)
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Device: {device}")


## Load Trained Model (or Train)

To train a model, use:
```bash
./scripts/train_pinn.sh --config configs/train.yaml
```

Here we'll load a checkpoint if available, otherwise show how to train.


In [None]:
# Try to load checkpoint
checkpoint_path = project_root / 'experiments' / 'pinn_baseline' / 'checkpoints' / 'best.pt'

if checkpoint_path.exists():
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from epoch {checkpoint['epoch']}")
    print(f"Validation loss: {checkpoint['loss']:.6f}")
else:
    print("No checkpoint found. Run training first:")
    print("  ./scripts/train_pinn.sh --config configs/train.yaml")


## Evaluate Model


In [None]:
# Evaluate on test set
model.eval()
metrics = evaluate_model(model, test_loader, device, scales)

print("Test Set Metrics:")
print(f"  Total RMSE: {metrics['rmse_total']:.6f}")
print("\nPer-component RMSE:")
for name, rmse in metrics['rmse_per_component'].items():
    print(f"  {name}: {rmse:.6f}")


## Visualize Sample Trajectories


In [None]:
# Plot sample trajectories
fig_dir = project_root / 'experiments' / 'pinn_baseline' / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)

model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        if i >= 3:  # Plot first 3 cases
            break
        
        t = batch['t'].to(device)
        context = batch['context'].to(device)
        state_true = batch['state'].to(device)
        
        if t.dim() == 2:
            t = t.unsqueeze(-1)
        
        state_pred = model(t, context)
        
        # Convert to numpy
        t_np = t[0].cpu().squeeze(-1).numpy()
        pred_np = state_pred[0].cpu().numpy()
        true_np = state_true[0].cpu().numpy()
        
        plot_trajectory_comparison(
            t_np, pred_np, true_np, scales,
            save_path=str(fig_dir / f'trajectory_case_{i}.png'),
            title=f'Case {i}'
        )
        
        print(f"Saved trajectory plot for case {i}")

print(f"\nFigures saved to: {fig_dir}")


## Plot Training Curves

If training log is available, plot loss curves.


In [None]:
# Load training log
import json

log_path = project_root / 'experiments' / 'pinn_baseline' / 'train_log.json'

if log_path.exists():
    with open(log_path, 'r') as f:
        train_log = json.load(f)
    
    plot_loss_curves(
        train_log,
        save_path=str(fig_dir / 'loss_curves.png')
    )
    print("Loss curves saved")
else:
    print("Training log not found. Run training first.")
