# Dev Phase 5: Training Pipeline

Training notebook for the MIGT-TVDT model with:
- Mixed precision training (AMP)
- Gradient accumulation
- Learning rate scheduling with warmup
- Early stopping
- Checkpointing

**Tests:**
1. Loss functions (pinball loss, per-quantile breakdown)
2. Learning rate scheduler (warmup, cosine annealing)
3. Training step (single batch, gradient flow)
4. Validation step (metrics computation)
5. Full training loop (mini run)
6. Checkpoint save/load
7. Early stopping behavior
8. Phase 3/4 integration

In [None]:
# Setup: Mount drive, add paths
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.insert(0, '/content/drive/MyDrive/Colab Notebooks/Transformers/FP/src')

!pip install pyyaml -q

In [None]:
# Imports
import torch
import torch.nn as nn
import numpy as np
import yaml
from pathlib import Path
import json

# Training imports
from training.loss_functions import PinballLoss, CombinedQuantileLoss
from training.scheduler import WarmupCosineScheduler, LinearWarmupScheduler
from training.trainer import Trainer, EarlyStopping, create_trainer

# Model imports
from model.migt_tvdt import MIGT_TVDT

# Data imports
from data.dataset import NQDataModule, collate_fn

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Load configurations
BASE_DIR = Path('/content/drive/MyDrive/Colab Notebooks/Transformers/FP')

with open(BASE_DIR / 'configs/model_config.yaml') as f:
    model_config = yaml.safe_load(f)

with open(BASE_DIR / 'configs/training_config.yaml') as f:
    train_config = yaml.safe_load(f)

print("Model config loaded")
print(f"  d_model: {model_config['model']['d_model']}")
print(f"  n_variables: {model_config['model']['n_variables']}")

print("\nTraining config loaded")
print(f"  batch_size: {train_config['training']['batch_size']}")
print(f"  max_epochs: {train_config['training']['max_epochs']}")
print(f"  lr: {train_config['optimizer']['lr']}")

In [None]:
# Test parameters
B = 4  # Batch size for unit tests
H = 5  # Horizons
Q = 7  # Quantiles

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Testing on device: {device}")

## Test 1: Loss Functions

