# Trajectory Transformer Demo

This notebook demonstrates the Trajectory Transformer (TT) model, which models full trajectories autoregressively:

**Key Features:**
- Models sequence: s₀, a₀, r₀, s₁, a₁, r₁, ...
- Single transformer predicts all tokens (states, actions, rewards)
- Multi-head outputs for different prediction tasks
- Can be used for both behavior cloning and planning

**Sections:**
1. Generate synthetic trajectory data
2. Train Trajectory Transformer
3. Evaluate world model accuracy
4. Beam search planning
5. Comparison with other approaches

## Learning Objectives

By the end of this notebook, you will understand:

1. **Autoregressive Trajectory Modeling** - Predicting full sequences (s₀, a₀, r₀, s₁, a₁, r₁, ...)
2. **World Models** - Learning environment dynamics to predict future states and rewards
3. **Beam Search Planning** - Looking ahead multiple steps to select better actions
4. **Multi-Task Learning** - Joint training on state, action, and reward prediction
5. **Planning vs Reactive Policies** - When lookahead improves performance

**Estimated Time**: 30-40 minutes (includes data collection and training)
**Prerequisites**: Understanding of transformers, sequence modeling, planning concepts
**Hardware**: GPU recommended for transformer training

In [None]:
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

# Add parent directory to path
sys.path.insert(0, '..')

from models import TrajectoryTransformer, TrajectoryTransformerConfig, create_trajectory_transformer
from training import TrajectoryTransformerTrainer, TrainingConfig, create_trainer
from training import BeamSearchPlanner, PlannerConfig, create_planner
from poc.atc_rl import Simple2DATCEnv
from environment import get_device

