## Summary

**Advanced Training Complete! ✓**

**New Features Demonstrated:**
1. ✓ **Data Augmentation**: Rotation, flip, brightness, noise
2. ✓ **TensorBoard Logging**: All metrics logged in real-time
3. ✓ **Learning Rate Scheduling**: Adaptive LR with ReduceLROnPlateau
4. ✓ **Early Stopping**: Training stopped when validation plateaus
5. ✓ **Model Checkpointing**: Best model saved automatically
6. ✓ **Gradient Clipping**: Built into train_step for stability

**Benefits Over Basic Training:**
- Better generalization through data augmentation
- Real-time monitoring via TensorBoard
- More stable training with LR scheduling
- Automatic early stopping prevents overfitting
- Production-ready training pipeline

**View Results:**
Run in terminal: `tensorboard --logdir=../runs`
Then open: http://localhost:6006

**Next Steps:**
1. Compare augmented vs non-augmented performance
2. Tune augmentation parameters
3. Try different LR schedules (CosineAnnealing, etc.)
4. Experiment with λ_physics weight
5. Evaluate on test set (phase5c_evaluate.ipynb)

In [None]:
# Training configuration
NUM_EPOCHS = 50
LAMBDA_PHYSICS = 0.1
PATIENCE = 10
LOG_INTERVAL = 5  # Log predictions every N epochs

best_val_loss = float('inf')
patience_counter = 0
best_model_path = '../models/best_pinn_augmented.pth'
Path(best_model_path).parent.mkdir(parents=True, exist_ok=True)

print("\n" + "="*70)
print(" "*20 + "ADVANCED TRAINING START")
print("="*70)
print(f"Epochs: {NUM_EPOCHS} | Patience: {PATIENCE} | λ_physics: {LAMBDA_PHYSICS}")
print("Features: Augmentation ✓ | TensorBoard ✓ | LR Scheduling ✓")
print("="*70 + "\n")

start_time = time.time()

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    
    # Training phase
    model.train()
    train_losses = {'total': [], 'mse_params': [], 'ce_class': [], 'physics_residual': []}
    
    for batch_idx, (images, params, labels) in enumerate(train_loader):
        # Convert augmented images to tensors if needed
        if isinstance(images, np.ndarray):
            images = torch.from_numpy(images).float()
        images = images.to(device)
        params = params.float().to(device)
        labels = labels.long().to(device)
        
        # Training step
        losses = train_step(model, images, params, labels, optimizer, LAMBDA_PHYSICS, device)
        
        for key in losses:
            train_losses[key].append(losses[key])
    
    # Average training losses
    avg_train_losses = {key: np.mean(vals) for key, vals in train_losses.items()}
    
    # Validation phase
    model.eval()
    val_losses = {'total': [], 'mse_params': [], 'ce_class': [], 'physics_residual': []}
    
    for images, params, labels in val_loader:
        if isinstance(images, np.ndarray):
            images = torch.from_numpy(images).float()
        images = images.to(device)
        params = params.float().to(device)
        labels = labels.long().to(device)
        
        losses = validate_step(model, images, params, labels, LAMBDA_PHYSICS, device)
        
        for key in losses:
            val_losses[key].append(losses[key])
    
    # Average validation losses
    avg_val_losses = {key: np.mean(vals) for key, vals in val_losses.items()}
    
    # Update learning rate
    scheduler.step(avg_val_losses['total'])
    current_lr = optimizer.param_groups[0]['lr']
    
    # Log to TensorBoard
    logger.log_training_metrics(
        epoch,
        avg_train_losses,
        avg_val_losses,
        current_lr
    )
    
    # Log predictions periodically
    if epoch % LOG_INTERVAL == 0:
        logger.log_predictions_comparison(model, val_loader, device, epoch, num_samples=6)
        logger.log_histograms(model, epoch)
    
    epoch_time = time.time() - epoch_start
    
    # Print progress
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] ({epoch_time:.1f}s) LR: {current_lr:.2e}")
    print(f"  Train - Loss: {avg_train_losses['total']:.4f} | " +
          f"MSE: {avg_train_losses['mse_params']:.4f} | " +
          f"CE: {avg_train_losses['ce_class']:.4f} | " +
          f"Phys: {avg_train_losses['physics_residual']:.4f}")
    print(f"  Val   - Loss: {avg_val_losses['total']:.4f} | " +
          f"MSE: {avg_val_losses['mse_params']:.4f} | " +
          f"CE: {avg_val_losses['ce_class']:.4f} | " +
          f"Phys: {avg_val_losses['physics_residual']:.4f}")
    
    # Save best model
    if avg_val_losses['total'] < best_val_loss:
        best_val_loss = avg_val_losses['total']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_val_losses['total'],
            'train_loss': avg_train_losses['total'],
        }, best_model_path)
        print(f"  ✓ Best model saved (val_loss: {best_val_loss:.4f})")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"\n⚠ Early stopping triggered after {epoch+1} epochs")
            break
    
    print()

total_time = time.time() - start_time

print("="*70)
print(f"Training completed in {total_time/60:.1f} minutes")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Model saved: {best_model_path}")
print("="*70)

# Log final hyperparameters
hparams = {
    'batch_size': BATCH_SIZE,
    'learning_rate': 1e-3,
    'lambda_physics': LAMBDA_PHYSICS,
    'dropout': 0.2,
    'augmentation': 'Yes',
    'scheduler': 'ReduceLROnPlateau'
}
metrics = {
    'final_val_loss': best_val_loss,
    'total_epochs': epoch + 1
}
logger.log_hyperparameters(hparams, metrics)

