# Delphi Model Evaluation on Test Set

This notebook evaluates the trained Delphi model on the test dataset.

**Metrics:**
- Cross-Entropy Loss (CE): Measures accuracy of next event prediction
- Time-to-Event Loss (DT): Measures accuracy of time prediction
- Perplexity: exp(CE loss), measures model uncertainty
- Combined Loss: CE + DT

In [None]:
import os
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from model import Delphi, DelphiConfig
from utils import get_p2i, get_batch

# Set style for plots
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Configuration and Setup

In [None]:
# Configuration
checkpoint_path = 'mc-med-gpt/ckpt_10000.pt'  # or 'mc-med-gpt/ckpt.pt'
data_dir = 'data/mc-med'
device = 'mps'  # 'cuda', 'mps', or 'cpu'
batch_size = 64
block_size = 128  # Should match training config
no_event_token_rate = 5
eval_iters = 100  # Number of batches to evaluate (increase for more accurate results)

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Load Checkpoint and Model

In [None]:
# Load checkpoint
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device)

# Extract model configuration
model_args = checkpoint['model_args']
print(f"\nModel configuration:")
for k, v in model_args.items():
    print(f"  {k}: {v}")

# Print training info
print(f"\nTraining info:")
print(f"  Iteration: {checkpoint['iter_num']}")
print(f"  Best validation loss: {checkpoint['best_val_loss']:.4f}")

# Create model
gptconf = DelphiConfig(**model_args)
model = Delphi(gptconf)

# Load weights
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)

model.to(device)
model.eval()
print(f"\nModel loaded successfully and moved to {device}")

## 3. Load Test Data

In [None]:
# Load test data
test_data = np.memmap(os.path.join(data_dir, 'test.bin'), dtype=np.uint32, mode='r').reshape(-1, 3)
test_p2i = get_p2i(test_data)

print(f"Test dataset:")
print(f"  Total events: {len(test_data):,}")
print(f"  Total patients: {len(test_p2i):,}")
print(f"  Avg events per patient: {len(test_data) / len(test_p2i):.2f}")

# Load labels if available
labels_path = os.path.join(data_dir, 'labels.csv')
if os.path.exists(labels_path):
    labels_df = pd.read_csv(labels_path)
    print(f"\nLoaded {len(labels_df)} event labels")
else:
    labels_df = None
    print("\nNo labels file found")

## 4. Evaluate on Test Set

In [None]:
@torch.no_grad()
def evaluate_test_set(model, test_data, test_p2i, eval_iters, batch_size, 
                      block_size, device, no_event_token_rate):
    """
    Evaluate model on test set following the same approach as BERT evaluation.
    Returns detailed metrics including per-batch losses.
    """
    model.eval()
    
    # Storage for batch-level metrics
    batch_losses_ce = []
    batch_losses_dt = []
    batch_losses_total = []
    
    print(f"Evaluating on {eval_iters} batches...")
    for k in tqdm(range(eval_iters)):
        # Sample random batch
        ix = torch.randint(len(test_p2i), (batch_size,))
        X, A, Y, B = get_batch(ix, test_data, test_p2i, 
                              block_size=block_size,
                              device=device, 
                              select='left',
                              no_event_token_rate=no_event_token_rate, 
                              cut_batch=True)
        
        # Forward pass
        logits, loss, _ = model(X, A, Y, B, validation_loss_mode=True)
        
        # Store losses
        batch_losses_ce.append(loss['loss_ce'].item())
        batch_losses_dt.append(loss['loss_dt'].item())
        batch_losses_total.append(loss['loss_ce'].item() + loss['loss_dt'].item())
    
    # Convert to numpy arrays
    batch_losses_ce = np.array(batch_losses_ce)
    batch_losses_dt = np.array(batch_losses_dt)
    batch_losses_total = np.array(batch_losses_total)
    
    # Compute statistics
    results = {
        'loss_ce': {
            'mean': batch_losses_ce.mean(),
            'std': batch_losses_ce.std(),
            'min': batch_losses_ce.min(),
            'max': batch_losses_ce.max(),
            'median': np.median(batch_losses_ce),
        },
        'loss_dt': {
            'mean': batch_losses_dt.mean(),
            'std': batch_losses_dt.std(),
            'min': batch_losses_dt.min(),
            'max': batch_losses_dt.max(),
            'median': np.median(batch_losses_dt),
        },
        'loss_total': {
            'mean': batch_losses_total.mean(),
            'std': batch_losses_total.std(),
            'min': batch_losses_total.min(),
            'max': batch_losses_total.max(),
            'median': np.median(batch_losses_total),
        },
        'perplexity': np.exp(batch_losses_ce.mean()),
        'batch_losses_ce': batch_losses_ce,
        'batch_losses_dt': batch_losses_dt,
        'batch_losses_total': batch_losses_total,
    }
    
    return results

