# Phase 5b: Train Physics-Informed Neural Network

This notebook trains the PINN model to:
1. **Regress lens parameters**: M_vir, r_s, β_x, β_y, H₀
2. **Classify dark matter type**: CDM, WDM, SIDM
3. **Enforce physical constraints**: Lens equation via physics-informed loss

## Training Configuration

- **Model**: Dual-head CNN with encoder + dense layers
- **Loss**: MSE (params) + CrossEntropy (class) + Physics residual
- **Optimizer**: AdamW with learning rate scheduling
- **Augmentation**: Random rotation, flip, brightness
- **Batch size**: 32
- **Epochs**: 50 (with early stopping)
- **Device**: GPU if available, otherwise CPU

## Summary

**Training Complete! ✓**

The PINN model has been successfully trained with:
- ✓ Physics-informed loss enforcing lens equation
- ✓ Dual-head architecture for parameters + classification
- ✓ Learning rate scheduling with ReduceLROnPlateau
- ✓ Early stopping to prevent overfitting
- ✓ Best model saved for evaluation

**Next Steps:**
1. Comprehensive evaluation on test set (`phase5c_evaluate.ipynb`)
2. Analyze confusion matrix and calibration
3. Compute parameter errors (MAE, RMSE, MAPE)
4. Generate publication-quality plots

**Model saved at:** `../models/best_pinn_model.pth`

In [None]:
# Load best model
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Get some validation samples
val_images, val_params, val_labels = next(iter(val_loader))
val_images = val_images.float().to(device)

# Make predictions
with torch.no_grad():
    pred_params, pred_class_logits = model(val_images)
    pred_probs = torch.softmax(pred_class_logits, dim=1)
    pred_classes = torch.argmax(pred_probs, dim=1)

# Move to CPU for visualization
val_images_cpu = val_images.cpu().numpy()
val_params_cpu = val_params.numpy()
val_labels_cpu = val_labels.numpy()
pred_params_cpu = pred_params.cpu().numpy()
pred_classes_cpu = pred_classes.cpu().numpy()
pred_probs_cpu = pred_probs.cpu().numpy()

# Visualize first 9 samples
fig, axes = plt.subplots(3, 3, figsize=(14, 14))
class_names = ['CDM', 'WDM', 'SIDM']
param_names = ['M_vir', 'r_s', 'β_x', 'β_y', 'H₀']