print("Imports successful!")
print(f"PyTorch version: {torch.__version__}")
device = get_device()  # Auto-detects CUDA, Metal (MPS), or CPU
print(f"Using device: {device}")
if device == "cuda":
    print(f"CUDA device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
elif device == "mps":
    print("Metal Performance Shaders (Apple Silicon GPU) enabled")

## 1. Generate Trajectory Data

We'll use the Simple2DATCEnv to generate trajectory data for training.

In [None]:
def flatten_observation(obs):
    """Flatten observation dict to single vector."""
    aircraft_flat = obs["aircraft"].flatten()  # (max_aircraft * 6,)
    global_state = obs["global_state"]  # (4,)
    return np.concatenate([aircraft_flat, global_state])

def action_to_index(action):
    """Convert multi-discrete action to single index."""
    # Action space: [21, 5, 18, 13, 8]
    # Total: 21 * 5 * 18 * 13 * 8 = 196,560 actions (too large)
    # Simplification: use just aircraft_id * 5 + command_type
    aircraft_id, command_type, altitude, heading, speed = action
    return aircraft_id * 5 + command_type

def collect_episodes(env, num_episodes=1000, max_steps=50):
    """Collect trajectory data from environment."""
    episodes = []
    
    for ep in tqdm(range(num_episodes), desc="Collecting episodes"):
        obs, info = env.reset()
        
        states = []
        actions = []
        rewards = []
        
        for step in range(max_steps):
            # Flatten observation
            state = flatten_observation(obs)
            states.append(state)
            
            # Random action
            action = env.action_space.sample()
            action_idx = action_to_index(action)
            actions.append(action_idx)
            
            # Step environment
            obs, reward, terminated, truncated, info = env.step(action)
            rewards.append(reward)
            
            if terminated or truncated:
                break
        
        episodes.append({
            "states": np.array(states),
            "actions": np.array(actions),
            "rewards": np.array(rewards),
        })
    
    return episodes

# Create environment
env = Simple2DATCEnv(max_aircraft=5, max_steps=50)

print(f"Observation space: {env.observation_space}")
print(f"Action space: {env.action_space}")

# Get dimensions
obs, _ = env.reset()
state_dim = flatten_observation(obs).shape[0]
action_dim = 21 * 5  # Simplified action space

print(f"\nState dimension: {state_dim}")
print(f"Action dimension: {action_dim}")

In [None]:
# Collect training data
# ⏱️ ~5-10 minutes for 500 episodes
print("Collecting training episodes...")
episodes = collect_episodes(env, num_episodes=500, max_steps=50)

print(f"\nCollected {len(episodes)} episodes")
print(f"Average episode length: {np.mean([len(ep['states']) for ep in episodes]):.1f}")
print(f"Average episode return: {np.mean([ep['rewards'].sum() for ep in episodes]):.2f}")

In [None]:
# Pad episodes to same length
max_steps = 50

states_array = np.zeros((len(episodes), max_steps, state_dim), dtype=np.float32)
actions_array = np.zeros((len(episodes), max_steps), dtype=np.int64)
rewards_array = np.zeros((len(episodes), max_steps), dtype=np.float32)

for i, ep in enumerate(episodes):
    ep_len = len(ep["states"])
    states_array[i, :ep_len] = ep["states"]
    actions_array[i, :ep_len] = ep["actions"]
    rewards_array[i, :ep_len] = ep["rewards"]

print(f"States shape: {states_array.shape}")
print(f"Actions shape: {actions_array.shape}")
print(f"Rewards shape: {rewards_array.shape}")

## 2. Train Trajectory Transformer

Train the model to predict states, actions, and rewards autoregressively.

In [None]:
# Create trainer
trainer = create_trainer(
    state_dim=state_dim,
    action_dim=action_dim,
    embed_dim=128,
    num_layers=4,
    num_heads=4,
    ff_dim=512,
    context_length=20,
    batch_size=32,
    num_epochs=50,
    learning_rate=3e-4,
    state_loss_weight=1.0,
    action_loss_weight=1.0,
    reward_loss_weight=0.5,
)

print(f"Model parameters: {sum(p.numel() for p in trainer.model.parameters()):,}")

In [None]:
# Train the model
# ⏱️ ~10-15 minutes on GPU
print("Training Trajectory Transformer...")
trainer.train(states_array, actions_array, rewards_array)

In [None]:
# Plot training curves
metrics = trainer.get_metrics()

fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Training loss
axes[0].plot(metrics["train_losses"])
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training Loss")
axes[0].grid(True, alpha=0.3)

# Evaluation losses
if metrics["eval_losses"]:
    eval_steps = np.arange(len(metrics["eval_losses"]))
    total_losses = [l["total_loss"] for l in metrics["eval_losses"]]
    state_losses = [l["state_loss"] for l in metrics["eval_losses"]]
    action_losses = [l["action_loss"] for l in metrics["eval_losses"]]
    reward_losses = [l["reward_loss"] for l in metrics["eval_losses"]]
    
    axes[1].plot(eval_steps, total_losses, label="Total", linewidth=2)
    axes[1].plot(eval_steps, state_losses, label="State", alpha=0.7)
    axes[1].plot(eval_steps, action_losses, label="Action", alpha=0.7)
    axes[1].plot(eval_steps, reward_losses, label="Reward", alpha=0.7)
    axes[1].set_xlabel("Evaluation Step")
    axes[1].set_ylabel("Loss")
    axes[1].set_title("Validation Loss Components")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nBest validation loss: {metrics['best_loss']:.4f}")

## 3. Evaluate World Model Accuracy

Test how well the model predicts next states, actions, and rewards.

In [None]:
# Create planner for evaluation
# ⏱️ ~1-2 seconds per episode
planner = create_planner(
    trainer.model,
    beam_width=5,
    lookahead_steps=5,
)

# Evaluate on a few test episodes
num_test = 10
test_indices = np.random.choice(len(states_array), num_test, replace=False)

eval_results = []
for idx in test_indices:
    states = torch.FloatTensor(states_array[idx:idx+1])
    actions = torch.LongTensor(actions_array[idx:idx+1])
    rewards = torch.FloatTensor(rewards_array[idx:idx+1]).unsqueeze(-1)
    
    result = planner.evaluate_trajectory(states, actions, rewards)
    eval_results.append(result)

# Print average metrics
print("World Model Evaluation:")
print("=" * 50)
for key in eval_results[0].keys():
    avg_value = np.mean([r[key] for r in eval_results])
    print(f"{key:25s}: {avg_value:.4f}")

In [None]:
# Visualize predictions for one episode
test_idx = test_indices[0]
states_test = torch.FloatTensor(states_array[test_idx:test_idx+1])
actions_test = torch.LongTensor(actions_array[test_idx:test_idx+1])
rewards_test = torch.FloatTensor(rewards_array[test_idx:test_idx+1]).unsqueeze(-1)

# Get predictions
trainer.model.eval()
with torch.no_grad():
    pred_states, pred_actions, pred_rewards = trainer.model(
        states_test.to(trainer.config.device),
        actions_test.to(trainer.config.device),
        rewards_test.to(trainer.config.device)
    )

# Convert to numpy
pred_rewards_np = pred_rewards.cpu().numpy()[0, :, 0]
actual_rewards_np = rewards_test.numpy()[0, :, 0]
pred_actions_np = pred_actions.argmax(dim=-1).cpu().numpy()[0]
actual_actions_np = actions_test.numpy()[0]

# Find valid timesteps (non-zero rewards or actions)
valid_mask = (actual_rewards_np != 0) | (actual_actions_np != 0)
valid_steps = np.where(valid_mask)[0]
if len(valid_steps) > 0:
    max_step = valid_steps[-1] + 1
else:
    max_step = 20

fig, axes = plt.subplots(2, 1, figsize=(12, 8))

# Reward prediction
axes[0].plot(actual_rewards_np[:max_step], label="Actual", marker='o', alpha=0.7)
axes[0].plot(pred_rewards_np[:max_step], label="Predicted", marker='x', alpha=0.7)
axes[0].set_xlabel("Timestep")
axes[0].set_ylabel("Reward")
axes[0].set_title("Reward Prediction")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Action prediction
axes[1].plot(actual_actions_np[:max_step], label="Actual", marker='o', alpha=0.7)
axes[1].plot(pred_actions_np[:max_step], label="Predicted", marker='x', alpha=0.7)
axes[1].set_xlabel("Timestep")
axes[1].set_ylabel("Action Index")
axes[1].set_title("Action Prediction")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Compute accuracy for this episode
action_matches = (pred_actions_np[:max_step] == actual_actions_np[:max_step])
action_accuracy = action_matches.sum() / max_step
reward_mse = np.mean((pred_rewards_np[:max_step] - actual_rewards_np[:max_step]) ** 2)

print(f"Episode {test_idx}:")
print(f"  Action accuracy: {action_accuracy:.2%}")
print(f"  Reward MSE: {reward_mse:.4f}")

## 4. Beam Search Planning

Use the trained world model for planning with beam search.

In [None]:
# Test beam search planning
# ⏱️ ~20-30 seconds for 10 episodes
def test_planning(env, planner, num_episodes=5, max_steps=50):
    """Test planning in environment."""
    episode_returns = []
    
    for ep in range(num_episodes):
        obs, info = env.reset()
        
        states_history = []
        actions_history = []
        rewards_history = []
        total_reward = 0
        
        for step in range(max_steps):
            # Flatten observation
            state = flatten_observation(obs)
            states_history.append(state)
            
            # Convert history to tensors
            if len(states_history) > 1:
                states_tensor = torch.FloatTensor(np.array(states_history)).unsqueeze(0)
                actions_tensor = torch.LongTensor(np.array(actions_history)).unsqueeze(0)
                rewards_tensor = torch.FloatTensor(np.array(rewards_history)).unsqueeze(0).unsqueeze(-1)
            else:
                states_tensor = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0)
                actions_tensor = None
                rewards_tensor = None
            
            # Plan action
            action_idx = planner.plan(states_tensor, actions_tensor, rewards_tensor).item()
            
            # Convert to environment action
            aircraft_id = action_idx // 5
            command_type = action_idx % 5
            action = np.array([aircraft_id, command_type, 0, 0, 0])
            
            actions_history.append(action_idx)
            
            # Step environment
            obs, reward, terminated, truncated, info = env.step(action)
            rewards_history.append(reward)
            total_reward += reward
            
            if terminated or truncated:
                break
        
        episode_returns.append(total_reward)
        print(f"Episode {ep + 1}: Return = {total_reward:.2f}, Steps = {len(states_history)}")
    
    return episode_returns