In [None]:
def test_loss_functions():
    """Test pinball loss computation and properties."""
    print("=" * 60)
    print("TEST 1: Loss Functions")
    print("=" * 60)
    
    quantiles = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
    
    # 1.1: Basic pinball loss
    print("\n1.1 PinballLoss basic computation")
    loss_fn = PinballLoss(quantiles)
    
    # Perfect predictions (all quantiles equal target)
    predictions = torch.zeros(B, H, Q)
    targets = torch.zeros(B, H)
    
    loss = loss_fn(predictions, targets)
    print(f"  Loss with perfect predictions: {loss.item():.6f}")
    assert loss.item() == 0.0, "Perfect predictions should have zero loss"
    print("  [PASS] Zero loss for perfect predictions")
    
    # 1.2: Asymmetric penalty check
    print("\n1.2 Asymmetric penalty (quantile tau=0.9)")
    
    # Underprediction (target > prediction) should be penalized more for high tau
    pred_under = torch.zeros(B, H, Q)
    target_above = torch.ones(B, H)  # Target above prediction
    loss_under = loss_fn(pred_under, target_above)
    
    pred_over = torch.ones(B, H, Q) * 2
    target_below = torch.ones(B, H)  # Target below prediction
    loss_over = loss_fn(pred_over, target_below)
    
    print(f"  Underprediction loss: {loss_under.item():.6f}")
    print(f"  Overprediction loss: {loss_over.item():.6f}")
    print("  [PASS] Asymmetric losses computed")
    
    # 1.3: Per-quantile breakdown
    print("\n1.3 Per-quantile loss breakdown")
    predictions = torch.randn(B, H, Q)
    targets = torch.randn(B, H)
    
    q_losses = loss_fn.per_quantile_loss(predictions, targets)
    print(f"  Per-quantile losses: {q_losses}")
    assert len(q_losses) == Q, f"Expected {Q} quantile losses"
    print("  [PASS] Per-quantile breakdown")
    
    # 1.4: Per-horizon breakdown
    print("\n1.4 Per-horizon loss breakdown")
    h_losses = loss_fn.per_horizon_loss(predictions, targets)
    print(f"  Per-horizon losses: {h_losses}")
    assert len(h_losses) == H, f"Expected {H} horizon losses"
    print("  [PASS] Per-horizon breakdown")
    
    # 1.5: Combined loss
    print("\n1.5 CombinedQuantileLoss")
    combined_loss = CombinedQuantileLoss(quantiles)
    
    loss_dict = combined_loss(predictions, targets)
    print(f"  Total loss: {loss_dict['total'].item():.6f}")
    print(f"  Pinball loss: {loss_dict['pinball'].item():.6f}")
    assert 'total' in loss_dict and 'pinball' in loss_dict and 'crossing' in loss_dict
    print(f"  Crossing loss: {loss_dict['crossing'].item():.6f}")
    print("  [PASS] Combined loss components")
    
    # 1.6: Metrics computation
    print("\n1.6 Metrics computation")
    metrics = combined_loss.get_metrics(predictions, targets)
    print(f"  PICP 80%: {metrics['picp_80']:.3f}")
    print(f"  Interval 80 mean: {metrics['interval_80_mean']:.4f}")
    assert 'picp_80' in metrics and 'coverage_q50' in metrics
    print("  [PASS] Metrics computed")
    
    # 1.7: Gradient flow
    print("\n1.7 Gradient flow through loss")
    predictions = torch.randn(B, H, Q, requires_grad=True)
    targets = torch.randn(B, H)
    
    loss = loss_fn(predictions, targets)
    loss.backward()
    
    assert predictions.grad is not None, "Gradients not computed"
    assert not torch.isnan(predictions.grad).any(), "NaN gradients"
    print(f"  Gradient norm: {predictions.grad.norm().item():.6f}")
    print("  [PASS] Gradients flow correctly")
    
    print("\n" + "=" * 60)
    print("TEST 1 COMPLETE: All loss function tests passed")
    print("=" * 60)

test_loss_functions()

## Test 2: Learning Rate Scheduler

In [None]:
def test_scheduler():
    """Test learning rate scheduling with warmup and cosine annealing."""
    print("=" * 60)
    print("TEST 2: Learning Rate Scheduler")
    print("=" * 60)
    
    # Create dummy model and optimizer
    model = nn.Linear(10, 10)
    base_lr = 1e-4
    optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr)
    
    # 2.1: Warmup phase
    print("\n2.1 Warmup phase")
    scheduler = WarmupCosineScheduler(
        optimizer,
        warmup_steps=100,
        t_0=10,
        t_mult=2,
        eta_min=1e-6
    )
    
    # Simulate warmup
    warmup_lrs = []
    for step in range(100):
        scheduler.step_batch()
        warmup_lrs.append(optimizer.param_groups[0]['lr'])
    
    print(f"  LR at step 0: {warmup_lrs[0]:.2e}")
    print(f"  LR at step 50: {warmup_lrs[50]:.2e}")
    print(f"  LR at step 99: {warmup_lrs[99]:.2e}")
    
    # Verify linear increase
    assert warmup_lrs[0] < warmup_lrs[50] < warmup_lrs[99], "LR should increase during warmup"
    assert abs(warmup_lrs[99] - base_lr) < 1e-6, f"LR should reach base_lr at warmup end"
    print("  [PASS] Linear warmup")
    
    # 2.2: Cosine annealing
    print("\n2.2 Cosine annealing phase")
    
    epoch_lrs = []
    for epoch in range(30):
        scheduler.step()
        epoch_lrs.append(optimizer.param_groups[0]['lr'])
    
    print(f"  LR at epoch 0: {epoch_lrs[0]:.2e}")
    print(f"  LR at epoch 9 (end of first cycle): {epoch_lrs[9]:.2e}")
    print(f"  LR at epoch 10 (restart): {epoch_lrs[10]:.2e}")
    print(f"  LR at epoch 29: {epoch_lrs[29]:.2e}")
    
    # Verify cosine decay
    assert epoch_lrs[0] > epoch_lrs[5], "LR should decrease in first cycle"
    assert epoch_lrs[9] < epoch_lrs[10], "LR should jump at restart (epoch 10)"
    print("  [PASS] Cosine annealing with restarts")
    
    # 2.3: State dict save/load
    print("\n2.3 State dict save/load")
    state = scheduler.state_dict()
    
    # Create new scheduler and load state
    optimizer2 = torch.optim.AdamW(model.parameters(), lr=base_lr)
    scheduler2 = WarmupCosineScheduler(
        optimizer2, warmup_steps=100, t_0=10, t_mult=2, eta_min=1e-6
    )
    scheduler2.load_state_dict(state)
    
    assert scheduler2.warmup_finished == scheduler.warmup_finished
    assert scheduler2.epoch_in_cycle == scheduler.epoch_in_cycle
    print("  [PASS] State dict save/load")
    
    print("\n" + "=" * 60)
    print("TEST 2 COMPLETE: All scheduler tests passed")
    print("=" * 60)

