# TD-MPC 2: Model-Based RL with Transformer World Model

This notebook demonstrates TD-MPC 2 (Temporal Difference Model Predictive Control 2) for OpenScope ATC, combining:
- **Transformer-based world model** for dynamics prediction
- **Model Predictive Control (MPC)** for action planning
- **Q-learning** for long-term value estimation

## Key Concepts

**TD-MPC 2** is a sample-efficient model-based RL algorithm that:
- Learns a transformer-based world model in latent space
- Uses MPC with Cross-Entropy Method (CEM) for action planning
- Combines short-term planning (MPC) with long-term value (Q-function)
- Achieves better sample efficiency than pure model-free methods

**Workflow**:
1. **Collect initial data** from OpenScope (random or heuristic policy)
2. **Train world model** - Learn dynamics and reward prediction
3. **Train Q-function** - Learn long-term value estimates
4. **MPC planning** - Use learned model for action selection
5. **Online learning** - Collect data with MPC policy and continue training

## Prerequisites

- OpenScope server running at http://localhost:3003
- GPU recommended for faster training
- Estimated time: 2-3 hours for full training (1M steps)


## ðŸ“š Learning Objectives

By the end of this notebook, you will understand:

1. **TD-MPC 2 Algorithm** - How transformer world models combine with MPC and Q-learning
2. **Latent State Representation** - Encoding observations to compact latent space
3. **MPC Planning** - Cross-Entropy Method for action optimization
4. **Joint Training** - Simultaneously learning dynamics, rewards, and Q-values
5. **Sample Efficiency** - How model-based RL reduces environment interactions

**Estimated Time**: 2-3 hours for full training (1M steps) | 30 min for quick demo (10k steps)  
**Prerequisites**: Understanding of transformers, Q-learning, model-based RL  
**Hardware**: GPU strongly recommended


## Section 1: Setup & Imports

Set up imports and create the environment.


In [None]:
import sys
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

# Add parent directory to path
sys.path.insert(0, str(Path.cwd().parent))

from environment import PlaywrightEnv
from models.tdmpc2 import TDMPC2Model, TDMPC2Config
from training.tdmpc2_trainer import TDMPC2Trainer, TDMPC2TrainingConfig, ReplayBuffer
from training.tdmpc2_planner import MPCPlanner, MPCPlannerConfig
from environment.utils import get_device

print("âœ… Imports successful!")
print(f"PyTorch: {torch.__version__}")
print(f"Device: {get_device()}")


## Section 2: Create Environment

Create the OpenScope environment for data collection and evaluation.


In [None]:
# Environment configuration
AIRPORT = "KLAS"
MAX_AIRCRAFT = 10
HEADLESS = True  # Set to False to see browser
TIMEWARP = 5

# Create environment
env = PlaywrightEnv(
    airport=AIRPORT,
    max_aircraft=MAX_AIRCRAFT,
    headless=HEADLESS,
    timewarp=TIMEWARP,
)

print(f"âœ… Environment created: {AIRPORT} with max {MAX_AIRCRAFT} aircraft")
print(f"Observation space: {env.observation_space}")
print(f"Action space: {env.action_space}")


## Section 3: Configure TD-MPC 2

Set up model and training configurations.


In [None]:
# Model configuration
model_config = TDMPC2Config(
    aircraft_feature_dim=14,
    global_feature_dim=4,
    max_aircraft=MAX_AIRCRAFT,
    latent_dim=512,
    encoder_hidden_dim=256,
    encoder_num_layers=4,
    encoder_num_heads=8,
    dynamics_hidden_dim=512,
    dynamics_num_layers=3,
    action_dim=5,  # aircraft_id, command_type, altitude, heading, speed
)

# Planner configuration
planner_config = MPCPlannerConfig(
    planning_horizon=5,
    num_samples=512,
    num_elites=64,
    num_iterations=6,
    gamma=0.99,
)