# Run evaluation
results = evaluate_test_set(
    model=model,
    test_data=test_data,
    test_p2i=test_p2i,
    eval_iters=eval_iters,
    batch_size=batch_size,
    block_size=block_size,
    device=device,
    no_event_token_rate=no_event_token_rate
)

## 5. Results Summary

In [None]:
# Print summary
print("="*70)
print("TEST SET EVALUATION RESULTS")
print("="*70)
print(f"\nCheckpoint: {checkpoint_path}")
print(f"Training iteration: {checkpoint['iter_num']}")
print(f"Test batches evaluated: {eval_iters}")
print(f"Total test samples evaluated: ~{eval_iters * batch_size:,}")
print("\n" + "-"*70)

print("\nCROSS-ENTROPY LOSS (Next Event Prediction):")
print(f"  Mean:   {results['loss_ce']['mean']:.4f} ± {results['loss_ce']['std']:.4f}")
print(f"  Median: {results['loss_ce']['median']:.4f}")
print(f"  Range:  [{results['loss_ce']['min']:.4f}, {results['loss_ce']['max']:.4f}]")

print("\nTIME-TO-EVENT LOSS:")
print(f"  Mean:   {results['loss_dt']['mean']:.4f} ± {results['loss_dt']['std']:.4f}")
print(f"  Median: {results['loss_dt']['median']:.4f}")
print(f"  Range:  [{results['loss_dt']['min']:.4f}, {results['loss_dt']['max']:.4f}]")

print("\nCOMBINED LOSS (CE + DT):")
print(f"  Mean:   {results['loss_total']['mean']:.4f} ± {results['loss_total']['std']:.4f}")
print(f"  Median: {results['loss_total']['median']:.4f}")
print(f"  Range:  [{results['loss_total']['min']:.4f}, {results['loss_total']['max']:.4f}]")

print("\nPERPLEXITY (exp of CE loss):")
print(f"  {results['perplexity']:.2f}")
print("  (Lower is better - measures model uncertainty)")

print("\n" + "="*70)

## 6. Visualizations

In [None]:
# Plot loss distributions
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# CE Loss
axes[0].hist(results['batch_losses_ce'], bins=30, alpha=0.7, color='blue', edgecolor='black')
axes[0].axvline(results['loss_ce']['mean'], color='red', linestyle='--', linewidth=2, label=f"Mean: {results['loss_ce']['mean']:.4f}")
axes[0].axvline(results['loss_ce']['median'], color='green', linestyle='--', linewidth=2, label=f"Median: {results['loss_ce']['median']:.4f}")
axes[0].set_xlabel('Cross-Entropy Loss')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Distribution of CE Loss Across Batches')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# DT Loss
axes[1].hist(results['batch_losses_dt'], bins=30, alpha=0.7, color='orange', edgecolor='black')
axes[1].axvline(results['loss_dt']['mean'], color='red', linestyle='--', linewidth=2, label=f"Mean: {results['loss_dt']['mean']:.4f}")
axes[1].axvline(results['loss_dt']['median'], color='green', linestyle='--', linewidth=2, label=f"Median: {results['loss_dt']['median']:.4f}")
axes[1].set_xlabel('Time-to-Event Loss')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Distribution of DT Loss Across Batches')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Total Loss
axes[2].hist(results['batch_losses_total'], bins=30, alpha=0.7, color='purple', edgecolor='black')
axes[2].axvline(results['loss_total']['mean'], color='red', linestyle='--', linewidth=2, label=f"Mean: {results['loss_total']['mean']:.4f}")
axes[2].axvline(results['loss_total']['median'], color='green', linestyle='--', linewidth=2, label=f"Median: {results['loss_total']['median']:.4f}")
axes[2].set_xlabel('Total Loss (CE + DT)')
axes[2].set_ylabel('Frequency')
axes[2].set_title('Distribution of Total Loss Across Batches')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('test_evaluation_loss_distributions.png', dpi=300, bbox_inches='tight')
plt.show()