logger.close()

## 7. Advanced Training Loop

Key improvements:
- TensorBoard logging of all metrics
- Model checkpointing (save best model)
- Early stopping with patience
- Gradient clipping for stability
- Regular prediction visualizations

In [None]:
# Create model
model = PhysicsInformedNN(input_size=64, dropout_rate=0.2)
model = model.to(device)

# Log model architecture to TensorBoard
logger.log_model_graph(model, input_size=(1, 1, 64, 64))

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

# Optimizer with weight decay
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

# Learning rate schedulers (we'll use two strategies)
# 1. Reduce on plateau: Reduce LR when validation loss plateaus
scheduler_plateau = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
)

# 2. Cosine annealing: Smooth LR decay
scheduler_cosine = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=10, T_mult=2, eta_min=1e-6
)

# We'll use ReduceLROnPlateau for this demo
scheduler = scheduler_plateau

print("✓ Model and optimizers initialized")
print(f"  Initial LR: {optimizer.param_groups[0]['lr']:.2e}")
print(f"  Scheduler: ReduceLROnPlateau (factor=0.5, patience=5)")

## 6. Initialize Model and Optimizers

In [None]:
# Create TensorBoard logger
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_name = f'pinn_augmented_{timestamp}'

logger = PINNLogger(log_dir='../runs', experiment_name=experiment_name)

print("\n" + "="*70)
print("TensorBoard Logging Enabled!")
print("="*70)
print(f"Experiment: {experiment_name}")
print(f"\nTo view logs in real-time, run in terminal:")
print(f"  tensorboard --logdir=../runs")
print(f"\nThen open: http://localhost:6006")
print("="*70)

## 5. Initialize TensorBoard Logger

In [None]:
BATCH_SIZE = 32

# Training set WITH augmentation
train_dataset = LensDataset(DATA_FILE, split='train', transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

# Validation set WITHOUT augmentation (important!)
val_dataset = LensDataset(DATA_FILE, split='val', transform=None)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Training set: {len(train_dataset)} samples (WITH augmentation)")
print(f"Validation set: {len(val_dataset)} samples (NO augmentation)")
print(f"Batches per epoch: {len(train_loader)}")

## 4. Load Dataset with Augmentation

In [None]:
# Load a sample image
DATA_FILE = '../data/processed/lens_training_data.h5'
temp_dataset = LensDataset(DATA_FILE, split='train')
sample_image, _, _ = temp_dataset[0]

# Apply augmentation multiple times
fig, axes = plt.subplots(2, 4, figsize=(14, 7))
axes = axes.ravel()

# Original
axes[0].imshow(sample_image[0], cmap='viridis', origin='lower')
axes[0].set_title('Original', fontsize=12, fontweight='bold')
axes[0].axis('off')

# Augmented versions
for i in range(1, 8):
    aug_image = train_transforms(sample_image.copy())
    axes[i].imshow(aug_image[0], cmap='viridis', origin='lower')
    axes[i].set_title(f'Augmented {i}', fontsize=12)
    axes[i].axis('off')

plt.suptitle('Data Augmentation Examples', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

print("Note: Each training iteration applies random augmentations!")

## 3. Visualize Augmentation Examples

In [None]:
# Create training transforms with augmentation
train_transforms = get_training_transforms(
    rotation=True,           # Random 90/180/270 degree rotations
    flip=True,               # Random horizontal/vertical flips
    brightness=True,         # Random brightness adjustments
    noise=True,              # Random Gaussian noise
    rotation_p=0.5,         # 50% probability of rotation
    flip_p=0.5,             # 50% probability of flip
    brightness_p=0.5,       # 50% probability of brightness change
    noise_p=0.3,            # 30% probability of noise
    brightness_range=(0.8, 1.2),  # ±20% brightness variation
    noise_std=0.01          # Small noise for realism
)

print("✓ Data augmentation pipeline created")
print("\nAugmentation Pipeline:")
print("  - Random Rotation (90°, 180°, 270°): p=0.5")
print("  - Random Flip (H/V): p=0.5")
print("  - Random Brightness (0.8-1.2x): p=0.5")
print("  - Random Gaussian Noise (σ=0.01): p=0.3")

## 2. Setup Data Augmentation

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import time
import sys

sys.path.append('..')
from src.ml import (
    PhysicsInformedNN, get_training_transforms,
    PINNLogger
)
from src.ml.pinn import train_step, validate_step
from src.ml.generate_dataset import LensDataset

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if device.type == 'cuda':
    print(f"  GPU: {torch.cuda.get_device_name(0)}")

plt.rcParams['figure.figsize'] = (14, 8)
print("✓ All modules imported")

## 1. Import Libraries

# Phase 5d: Advanced Training with Augmentation & TensorBoard

This notebook demonstrates **advanced training features**:

## New Features:
- ✓ **Data Augmentation**: Rotation, flip, brightness, noise
- ✓ **TensorBoard Logging**: Real-time training visualization
- ✓ **Learning Rate Scheduling**: Adaptive learning rate
- ✓ **Early Stopping**: Prevent overfitting
- ✓ **Gradient Clipping**: Training stability
- ✓ **Model Checkpointing**: Save best weights

## Improvements Over Basic Training:
- Better generalization through augmentation
- Real-time monitoring with TensorBoard
- More robust training procedures
- Publication-quality logging

Let's implement production-ready training!