test_scheduler()

## Test 3: Training Step

In [None]:
def test_training_step():
    """Test single training step with gradient flow."""
    print("=" * 60)
    print("TEST 3: Training Step")
    print("=" * 60)
    
    # Create model
    model = MIGT_TVDT(model_config['model']).to(device)
    model.train()
    
    # Create optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=train_config['optimizer']['lr'],
        weight_decay=train_config['optimizer']['weight_decay']
    )
    
    # Loss function
    loss_fn = CombinedQuantileLoss(
        quantiles=train_config['quantile_regression']['quantiles']
    )
    
    # Create synthetic batch
    T, V = 288, model_config['model']['n_variables']
    batch = {
        'features': torch.randn(B, T, V, device=device),
        'attention_mask': torch.ones(B, T, dtype=torch.bool, device=device),
        'targets': torch.randn(B, H, device=device),
        'bar_in_day': torch.arange(T).unsqueeze(0).expand(B, -1).to(device),
        'day_of_week': torch.randint(0, 5, (B,), device=device),
        'day_of_month': torch.randint(1, 32, (B,), device=device),
        'day_of_year': torch.randint(1, 366, (B,), device=device)
    }
    
    # 3.1: Forward pass
    print("\n3.1 Forward pass")
    temporal_info = {
        'bar_in_day': batch['bar_in_day'],
        'day_of_week': batch['day_of_week'],
        'day_of_month': batch['day_of_month'],
        'day_of_year': batch['day_of_year']
    }
    
    outputs = model(
        features=batch['features'],
        attention_mask=batch['attention_mask'],
        temporal_info=temporal_info
    )
    
    print(f"  Output shape: {outputs['quantiles'].shape}")
    assert outputs['quantiles'].shape == (B, H, Q), f"Expected ({B}, {H}, {Q})"
    print("  [PASS] Forward pass shape")
    
    # 3.2: Loss computation
    print("\n3.2 Loss computation")
    loss_dict = loss_fn(outputs['quantiles'], batch['targets'])
    loss = loss_dict['total']
    
    print(f"  Loss value: {loss.item():.6f}")
    assert not torch.isnan(loss), "Loss should not be NaN"
    assert loss.item() > 0, "Loss should be positive"
    print("  [PASS] Loss computed")
    
    # 3.3: Backward pass
    print("\n3.3 Backward pass")
    optimizer.zero_grad()
    loss.backward()
    
    # Check gradients exist
    grad_norms = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norms.append((name, param.grad.norm().item()))
    
    print(f"  Parameters with gradients: {len(grad_norms)}")
    print(f"  Sample gradient norms:")
    for name, norm in grad_norms[:3]:
        print(f"    {name}: {norm:.6f}")
    
    assert len(grad_norms) > 0, "No gradients computed"
    print("  [PASS] Gradients computed")
    
    # 3.4: Optimizer step
    print("\n3.4 Optimizer step")
    
    # Get initial weights
    initial_weight = model.output_pool[0].weight.clone()
    
    optimizer.step()
    
    # Check weights changed
    weight_diff = (model.output_pool[0].weight - initial_weight).abs().mean()
    print(f"  Weight change (mean abs): {weight_diff.item():.8f}")
    assert weight_diff > 0, "Weights should change after optimizer step"
    print("  [PASS] Optimizer step updated weights")
    
    # 3.5: Mixed precision
    print("\n3.5 Mixed precision training")
    scaler = torch.cuda.amp.GradScaler()
    
    optimizer.zero_grad()
    
    with torch.cuda.amp.autocast():
        outputs = model(
            features=batch['features'],
            attention_mask=batch['attention_mask'],
            temporal_info=temporal_info
        )
        loss_dict = loss_fn(outputs['quantiles'], batch['targets'])
        loss = loss_dict['total']
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    print(f"  AMP loss: {loss.item():.6f}")
    print(f"  Scaler scale: {scaler.get_scale():.1f}")
    print("  [PASS] Mixed precision training")
    
    print("\n" + "=" * 60)
    print("TEST 3 COMPLETE: All training step tests passed")
    print("=" * 60)

