# Simple World Model for Fast Training

This notebook demonstrates how to train a lightweight MLP-based dynamics model for OpenScope ATC, enabling fast model-based RL training without browser overhead.

## Key Concepts

**Problem**: Browser-based training is slow (100-500ms per step), limiting RL iteration speed.

**Solution**: Learn a simple dynamics model from collected trajectories, then train RL policies in the learned model (100x faster).

**Simple World Model**:
- MLP that predicts `next_state, reward = f(state, action)`
- Trained on offline trajectories from OpenScope
- Much simpler than Cosmos (no video, no foundation models)
- Still enables fast model-based RL training

## Workflow

1. **Collect trajectories** from OpenScope (or use existing dataset)
2. **Train simple dynamics model** - Learn state/action → next_state/reward mapping
3. **Evaluate model accuracy** - Test prediction quality
4. **Train RL in learned model** - Fast training without browser
5. **Compare performance** - Model-based vs browser-based training

## Prerequisites

- OpenScope server running at http://localhost:3003 (for data collection)
- GPU recommended for faster training
- Estimated time: 30-45 minutes (data collection + training)



## 📚 Learning Objectives

By the end of this notebook, you will understand:

1. **World Models** - Learning environment dynamics to predict future states and rewards
2. **Model-Based RL** - Training RL policies in learned models for faster iteration
3. **Dynamics Model Architecture** - Simple MLP that predicts state transitions
4. **Training Dynamics Models** - Supervised learning on trajectory data
5. **Trade-offs** - Simplicity vs accuracy: when simple models are sufficient

**Estimated Time**: 30-45 minutes (includes data collection and training)  
**Prerequisites**: Understanding of supervised learning, basic RL concepts  
**Hardware**: GPU recommended for faster training


## Section 1: Setup and Imports

Let's set up the environment and import necessary modules.


## Section 1: Setup & Imports

Set up imports and utilities for training the simple dynamics model.


In [None]:
import sys
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import random_split, DataLoader

# Add parent directory to path
sys.path.insert(0, str(Path.cwd().parent))

from data.offline_dataset import OfflineDatasetCollector
from training.world_model_trainer import WorldModelTrainer, WorldModelConfig, DynamicsDataset, compute_state_dim, compute_action_dim
from environment.utils import get_device

print("✅ Imports successful!")
print(f"PyTorch: {torch.__version__}")
print(f"Device: {get_device()}")


## Section 2: Load Trajectories

Load offline OpenScope trajectories (or collect a small demo set).


In [None]:
from pathlib import Path
from data.offline_dataset import OfflineDatasetCollector

# Load if available
data_path = Path("../data/offline_data.pkl")
if data_path.exists():
    print(f"📦 Loading episodes from {data_path}")
    episodes = OfflineDatasetCollector.load_episodes(str(data_path))
else:
    print("⚠️ No dataset found at ../data/offline_data.pkl")
    episodes = []

print(f"Episodes: {len(episodes)}")


## Section 3: Build Dataset and Train World Model

We'll flatten observations/actions and train the MLP dynamics model.


In [None]:
if len(episodes) == 0:
    raise RuntimeError("No offline episodes available. Please load or collect data first.")

# Compute flattened dims (use defaults; adjust if needed)
state_dim = compute_state_dim(max_aircraft=20)
action_dim = compute_action_dim()

# Split episodes for train/val
total_eps = len(episodes)
val_eps = max(1, int(0.1 * total_eps))
train_eps = total_eps - val_eps

train_episodes = episodes[:train_eps]
val_episodes = episodes[train_eps:]

# Create datasets
train_ds = DynamicsDataset(train_episodes, state_dim=state_dim, action_dim=action_dim)
val_ds = DynamicsDataset(val_episodes, state_dim=state_dim, action_dim=action_dim)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)

# Configure and train
config = WorldModelConfig(num_epochs=10, batch_size=64, learning_rate=1e-3)
trainer = WorldModelTrainer(config, state_dim=state_dim, action_dim=action_dim)
history = trainer.train(train_loader, val_loader)

print("\nTraining complete!")
print({k: v[-1] for k, v in history.items() if len(v) > 0})


In [None]:
# Apply nest_asyncio for Jupyter compatibility
import nest_asyncio
nest_asyncio.apply()

import sys
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Add parent directory to path
sys.path.insert(0, str(Path.cwd().parent))

from environment import PlaywrightEnv, create_default_config
from data.offline_dataset import OfflineDatasetCollector
from models.dynamics_model import (
    SimpleDynamicsModel,
    flatten_observation,
    flatten_action,
    compute_state_dim,
    compute_action_dim,
)
from training.world_model_trainer import (
    WorldModelTrainer,
    WorldModelConfig,
    DynamicsDataset,
)
from environment.utils import get_device
from torch.utils.data import DataLoader, random_split

print("✅ Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {get_device()}")


## Section 2: Collect Trajectory Data

First, we need to collect trajectories from OpenScope. We'll use a random policy to gather diverse state-action transitions.


In [None]:
# Create OpenScope environment
env = PlaywrightEnv(
    airport="KLAS",
    max_aircraft=5,  # Start with small number for faster data collection
    headless=True,
    timewarp=5,
    episode_length=600,  # 10 minutes
)

print("✅ Environment created")

# Initialize collector
collector = OfflineDatasetCollector(env)