print("Testing beam search planning...")
planning_returns = test_planning(env, planner, num_episodes=10)

print(f"\nPlanning Results:")
print(f"  Average return: {np.mean(planning_returns):.2f} +/- {np.std(planning_returns):.2f}")
print(f"  Max return: {np.max(planning_returns):.2f}")
print(f"  Min return: {np.min(planning_returns):.2f}")

## 5. Compare Planning Strategies

Compare beam search planning with random actions and greedy action selection.

In [None]:
# ⏱️ ~2-3 minutes total (3 strategies × 20 episodes)
def test_random_policy(env, num_episodes=10, max_steps=50):
    """Test random policy."""
    episode_returns = []
    
    for ep in range(num_episodes):
        obs, info = env.reset()
        total_reward = 0
        
        for step in range(max_steps):
            action = env.action_space.sample()
            obs, reward, terminated, truncated, info = env.step(action)
            total_reward += reward
            
            if terminated or truncated:
                break
        
        episode_returns.append(total_reward)
    
    return episode_returns

def test_greedy_policy(env, model, num_episodes=10, max_steps=50):
    """Test greedy action selection (no planning)."""
    episode_returns = []
    
    for ep in range(num_episodes):
        obs, info = env.reset()
        
        states_history = []
        actions_history = []
        rewards_history = []
        total_reward = 0
        
        for step in range(max_steps):
            state = flatten_observation(obs)
            states_history.append(state)
            
            # Get action from model (no planning, just immediate prediction)
            if len(states_history) > 1:
                states_tensor = torch.FloatTensor(np.array(states_history)).unsqueeze(0)
                actions_tensor = torch.LongTensor(np.array(actions_history)).unsqueeze(0)
                rewards_tensor = torch.FloatTensor(np.array(rewards_history)).unsqueeze(0).unsqueeze(-1)
            else:
                states_tensor = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0)
                actions_tensor = None
                rewards_tensor = None
            
            action_idx = model.get_action(states_tensor, actions_tensor, rewards_tensor, sample=False).item()
            
            aircraft_id = action_idx // 5
            command_type = action_idx % 5
            action = np.array([aircraft_id, command_type, 0, 0, 0])
            
            actions_history.append(action_idx)
            
            obs, reward, terminated, truncated, info = env.step(action)
            rewards_history.append(reward)
            total_reward += reward
            
            if terminated or truncated:
                break
        
        episode_returns.append(total_reward)
    
    return episode_returns