for i in range(9):
    ax = axes[i // 3, i % 3]
    
    # Show image
    ax.imshow(val_images_cpu[i, 0], cmap='viridis', origin='lower')
    
    # True vs predicted
    true_class = class_names[val_labels_cpu[i]]
    pred_class = class_names[pred_classes_cpu[i]]
    confidence = pred_probs_cpu[i, pred_classes_cpu[i]] * 100
    
    # Check if classification is correct
    correct = "✓" if val_labels_cpu[i] == pred_classes_cpu[i] else "✗"
    color = 'green' if correct == "✓" else 'red'
    
    title = f"{correct} True: {true_class} | Pred: {pred_class} ({confidence:.0f}%)\n"
    title += f"M: {val_params_cpu[i,0]:.2e} → {pred_params_cpu[i,0]:.2e}\n"
    title += f"H₀: {val_params_cpu[i,4]:.1f} → {pred_params_cpu[i,4]:.1f}"
    
    ax.set_title(title, fontsize=9, color=color, fontweight='bold')
    ax.axis('off')

plt.suptitle('Validation Samples: Predictions vs Ground Truth', fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()

# Calculate accuracy on this batch
accuracy = np.mean(val_labels_cpu == pred_classes_cpu) * 100
print(f"\nBatch classification accuracy: {accuracy:.1f}%")

## 7. Quick Validation on Test Samples

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

epochs_range = range(1, len(history['train_loss']) + 1)

# Total loss
ax = axes[0, 0]
ax.plot(epochs_range, history['train_loss'], 'b-', label='Train', linewidth=2)
ax.plot(epochs_range, history['val_loss'], 'r-', label='Validation', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Total Loss', fontsize=12)
ax.set_title('Total Loss', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Parameter MSE
ax = axes[0, 1]
ax.plot(epochs_range, history['train_mse'], 'b-', label='Train', linewidth=2)
ax.plot(epochs_range, history['val_mse'], 'r-', label='Validation', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('MSE Loss', fontsize=12)
ax.set_title('Parameter Regression (MSE)', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Classification CE
ax = axes[1, 0]
ax.plot(epochs_range, history['train_ce'], 'b-', label='Train', linewidth=2)
ax.plot(epochs_range, history['val_ce'], 'r-', label='Validation', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Cross-Entropy Loss', fontsize=12)
ax.set_title('Classification Loss', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Physics residual
ax = axes[1, 1]
ax.plot(epochs_range, history['train_physics'], 'b-', label='Train', linewidth=2)
ax.plot(epochs_range, history['val_physics'], 'r-', label='Validation', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Physics Residual', fontsize=12)
ax.set_title('Physics Constraint Violation', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.suptitle('Training History - Physics-Informed Neural Network', fontsize=15, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

# Learning rate schedule
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(epochs_range, history['lr'], 'g-', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Learning Rate', fontsize=12)
ax.set_title('Learning Rate Schedule', fontsize=13, fontweight='bold')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Visualize Training History

In [None]:
# Training history
history = {
    'train_loss': [], 'train_mse': [], 'train_ce': [], 'train_physics': [],
    'val_loss': [], 'val_mse': [], 'val_ce': [], 'val_physics': [],
    'lr': []
}

best_val_loss = float('inf')
patience_counter = 0
best_model_path = '../models/best_pinn_model.pth'
Path(best_model_path).parent.mkdir(parents=True, exist_ok=True)

print("\n" + "="*70)
print(" "*25 + "TRAINING START")
print("="*70 + "\n")

start_time = time.time()

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    
    # Training phase
    model.train()
    train_losses = {'total': [], 'mse_params': [], 'ce_class': [], 'physics_residual': []}
    
    for batch_idx, (images, params, labels) in enumerate(train_loader):
        images = images.float().to(device)
        params = params.float().to(device)
        labels = labels.long().to(device)
        
        # Training step
        losses = train_step(model, images, params, labels, optimizer, LAMBDA_PHYSICS, device)
        
        for key in losses:
            train_losses[key].append(losses[key])
    
    # Average training losses
    avg_train_loss = np.mean(train_losses['total'])
    avg_train_mse = np.mean(train_losses['mse_params'])
    avg_train_ce = np.mean(train_losses['ce_class'])
    avg_train_physics = np.mean(train_losses['physics_residual'])
    
    # Validation phase
    model.eval()
    val_losses = {'total': [], 'mse_params': [], 'ce_class': [], 'physics_residual': []}
    
    for images, params, labels in val_loader:
        images = images.float().to(device)
        params = params.float().to(device)
        labels = labels.long().to(device)
        
        losses = validate_step(model, images, params, labels, LAMBDA_PHYSICS, device)
        
        for key in losses:
            val_losses[key].append(losses[key])
    
    # Average validation losses
    avg_val_loss = np.mean(val_losses['total'])
    avg_val_mse = np.mean(val_losses['mse_params'])
    avg_val_ce = np.mean(val_losses['ce_class'])
    avg_val_physics = np.mean(val_losses['physics_residual'])
    
    # Update learning rate
    scheduler.step(avg_val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Store history
    history['train_loss'].append(avg_train_loss)
    history['train_mse'].append(avg_train_mse)
    history['train_ce'].append(avg_train_ce)
    history['train_physics'].append(avg_train_physics)
    history['val_loss'].append(avg_val_loss)
    history['val_mse'].append(avg_val_mse)
    history['val_ce'].append(avg_val_ce)
    history['val_physics'].append(avg_val_physics)
    history['lr'].append(current_lr)
    
    epoch_time = time.time() - epoch_start
    
    # Print progress
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] ({epoch_time:.1f}s)")
    print(f"  Train - Loss: {avg_train_loss:.4f} | MSE: {avg_train_mse:.4f} | CE: {avg_train_ce:.4f} | Phys: {avg_train_physics:.4f}")
    print(f"  Val   - Loss: {avg_val_loss:.4f} | MSE: {avg_val_mse:.4f} | CE: {avg_val_ce:.4f} | Phys: {avg_val_physics:.4f}")
    print(f"  LR: {current_lr:.2e}")
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_val_loss,
        }, best_model_path)
        print(f"  ✓ Best model saved (val_loss: {best_val_loss:.4f})")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"\n⚠ Early stopping triggered after {epoch+1} epochs")
            break
    
    print()

total_time = time.time() - start_time
print("="*70)
print(f"Training completed in {total_time/60:.1f} minutes")
print(f"Best validation loss: {best_val_loss:.4f}")
print("="*70)

## 5. Training Loop with Live Progress

In [None]:
# Training configuration
NUM_EPOCHS = 50
LEARNING_RATE = 1e-3
LAMBDA_PHYSICS = 0.1
PATIENCE = 10  # For early stopping

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

# Learning rate scheduler (reduce on plateau)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

print("Training Configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Physics weight (λ): {LAMBDA_PHYSICS}")
print(f"  Optimizer: AdamW")
print(f"  Scheduler: ReduceLROnPlateau")
print(f"  Early stopping patience: {PATIENCE}")

## 4. Configure Training

In [None]:
# Create model
model = PhysicsInformedNN(input_size=64, dropout_rate=0.2)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("="*70)
print(" "*20 + "MODEL ARCHITECTURE")
print("="*70)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"\nModel structure:")
print(model)
print("="*70)

## 3. Initialize Model

In [None]:
DATA_FILE = '../data/processed/lens_training_data.h5'
BATCH_SIZE = 32

# Create datasets
train_dataset = LensDataset(DATA_FILE, split='train')
val_dataset = LensDataset(DATA_FILE, split='val')

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Training set: {len(train_dataset)} samples")
print(f"Validation set: {len(val_dataset)} samples")
print(f"Batches per epoch: {len(train_loader)}")

# Test loading a batch
images, params, labels = next(iter(train_loader))
print(f"\nBatch shapes:")
print(f"  Images: {images.shape}")
print(f"  Params: {params.shape}")
print(f"  Labels: {labels.shape}")

## 2. Load Dataset

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import time
import sys

sys.path.append('..')
from src.ml import PhysicsInformedNN
from src.ml.pinn import train_step, validate_step
from src.ml.generate_dataset import LensDataset

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if device.type == 'cuda':
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

plt.rcParams['figure.figsize'] = (14, 8)
print("✓ All modules imported")

## 1. Import Libraries and Setup