# Training configuration
training_config = TDMPC2TrainingConfig(
    model_config=model_config,
    planner_config=planner_config,
    num_steps=100000,  # Reduced for demo - use 1000000 for full training
    batch_size=64,
    learning_rate_model=1e-3,
    learning_rate_q=1e-3,
    buffer_capacity=100000,
    min_buffer_size=1000,
    eval_frequency=5000,
    eval_episodes=5,
    checkpoint_dir="checkpoints/tdmpc2",
    use_wandb=False,  # Set to True to enable WandB logging
)

print("âœ… Configuration created")
print(f"Model latent dim: {model_config.latent_dim}")
print(f"Planning horizon: {planner_config.planning_horizon}")
print(f"Training steps: {training_config.num_steps}")


## Section 4: Create Model and Trainer

Initialize the TD-MPC 2 model and trainer.


In [None]:
# Create model
model = TDMPC2Model(model_config)
print(f"âœ… Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Create trainer
trainer = TDMPC2Trainer(env, training_config)
print("âœ… Trainer initialized")


## Section 5: Test MPC Planning

Test the MPC planner on a single observation to verify it works.


In [None]:
# Reset environment and get observation
obs, info = env.reset()

# Convert to tensors
aircraft_tensor = torch.from_numpy(obs["aircraft"]).float().unsqueeze(0).to(model_config.device)
mask_tensor = torch.from_numpy(obs["aircraft_mask"]).bool().unsqueeze(0).to(model_config.device)
global_tensor = torch.from_numpy(obs["global_state"]).float().unsqueeze(0).to(model_config.device)

# Test planner
planner = MPCPlanner(model, planner_config)
with torch.no_grad():
    action = planner.plan(aircraft_tensor, mask_tensor, global_tensor)

print(f"âœ… MPC planning successful")
print(f"Planned action shape: {action.shape}")
print(f"Action values: {action.squeeze().cpu().numpy()}")


## Section 6: Train TD-MPC 2

Start training! This will:
1. Collect initial data to fill replay buffer
2. Train world model and Q-function jointly
3. Use MPC policy for data collection
4. Continue online learning


In [None]:
# Start training
print("ðŸš€ Starting TD-MPC 2 training...")
print("This will take a while. Progress will be shown below.")
print()

trainer.train()

print()
print("âœ… Training complete!")


## Section 7: Evaluate Trained Model

Evaluate the trained model on the environment.


In [None]:
# Evaluate model
eval_metrics = trainer._evaluate()

print("ðŸ“Š Evaluation Results:")
for key, value in eval_metrics.items():
    print(f"  {key}: {value:.2f}")


## Section 8: Visualize Training Progress

Plot training metrics if available.


In [None]:
# If WandB was used, metrics are logged there
# Otherwise, you can track metrics manually during training

print("ðŸ“ˆ Training metrics:")
print(f"  Training steps: {trainer.training_step}")
print(f"  Environment steps: {trainer.env_step}")
print(f"  Total episodes: {trainer.episode}")
print(f"  Replay buffer size: {len(trainer.replay_buffer)}")
print(f"  Best eval return: {trainer.best_eval_return:.2f}")

# Note: For detailed plots, enable WandB logging in training_config


## Section 9: Save and Load Checkpoints

Save the trained model and demonstrate loading.


In [None]:
# Save final checkpoint
trainer.save_checkpoint("final_tdmpc2_model.pt")
print("âœ… Model saved")

# Example: Load checkpoint
# trainer.load_checkpoint("checkpoints/tdmpc2/final_model.pt")
# print("âœ… Model loaded")


## Summary

This notebook demonstrated:

1. **TD-MPC 2 Architecture** - Transformer-based world model with MPC planning
2. **Model Training** - Joint learning of dynamics, rewards, and Q-values
3. **MPC Planning** - Cross-Entropy Method for action optimization
4. **Online Learning** - Continuous improvement through environment interaction

**Next Steps**:
- Experiment with different model architectures (latent_dim, num_layers)
- Tune MPC parameters (planning_horizon, num_samples)
- Compare with other model-based methods (DreamerV3, Trajectory Transformer)
- Enable WandB logging for detailed metrics visualization

**References**:
- TD-MPC 2 Paper: https://arxiv.org/abs/2310.16828
- Original TD-MPC: https://arxiv.org/abs/2203.04955