# Compare strategies
print("Testing different strategies...\n")

print("1. Random Policy:")
random_returns = test_random_policy(env, num_episodes=20)
print(f"   Mean: {np.mean(random_returns):.2f} +/- {np.std(random_returns):.2f}\n")

print("2. Greedy Policy (no planning):")
greedy_returns = test_greedy_policy(env, trainer.model, num_episodes=20)
print(f"   Mean: {np.mean(greedy_returns):.2f} +/- {np.std(greedy_returns):.2f}\n")

print("3. Beam Search Planning:")
planning_returns_full = test_planning(env, planner, num_episodes=20)
print(f"   Mean: {np.mean(planning_returns_full):.2f} +/- {np.std(planning_returns_full):.2f}")

In [None]:
# Visualize comparison
fig, ax = plt.subplots(figsize=(10, 6))

strategies = ['Random', 'Greedy', 'Beam Search']
returns = [random_returns, greedy_returns, planning_returns_full]
colors = ['#ff6b6b', '#4ecdc4', '#45b7d1']

bp = ax.boxplot(returns, labels=strategies, patch_artist=True)
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax.set_ylabel('Episode Return')
ax.set_title('Comparison of Action Selection Strategies')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Statistical summary
print("\nDetailed Comparison:")
print("=" * 60)
for strategy, returns_list in zip(strategies, returns):
    print(f"{strategy:15s}: {np.mean(returns_list):6.2f} +/- {np.std(returns_list):5.2f} "
          f"(min: {np.min(returns_list):6.2f}, max: {np.max(returns_list):6.2f})")