print("Saved plot to 'test_evaluation_loss_distributions.png'")

In [None]:
# Plot loss trends across batches
fig, ax = plt.subplots(figsize=(14, 6))

batch_indices = np.arange(eval_iters)
ax.plot(batch_indices, results['batch_losses_ce'], alpha=0.6, label='CE Loss', linewidth=1)
ax.plot(batch_indices, results['batch_losses_dt'], alpha=0.6, label='DT Loss', linewidth=1)
ax.plot(batch_indices, results['batch_losses_total'], alpha=0.6, label='Total Loss', linewidth=1)

# Add rolling average
window = min(10, eval_iters // 10)
if window > 1:
    from scipy.ndimage import uniform_filter1d
    ax.plot(batch_indices, uniform_filter1d(results['batch_losses_total'], window), 
            color='red', linewidth=2, label=f'Total Loss (MA-{window})')

ax.set_xlabel('Batch Index')
ax.set_ylabel('Loss')
ax.set_title('Loss Values Across Test Batches')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('test_evaluation_loss_trends.png', dpi=300, bbox_inches='tight')
plt.show()

print("Saved plot to 'test_evaluation_loss_trends.png'")

## 7. Save Results to File

In [None]:
# Create results summary dictionary
summary = {
    'checkpoint': checkpoint_path,
    'training_iteration': checkpoint['iter_num'],
    'best_val_loss': checkpoint['best_val_loss'],
    'test_batches': eval_iters,
    'batch_size': batch_size,
    'test_samples_approx': eval_iters * batch_size,
    'ce_loss_mean': results['loss_ce']['mean'],
    'ce_loss_std': results['loss_ce']['std'],
    'ce_loss_median': results['loss_ce']['median'],
    'dt_loss_mean': results['loss_dt']['mean'],
    'dt_loss_std': results['loss_dt']['std'],
    'dt_loss_median': results['loss_dt']['median'],
    'total_loss_mean': results['loss_total']['mean'],
    'total_loss_std': results['loss_total']['std'],
    'total_loss_median': results['loss_total']['median'],
    'perplexity': results['perplexity'],
}

# Save to JSON
import json
output_file = 'test_evaluation_results.json'
with open(output_file, 'w') as f:
    json.dump(summary, f, indent=2)
print(f"Saved results summary to '{output_file}'")

# Save detailed batch results to CSV
batch_results_df = pd.DataFrame({
    'batch_idx': np.arange(eval_iters),
    'loss_ce': results['batch_losses_ce'],
    'loss_dt': results['batch_losses_dt'],
    'loss_total': results['batch_losses_total'],
})
batch_results_file = 'test_evaluation_batch_results.csv'
batch_results_df.to_csv(batch_results_file, index=False)
print(f"Saved batch-level results to '{batch_results_file}'")

## 8. Optional: Compare with Validation Set

Compare test set performance with validation set to check for overfitting.

In [None]:
# Load validation data
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint32, mode='r').reshape(-1, 3)
val_p2i = get_p2i(val_data)

print(f"Evaluating on validation set...")
val_results = evaluate_test_set(
    model=model,
    test_data=val_data,
    test_p2i=val_p2i,
    eval_iters=eval_iters,
    batch_size=batch_size,
    block_size=block_size,
    device=device,
    no_event_token_rate=no_event_token_rate
)

# Compare results
print("\n" + "="*70)
print("COMPARISON: Test vs Validation")
print("="*70)
print(f"\nCE Loss:")
print(f"  Test:       {results['loss_ce']['mean']:.4f}")
print(f"  Validation: {val_results['loss_ce']['mean']:.4f}")
print(f"  Difference: {results['loss_ce']['mean'] - val_results['loss_ce']['mean']:.4f}")