# Collect episodes (use fewer for demo - increase for real training)
print("\n📊 Collecting trajectory data...")
print("   This will take a few minutes. Using 50 episodes for demo.")
episodes = collector.collect_random_episodes(
    num_episodes=50,  # Increase to 500-1000 for real training
    max_steps=100,
    verbose=True
)

print(f"\n✅ Collected {len(episodes)} episodes")
print(f"   Total timesteps: {sum(ep.length for ep in episodes)}")


## Section 3: Prepare Dataset for Dynamics Model Training

We need to convert episodes into (state, action, next_state, reward) transitions for supervised learning.


In [None]:
# Compute dimensions
max_aircraft = 5
state_dim = compute_state_dim(max_aircraft)
action_dim = compute_action_dim()

print(f"State dimension (flattened): {state_dim}")
print(f"Action dimension: {action_dim}")

# Create dataset
dataset = DynamicsDataset(
    episodes=episodes,
    state_dim=state_dim,
    action_dim=action_dim,
)

# Split into train/val
train_size = int(len(dataset) * 0.8)
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0,
)

print(f"\n✅ Dataset prepared:")
print(f"   Train samples: {len(train_dataset)}")
print(f"   Val samples: {len(val_dataset)}")


## Section 4: Train Dynamics Model

Now we'll train the MLP to predict next states and rewards from current states and actions.


In [None]:
# Configure training
config = WorldModelConfig(
    hidden_dim=256,
    num_layers=3,
    dropout=0.1,
    num_epochs=50,  # Reduce for demo
    batch_size=64,
    learning_rate=1e-3,
    checkpoint_dir="../checkpoints/world_model",
    save_every=10,
    device=get_device(),
)

print("📋 Training Configuration:")
print(f"   Hidden dim: {config.hidden_dim}")
print(f"   Layers: {config.num_layers}")
print(f"   Epochs: {config.num_epochs}")
print(f"   Batch size: {config.batch_size}")
print(f"   Learning rate: {config.learning_rate}")
print(f"   Device: {config.device}")

# Create trainer
trainer = WorldModelTrainer(
    config=config,
    state_dim=state_dim,
    action_dim=action_dim,
)

print("\n✅ Trainer created")

# Train the model
print("\n🚀 Starting training...")
history = trainer.train(train_loader, val_loader)

print("\n✅ Training complete!")


## Section 5: Evaluate Model Accuracy

Let's evaluate how well the model predicts next states and rewards.


In [None]:
# Evaluate on validation set
trainer.model.eval()

state_errors = []
reward_errors = []

with torch.no_grad():
    for batch in val_loader:
        states = batch["state"].to(config.device)
        actions = batch["action"].to(config.device)
        next_states_true = batch["next_state"].to(config.device)
        rewards_true = batch["reward"].to(config.device).unsqueeze(1)
        
        # Predict
        next_states_pred, rewards_pred = trainer.model(states, actions)
        
        # Compute errors
        state_error = torch.mean((next_states_pred - next_states_true) ** 2, dim=1)
        reward_error = torch.abs(rewards_pred - rewards_true).squeeze()
        
        state_errors.extend(state_error.cpu().numpy())
        reward_errors.extend(reward_error.cpu().numpy())

print("📊 Model Evaluation Results:")
print(f"   Mean state prediction MSE: {np.mean(state_errors):.4f}")
print(f"   Mean reward prediction error: {np.mean(reward_errors):.4f}")
print(f"   Reward error std: {np.std(reward_errors):.4f}")

# Visualize training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history["train_losses"], label="Train")
if history.get("val_losses"):
    axes[0].plot(history["val_losses"], label="Val")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training Loss")
axes[0].legend()
axes[0].grid(True)

axes[1].hist(state_errors, bins=50, alpha=0.7)
axes[1].set_xlabel("State Prediction MSE")
axes[1].set_ylabel("Frequency")
axes[1].set_title("State Prediction Error Distribution")
axes[1].grid(True)

plt.tight_layout()
plt.show()


## Section 6: Analysis and Comparison

### Advantages of Simple World Model

1. **Fast Training**: Can train RL policies 100x faster than browser-based training
2. **Simple Architecture**: Just an MLP - no complex video models or foundation models  
3. **Easy to Debug**: Simple architecture makes it easier to understand what's happening
4. **Lower Resource Requirements**: No need for GPU clusters or large models

### Limitations

1. **Prediction Accuracy**: May not capture all dynamics accurately (especially complex interactions)
2. **Distribution Shift**: Model trained on offline data may not generalize well to new scenarios
3. **No Visual Information**: Only models state transitions, not visual cues

### When to Use

- ✅ You have enough offline data for accurate dynamics learning
- ✅ Training speed is more important than perfect accuracy
- ✅ You want to quickly iterate on RL algorithms
- ✅ Resource constraints prevent using Cosmos

### Comparison with Other Approaches

- **vs Browser Training**: 100x faster, but potential accuracy loss
- **vs Cosmos**: Much simpler and faster to train, but less accurate
- **vs Decision Transformer**: Different approach - DT learns policies, world models learn dynamics

## Summary

We've successfully trained a simple MLP-based world model that can predict next states and rewards from current states and actions. This enables fast model-based RL training without the overhead of browser automation.

**Next Steps:**
- Train an RL policy in the learned model
- Compare model-based vs browser-based training
- Improve model accuracy with more data or better architecture