test_training_step()

## Test 4: Validation Step

In [None]:
def test_validation_step():
    """Test validation with metrics computation."""
    print("=" * 60)
    print("TEST 4: Validation Step")
    print("=" * 60)
    
    model = MIGT_TVDT(model_config['model']).to(device)
    model.eval()
    
    loss_fn = CombinedQuantileLoss(
        quantiles=train_config['quantile_regression']['quantiles']
    )
    
    # Create synthetic validation batch
    T, V = 288, model_config['model']['n_variables']
    batch = {
        'features': torch.randn(B, T, V, device=device),
        'attention_mask': torch.ones(B, T, dtype=torch.bool, device=device),
        'targets': torch.randn(B, H, device=device),
        'bar_in_day': torch.arange(T).unsqueeze(0).expand(B, -1).to(device),
        'day_of_week': torch.randint(0, 5, (B,), device=device),
        'day_of_month': torch.randint(1, 32, (B,), device=device),
        'day_of_year': torch.randint(1, 366, (B,), device=device)
    }
    
    # 4.1: Validation forward pass (no gradients)
    print("\n4.1 Validation forward pass")
    
    with torch.no_grad():
        temporal_info = {
            'bar_in_day': batch['bar_in_day'],
            'day_of_week': batch['day_of_week'],
            'day_of_month': batch['day_of_month'],
            'day_of_year': batch['day_of_year']
        }
        
        outputs = model(
            features=batch['features'],
            attention_mask=batch['attention_mask'],
            temporal_info=temporal_info
        )
        
        loss_dict = loss_fn(outputs['quantiles'], batch['targets'])
    
    print(f"  Val loss: {loss_dict['total'].item():.6f}")
    print("  [PASS] No-gradient validation")
    
    # 4.2: Metrics computation
    print("\n4.2 Detailed metrics")
    
    predictions = outputs['quantiles'].cpu()
    targets = batch['targets'].cpu()
    
    metrics = loss_fn.get_metrics(predictions, targets)
    
    print(f"  PICP 80%: {metrics['picp_80']:.3f}")
    print(f"  Coverage q50: {metrics['coverage_q50']:.3f}")
    print(f"  Interval 80 mean: {metrics['interval_80_mean']:.4f}")
    print(f"  Loss 15m: {metrics['loss_15m']:.6f}")
    
    # Check all expected metrics present
    expected_keys = ['picp_80', 'coverage_q50', 'interval_80_mean', 'loss_15m']
    for key in expected_keys:
        assert key in metrics, f"Missing metric: {key}"
    print("  [PASS] All metrics computed")
    
    # 4.3: Non-crossing verification
    print("\n4.3 Quantile non-crossing verification")
    
    # Check all quantiles are monotonically increasing
    diffs = predictions[:, :, 1:] - predictions[:, :, :-1]
    all_positive = (diffs >= 0).all()
    
    print(f"  Min quantile diff: {diffs.min().item():.6f}")
    print(f"  All non-crossing: {all_positive.item()}")
    assert all_positive, "Quantiles should be non-crossing"
    print("  [PASS] Non-crossing quantiles verified")
    
    print("\n" + "=" * 60)
    print("TEST 4 COMPLETE: All validation tests passed")
    print("=" * 60)