print(f"\nDT Loss:")
print(f"  Test:       {results['loss_dt']['mean']:.4f}")
print(f"  Validation: {val_results['loss_dt']['mean']:.4f}")
print(f"  Difference: {results['loss_dt']['mean'] - val_results['loss_dt']['mean']:.4f}")

print(f"\nTotal Loss:")
print(f"  Test:       {results['loss_total']['mean']:.4f}")
print(f"  Validation: {val_results['loss_total']['mean']:.4f}")
print(f"  Difference: {results['loss_total']['mean'] - val_results['loss_total']['mean']:.4f}")

print(f"\nPerplexity:")
print(f"  Test:       {results['perplexity']:.2f}")
print(f"  Validation: {val_results['perplexity']:.2f}")
print(f"  Difference: {results['perplexity'] - val_results['perplexity']:.2f}")

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

metrics = ['CE Loss', 'DT Loss', 'Total Loss']
test_values = [results['loss_ce']['mean'], results['loss_dt']['mean'], results['loss_total']['mean']]
val_values = [val_results['loss_ce']['mean'], val_results['loss_dt']['mean'], val_results['loss_total']['mean']]

x = np.arange(len(metrics))
width = 0.35

axes[0].bar(x - width/2, test_values, width, label='Test', alpha=0.8)
axes[0].bar(x + width/2, val_values, width, label='Validation', alpha=0.8)
axes[0].set_ylabel('Loss')
axes[0].set_title('Test vs Validation Loss Comparison')
axes[0].set_xticks(x)
axes[0].set_xticklabels(metrics)
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')

axes[1].bar(['Test', 'Validation'], [results['perplexity'], val_results['perplexity']], alpha=0.8, color=['blue', 'orange'])
axes[1].set_ylabel('Perplexity')
axes[1].set_title('Test vs Validation Perplexity')
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('test_vs_validation_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nSaved comparison plot to 'test_vs_validation_comparison.png'")

## 9. Optional: Sample Predictions

Examine model predictions on a few test samples.

In [None]:
# Get a single batch for inspection
ix = torch.randint(len(test_p2i), (5,))  # Sample 5 patients
X, A, Y, B = get_batch(ix, test_data, test_p2i, 
                      block_size=block_size,
                      device=device, 
                      select='left',
                      no_event_token_rate=no_event_token_rate, 
                      cut_batch=True)

# Get predictions
with torch.no_grad():
    logits, loss, att = model(X, A, Y, B, validation_loss_mode=True)
    probs = torch.softmax(logits, dim=-1)
    predicted_tokens = torch.argmax(logits, dim=-1)

print("Sample Predictions (first patient in batch):\n")
print("Position | Input Token | Input Age | True Next | Predicted | Top-3 Predictions")
print("-" * 90)

patient_idx = 0
for pos in range(min(10, X.shape[1])):  # Show first 10 positions
    if X[patient_idx, pos].item() == 0:  # Skip padding
        continue
    
    input_token = X[patient_idx, pos].item()
    input_age = A[patient_idx, pos].item() / 365.25  # Convert to years
    true_next = Y[patient_idx, pos].item()
    pred_next = predicted_tokens[patient_idx, pos].item()
    
    # Get top 3 predictions
    top3_probs, top3_tokens = torch.topk(probs[patient_idx, pos], 3)
    top3_str = ", ".join([f"{tok.item()}({prob.item():.2%})" for tok, prob in zip(top3_tokens, top3_probs)])
    
    match = "✓" if pred_next == true_next else "✗"
    print(f"{pos:8d} | {input_token:11d} | {input_age:9.1f} | {true_next:9d} | {pred_next:9d} {match} | {top3_str}")

print("\nNote: Tokens are shifted by +1 (0 is padding)")

## Summary

This notebook evaluates the Delphi model following best practices from BERT and similar language models:

1. **Multiple metrics**: CE loss, DT loss, combined loss, and perplexity
2. **Statistical analysis**: Mean, std, median, min, max for robust evaluation
3. **Visualization**: Loss distributions and trends
4. **Comparison**: Test vs validation to detect overfitting
5. **Reproducibility**: Results saved to JSON and CSV files

**Key Metrics to Report:**
- Test Loss (CE + DT combined)
- Perplexity (measures model uncertainty)
- Comparison with validation loss (to check generalization)