## Common Pitfalls & Troubleshooting

### Problem 1: "World model predictions diverge quickly"
**Solution**: Error accumulates during autoregressive prediction. Mitigate with:
- **Shorter planning horizons**: Reduce lookahead from 10 to 3-5 steps
- **More training data**: Collect 2000+ episodes
- **Larger model**: Increase layers from 4 to 6

```python
planner = create_planner(
    model,
    lookahead_steps=3,  # Shorter horizon
)
```

### Problem 2: Beam search worse than greedy selection
**Causes**:
- **World model inaccurate**: Predictions unreliable for planning
- **Beam width too large**: Exploring low-probability branches
- **Reward model poor**: Incorrect reward predictions guide planning wrong

**Solution**: Train world model longer or use smaller beam width:
```python
config = PlannerConfig(
    beam_width=3,  # Reduce from 5
    lookahead_steps=3,  # Shorter horizon
)
```

### Problem 3: "Training is very slow (>1 hour for 50 epochs)"
**Solution**: 
- Use GPU acceleration
- Reduce sequence length
- Smaller model (fewer layers/heads)

```python
config = TrainingConfig(
    num_layers=3,  # Down from 4
    num_heads=4,   # Down from 8
)
```

### Problem 4: Action prediction accuracy low (<60%)
**Causes**:
- **Simplified action space too large**: Reduce action space
- **Not enough training data**: Collect more episodes
- **Model underfitting**: Increase model capacity

**Solution**: Simplify action space further or collect better data:
```python
def action_to_index(action):
    # Use only aircraft_id and command_type (ignore parameters)
    return action[0] * 5 + action[1]
```

### Problem 5: State predictions look wrong visually
**This is expected!** States are high-dimensional and hard to predict exactly. Focus on:
- Reward prediction accuracy (more important for planning)
- Action prediction accuracy (indicates model understands task)
- Beam search performance (ultimate metric)

### Problem 6: "RuntimeError: sequence length mismatch"
**Solution**: Ensure consistent padding across states, actions, rewards:
```python
# All should have same length (max_steps)
states_array = np.zeros((num_episodes, max_steps, state_dim))
actions_array = np.zeros((num_episodes, max_steps))
rewards_array = np.zeros((num_episodes, max_steps))
```

### Debugging Tips:
1. **Visualize one-step predictions**: Check accuracy before multi-step
2. **Compare beam widths**: Plot performance vs beam_width (1, 3, 5, 10)
3. **Analyze planning failures**: When does beam search pick wrong actions?
4. **Test on simple scenarios**: Start with 2-3 aircraft before scaling

**Need more help?** See Trajectory Transformer paper or model-based RL literature.

## Summary

**Trajectory Transformer Key Insights:**

1. **Unified Model**: Single transformer predicts states, actions, and rewards
2. **World Model**: Can predict future trajectories with reasonable accuracy
3. **Planning**: Beam search enables looking ahead and selecting better actions
4. **Multi-task Learning**: Joint training on all three prediction tasks helps regularization

**Comparison with Other Approaches:**

- **vs. Decision Transformer**: TT models the full trajectory including states and rewards, enabling planning
- **vs. PPO**: TT is offline and can be trained on any trajectory data, but may need more data
- **vs. Random**: TT learns patterns from data, significantly outperforming random actions

**Next Steps:**

1. Train on higher quality data (from expert demonstrations or trained policies)
2. Experiment with different beam widths and lookahead depths
3. Fine-tune on specific scenarios or objectives
4. Use the world model for model-based RL or data augmentation