test_validation_step()

## Test 5: Full Training Loop (Mini)

In [None]:
def test_training_loop_mini():
    """Test full training loop with synthetic data."""
    print("=" * 60)
    print("TEST 5: Full Training Loop (Mini)")
    print("=" * 60)
    
    # Create synthetic dataset
    from torch.utils.data import TensorDataset, DataLoader
    
    T, V = 288, model_config['model']['n_variables']
    n_samples = 32
    
    features = torch.randn(n_samples, T, V)
    masks = torch.ones(n_samples, T, dtype=torch.bool)
    targets = torch.randn(n_samples, H)
    bar_in_day = torch.arange(T).unsqueeze(0).expand(n_samples, -1)
    day_of_week = torch.randint(0, 5, (n_samples,))
    day_of_month = torch.randint(1, 32, (n_samples,))
    day_of_year = torch.randint(1, 366, (n_samples,))
    
    # Custom collate for synthetic data
    def synthetic_collate(batch):
        indices = torch.tensor(batch)
        return {
            'features': features[indices],
            'attention_mask': masks[indices],
            'targets': targets[indices],
            'bar_in_day': bar_in_day[indices],
            'day_of_week': day_of_week[indices],
            'day_of_month': day_of_month[indices],
            'day_of_year': day_of_year[indices]
        }
    
    train_loader = DataLoader(
        list(range(n_samples // 2)),
        batch_size=4,
        shuffle=True,
        collate_fn=synthetic_collate
    )
    val_loader = DataLoader(
        list(range(n_samples // 2, n_samples)),
        batch_size=4,
        collate_fn=synthetic_collate
    )
    
    # Create model and trainer config
    model = MIGT_TVDT(model_config['model']).to(device)
    
    mini_config = {
        'training': {
            'batch_size': 4,
            'gradient_accumulation_steps': 1,
            'max_epochs': 3,
            'early_stopping_patience': 5,
            'mixed_precision': True
        },
        'optimizer': {
            'lr': 1e-4,
            'weight_decay': 0.01,
            'betas': [0.9, 0.999]
        },
        'scheduler': {
            'warmup_steps': 10,
            't_0': 2,
            't_mult': 2,
            'eta_min': 1e-6
        },
        'regularization': {
            'gradient_clip_norm': 1.0
        },
        'quantile_regression': {
            'quantiles': [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95],
            'crossing_weight': 0.0
        },
        'checkpointing': {
            'save_top_k': 2
        }
    }
    
    output_dir = Path('/content/test_outputs')
    output_dir.mkdir(exist_ok=True)
    
    # 5.1: Create trainer
    print("\n5.1 Create trainer")
    trainer = Trainer(
        model=model,
        config=mini_config,
        train_loader=train_loader,
        val_loader=val_loader,
        output_dir=output_dir,
        device=device
    )
    print("  [PASS] Trainer created")
    
    # 5.2: Run training
    print("\n5.2 Run training (3 epochs)")
    history = trainer.train()
    
    print(f"\n  Final train loss: {history['train_loss'][-1]:.6f}")
    print(f"  Final val loss: {history['val_loss'][-1]:.6f}")
    print(f"  Best val loss: {trainer.best_val_loss:.6f}")
    
    assert len(history['train_loss']) == 3, "Should have 3 epochs"
    assert len(history['val_loss']) == 3
    print("  [PASS] Training completed")
    
    # 5.3: Verify loss decreased (with high probability)
    print("\n5.3 Loss trend")
    # Note: With random data, loss may not always decrease
    # We just verify no NaN/inf
    for loss in history['train_loss']:
        assert np.isfinite(loss), "Loss should be finite"
    print(f"  Train losses: {history['train_loss']}")
    print("  [PASS] All losses finite")
    
    # 5.4: Check files created
    print("\n5.4 Output files")
    assert (output_dir / 'checkpoint_latest.pt').exists()
    assert (output_dir / 'checkpoint_best.pt').exists()
    assert (output_dir / 'training_history.json').exists()
    print("  checkpoint_latest.pt: exists")
    print("  checkpoint_best.pt: exists")
    print("  training_history.json: exists")
    print("  [PASS] All output files created")
    
    # Cleanup
    import shutil
    shutil.rmtree(output_dir)
    
    print("\n" + "=" * 60)
    print("TEST 5 COMPLETE: Full training loop test passed")
    print("=" * 60)

test_training_loop_mini()

## Test 6: Checkpoint Save/Load

In [None]:
def test_checkpoint_save_load():
    """Test checkpoint saving and loading."""
    print("=" * 60)
    print("TEST 6: Checkpoint Save/Load")
    print("=" * 60)
    
    output_dir = Path('/content/test_checkpoint')
    output_dir.mkdir(exist_ok=True)
    
    # Create and initialize model
    model = MIGT_TVDT(model_config['model']).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scheduler = WarmupCosineScheduler(optimizer, warmup_steps=100)
    scaler = torch.cuda.amp.GradScaler()
    
    # Simulate some training
    for _ in range(50):
        scheduler.step_batch()
    
    # 6.1: Save checkpoint
    print("\n6.1 Save checkpoint")
    checkpoint = {
        'epoch': 5,
        'global_step': 500,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'val_loss': 0.123,
        'best_val_loss': 0.100,
        'config': model_config
    }
    
    checkpoint_path = output_dir / 'test_checkpoint.pt'
    torch.save(checkpoint, checkpoint_path)
    print(f"  Saved to: {checkpoint_path}")
    print(f"  File size: {checkpoint_path.stat().st_size / 1e6:.2f} MB")
    print("  [PASS] Checkpoint saved")
    
    # 6.2: Load checkpoint into new model
    print("\n6.2 Load checkpoint")
    
    model2 = MIGT_TVDT(model_config['model']).to(device)
    optimizer2 = torch.optim.AdamW(model2.parameters(), lr=1e-4)
    scheduler2 = WarmupCosineScheduler(optimizer2, warmup_steps=100)
    scaler2 = torch.cuda.amp.GradScaler()
    
    loaded = torch.load(checkpoint_path, map_location=device)
    
    model2.load_state_dict(loaded['model_state_dict'])
    optimizer2.load_state_dict(loaded['optimizer_state_dict'])
    scheduler2.load_state_dict(loaded['scheduler_state_dict'])
    scaler2.load_state_dict(loaded['scaler_state_dict'])
    
    print(f"  Loaded epoch: {loaded['epoch']}")
    print(f"  Loaded global_step: {loaded['global_step']}")
    print(f"  Loaded val_loss: {loaded['val_loss']}")
    print("  [PASS] Checkpoint loaded")
    
    # 6.3: Verify model outputs match
    print("\n6.3 Verify model outputs match")
    
    model.eval()
    model2.eval()
    
    T, V = 288, model_config['model']['n_variables']
    test_input = torch.randn(1, T, V, device=device)
    test_mask = torch.ones(1, T, dtype=torch.bool, device=device)
    temporal_info = {
        'bar_in_day': torch.arange(T).unsqueeze(0).to(device),
        'day_of_week': torch.tensor([0], device=device),
        'day_of_month': torch.tensor([1], device=device),
        'day_of_year': torch.tensor([1], device=device)
    }
    
    with torch.no_grad():
        out1 = model(test_input, test_mask, temporal_info)['quantiles']
        out2 = model2(test_input, test_mask, temporal_info)['quantiles']
    
    diff = (out1 - out2).abs().max().item()
    print(f"  Max output difference: {diff:.10f}")
    assert diff < 1e-6, f"Outputs should match, got diff={diff}"
    print("  [PASS] Model outputs match")
    
    # Cleanup
    import shutil
    shutil.rmtree(output_dir)
    
    print("\n" + "=" * 60)
    print("TEST 6 COMPLETE: Checkpoint save/load test passed")
    print("=" * 60)

test_checkpoint_save_load()

## Test 7: Early Stopping

In [None]:
def test_early_stopping():
    """Test early stopping behavior."""
    print("=" * 60)
    print("TEST 7: Early Stopping")
    print("=" * 60)
    
    # 7.1: Basic early stopping
    print("\n7.1 Basic early stopping (patience=3)")
    
    es = EarlyStopping(patience=3, min_delta=0.001, mode='min')
    
    # Simulate improving then stagnating
    scores = [1.0, 0.9, 0.8, 0.8, 0.8, 0.8]  # Stagnates after epoch 2
    
    for i, score in enumerate(scores):
        should_stop = es(score)
        print(f"  Epoch {i}: score={score:.2f}, counter={es.counter}, stop={should_stop}")
        
        if should_stop:
            print(f"  Stopped at epoch {i}")
            break
    
    assert should_stop, "Should have triggered early stopping"
    print("  [PASS] Early stopping triggered")
    
    # 7.2: No stopping with improvement
    print("\n7.2 No stopping with improvement")
    
    es2 = EarlyStopping(patience=3, min_delta=0.001, mode='min')
    scores2 = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5]
    
    for i, score in enumerate(scores2):
        should_stop = es2(score)
        if should_stop:
            print(f"  Unexpectedly stopped at epoch {i}")
            break
    
    assert not should_stop, "Should not have stopped"
    print(f"  Final best score: {es2.best_score:.2f}")
    print("  [PASS] No early stopping with improvement")
    
    # 7.3: Max mode (for accuracy-like metrics)
    print("\n7.3 Max mode (accuracy-like)")
    
    es3 = EarlyStopping(patience=2, min_delta=0.01, mode='max')
    scores3 = [0.5, 0.6, 0.65, 0.65, 0.65]
    
    for i, score in enumerate(scores3):
        should_stop = es3(score)
        if should_stop:
            break
    
    assert should_stop, "Should stop in max mode"
    print(f"  Best score: {es3.best_score:.2f}")
    print("  [PASS] Max mode early stopping")
    
    print("\n" + "=" * 60)
    print("TEST 7 COMPLETE: Early stopping tests passed")
    print("=" * 60)

test_early_stopping()

## Test 8: Phase 3/4 Integration

In [None]:
def test_phase_integration():
    """Test integration with Phase 3 data pipeline and Phase 4 model."""
    print("=" * 60)
    print("TEST 8: Phase 3/4 Integration")
    print("=" * 60)
    
    # Check for processed data
    data_path = BASE_DIR / 'data/processed/nq_features_full.parquet'
    
    if not data_path.exists():
        print(f"\n  [SKIP] Data file not found: {data_path}")
        print("  Run Phase 3 preprocessing first.")
        print("\n" + "=" * 60)
        print("TEST 8 SKIPPED: Data not available")
        print("=" * 60)
        return
    
    # 8.1: Load data module
    print("\n8.1 Load NQDataModule")
    
    data_module = NQDataModule(
        data_path=data_path,
        batch_size=4,  # Small batch for testing
        num_workers=0,
        pin_memory=False
    )
    data_module.setup()
    
    print(f"  Train samples: {len(data_module.train_dataset):,}")
    print(f"  Val samples: {len(data_module.val_dataset):,}")
    print(f"  Test samples: {len(data_module.test_dataset):,}")
    print("  [PASS] Data module loaded")
    
    # 8.2: Get batch from dataloader
    print("\n8.2 Get batch from dataloader")
    
    train_loader = data_module.train_dataloader()
    batch = next(iter(train_loader))
    
    print(f"  features shape: {batch['features'].shape}")
    print(f"  attention_mask shape: {batch['attention_mask'].shape}")
    print(f"  targets shape: {batch['targets'].shape}")
    print(f"  bar_in_day shape: {batch['bar_in_day'].shape}")
    print("  [PASS] Batch retrieved")
    
    # 8.3: Forward pass through model
    print("\n8.3 Forward pass with real data")
    
    model = MIGT_TVDT(model_config['model']).to(device)
    model.eval()
    
    # Move batch to device
    batch_device = {
        k: v.to(device) if isinstance(v, torch.Tensor) else v
        for k, v in batch.items()
    }
    
    temporal_info = {
        'bar_in_day': batch_device['bar_in_day'],
        'day_of_week': batch_device['day_of_week'],
        'day_of_month': batch_device['day_of_month'],
        'day_of_year': batch_device['day_of_year']
    }
    
    with torch.no_grad():
        outputs = model(
            features=batch_device['features'],
            attention_mask=batch_device['attention_mask'],
            temporal_info=temporal_info
        )
    
    print(f"  Output shape: {outputs['quantiles'].shape}")
    print(f"  Output sample (first horizon):\n{outputs['quantiles'][0, 0]}")
    print("  [PASS] Forward pass successful")
    
    # 8.4: Loss computation with real data
    print("\n8.4 Loss computation with real data")
    
    loss_fn = CombinedQuantileLoss(
        quantiles=train_config['quantile_regression']['quantiles']
    )
    
    loss_dict = loss_fn(outputs['quantiles'], batch_device['targets'])
    
    print(f"  Total loss: {loss_dict['total'].item():.6f}")
    print(f"  Pinball loss: {loss_dict['pinball'].item():.6f}")
    
    assert loss_dict['total'].item() > 0, "Loss should be positive"
    assert np.isfinite(loss_dict['total'].item()), "Loss should be finite"
    print("  [PASS] Loss computed successfully")
    
    # 8.5: Metrics with real data
    print("\n8.5 Metrics with real data")
    
    metrics = loss_fn.get_metrics(
        outputs['quantiles'].cpu(),
        batch_device['targets'].cpu()
    )
    
    print(f"  PICP 80%: {metrics['picp_80']:.3f}")
    print(f"  Coverage q50: {metrics['coverage_q50']:.3f}")
    print(f"  Interval 80 mean: {metrics['interval_80_mean']:.6f}")
    print("  [PASS] Metrics computed")
    
    print("\n" + "=" * 60)
    print("TEST 8 COMPLETE: Phase integration tests passed")
    print("=" * 60)

test_phase_integration()

## Summary

In [None]:
print("\n" + "=" * 70)
print("DEV PHASE 5: TRAINING PIPELINE TESTS COMPLETE")
print("=" * 70)
print("\nAll tests passed. Training pipeline is ready for full training.")
print("\nDelivered components:")
print("  - src/training/loss_functions.py")
print("  - src/training/scheduler.py")
print("  - src/training/trainer.py")
print("  - src/training/__init__.py")
print("  - configs/training_config.yaml")
print("\nNext steps for full training:")
print("  1. Ensure Phase 3 data is preprocessed")
print("  2. Copy data to VM: /content/data/")
print("  3. Run full training with Trainer class")
print("  4. Monitor with WandB or TensorBoard")