# Multi-Agent RL for Air Traffic Control

This notebook demonstrates multi-agent reinforcement learning where each aircraft is an independent cooperative agent.

## Key Concepts

**Multi-Agent Formulation:**
- Each aircraft = independent agent with shared policy
- Variable number of agents (aircraft spawn and exit dynamically)
- Agents communicate via attention mechanism
- Cooperative objective: maximize team reward

**MAPPO Architecture:**
- **Centralized Critic**: Sees global state (all aircraft + environment)
- **Decentralized Actors**: See local observations + communication
- **Communication**: Self-attention between agents enables coordination
- **Training**: Centralized training, decentralized execution (CTDE)

**Emergent Behaviors:**
- Coordination patterns emerge from communication
- Agents learn to share responsibility
- Conflict avoidance through implicit negotiation

## 📚 Learning Objectives

By the end of this notebook, you will understand:

1. **Multi-Agent Formulation** - Each aircraft as independent cooperative agent with shared policy
2. **CTDE (Centralized Training, Decentralized Execution)** - Centralized critic, decentralized actors
3. **Agent Communication** - Self-attention mechanism enabling coordination between aircraft
4. **Emergent Coordination** - How cooperative behaviors emerge from learned communication
5. **Scalability** - Parameter count doesn't grow with aircraft count (shared policy advantage)

**Estimated Time**: 15-20 minutes (demonstration only, no training)
**Prerequisites**: Understanding of RL, attention mechanisms, multi-agent concepts helpful
**Hardware**: CPU sufficient for forward passes

In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch

# Add parent directory to path
sys.path.insert(0, '..')

from models import MultiAgentPolicy, create_default_network_config
from training import MAPPOTrainer, MAPPOConfig

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# Notebook settings
%load_ext autoreload
%autoreload 2
%matplotlib inline

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

## 1. Multi-Agent Policy Architecture

Let's create and inspect the multi-agent policy architecture.

In [None]:
# Create network configuration
config = create_default_network_config(
    max_aircraft=10,
    aircraft_feature_dim=14,
    global_feature_dim=4,
    hidden_dim=256,
    num_encoder_layers=4,
    num_attention_heads=8
)

# Create multi-agent policy
policy = MultiAgentPolicy(config)

# Print architecture
print("Multi-Agent Policy Architecture:")
print("=" * 60)
print(policy)
print("\n" + "=" * 60)

# Parameter counts
params = policy.count_parameters()
print("\nParameter Breakdown:")
print(f"  Total Parameters: {params['total_parameters']:,}")
print(f"  Encoder: {params['encoder_parameters']:,}")
print(f"  Communication: {params['communication_parameters']:,}")
print(f"  Actor: {params['actor_parameters']:,}")
print(f"  Critic: {params['critic_parameters']:,}")

## 2. Forward Pass Example

Let's see how the policy processes observations and produces actions.

In [None]:
# ⏱️ ~5 seconds

# Create sample observation
batch_size = 2
num_aircraft = 5

obs = {
    "aircraft": torch.randn(batch_size, 10, 14),  # All aircraft features
    "aircraft_mask": torch.zeros(batch_size, 10, dtype=torch.bool),
    "global_state": torch.randn(batch_size, 4)
}

# Set first 5 aircraft as active
obs["aircraft_mask"][:, :num_aircraft] = True

# Forward pass
with torch.no_grad():
    action_logits, value, comm_attention = policy(obs, return_communication=True)

print("\nForward Pass Results:")
print("=" * 60)
print(f"\nAction Logits:")
for key, logits in action_logits.items():
    print(f"  {key}: {logits.shape}")
print(f"\nValue: {value.shape}")
print(f"\nCommunication Attention Layers: {len(comm_attention)}")
if comm_attention:
    print(f"  Each layer shape: {comm_attention[0].shape}")
    print(f"  Format: (batch_size, num_heads, num_agents, num_agents)")

## 3. Communication Visualization

Visualize attention patterns to understand agent communication.

In [None]:
def visualize_communication(attention_weights, agent_mask, layer_idx=0, head_idx=0, batch_idx=0):
    """
    Visualize communication attention between agents.
    
    Args:
        attention_weights: List of attention weight tensors
        agent_mask: Boolean mask indicating active agents
        layer_idx: Which communication layer to visualize
        head_idx: Which attention head to visualize
        batch_idx: Which batch item to visualize
    """
    # Get attention for specified layer/head/batch
    attn = attention_weights[layer_idx][batch_idx, head_idx].cpu().numpy()
    
    # Get active agents
    active = agent_mask[batch_idx].cpu().numpy()
    num_active = active.sum()
    
    # Filter to only active agents
    attn_active = attn[:num_active, :num_active]
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Heatmap
    sns.heatmap(
        attn_active,
        annot=True,
        fmt='.3f',
        cmap='YlOrRd',
        square=True,
        cbar_kws={'label': 'Attention Weight'},
        ax=ax1
    )
    ax1.set_xlabel('Key Agent')
    ax1.set_ylabel('Query Agent')
    ax1.set_title(f'Communication Attention (Layer {layer_idx}, Head {head_idx})')
    
    # Communication graph
    positions = np.array([[np.cos(2*np.pi*i/num_active), np.sin(2*np.pi*i/num_active)] 
                          for i in range(num_active)])
    
    # Draw nodes
    ax2.scatter(positions[:, 0], positions[:, 1], s=500, c='skyblue', edgecolors='black', zorder=3)
    
    # Label nodes
    for i, pos in enumerate(positions):
        ax2.text(pos[0], pos[1], f'A{i}', ha='center', va='center', fontsize=12, fontweight='bold')
    
    # Draw edges (communication links)
    threshold = 0.1  # Only show strong attention
    for i in range(num_active):
        for j in range(num_active):
            if i != j and attn_active[i, j] > threshold:
                # Draw arrow from j to i (j is key, i is query)
                dx = positions[i, 0] - positions[j, 0]
                dy = positions[i, 1] - positions[j, 1]
                
                # Scale arrow by attention weight
                alpha = min(attn_active[i, j] * 2, 1.0)
                width = attn_active[i, j] * 5
                
                ax2.arrow(
                    positions[j, 0], positions[j, 1],
                    dx * 0.8, dy * 0.8,
                    width=width,
                    head_width=0.1,
                    head_length=0.1,
                    alpha=alpha,
                    color='red',
                    zorder=2
                )
    
    ax2.set_xlim(-1.5, 1.5)
    ax2.set_ylim(-1.5, 1.5)
    ax2.set_aspect('equal')
    ax2.axis('off')
    ax2.set_title('Communication Graph\n(arrows show attention flow)')
    
    plt.tight_layout()
    plt.show()

# Visualize communication
if comm_attention:
    visualize_communication(comm_attention, obs["aircraft_mask"], layer_idx=0, head_idx=0)

## 4. Training with MAPPO

Set up and run MAPPO training (requires environment).

In [None]:
# NOTE: This requires a working ATC environment
# Uncomment and modify the following to train:

# from environment import PlaywrightEnv, create_default_config

# # Create environment
# env_config = create_default_config(
#     airport="KSFO",
#     max_aircraft=10,
#     headless=True,
#     timewarp=10
# )
# env = PlaywrightEnv(**env_config.__dict__)

# # Create training config
# train_config = MAPPOConfig(
#     max_aircraft=10,
#     hidden_dim=256,
#     learning_rate=3e-4,
#     total_timesteps=100_000,
#     steps_per_rollout=2048,
#     log_dir="logs/mappo_demo",
#     save_dir="checkpoints/mappo_demo"
# )

# # Create trainer
# trainer = MAPPOTrainer(env, train_config)

# # Train
# trainer.train()

print("Training code ready (uncomment to run with environment)")

## 5. Analyzing Emergent Coordination

After training, we can analyze emergent coordination patterns.

In [None]:
# ⏱️ ~30-60 seconds

def analyze_coordination_patterns(policy, num_scenarios=10, num_aircraft=5):
    """
    Analyze coordination patterns across multiple scenarios.
    """
    all_attention = []
    
    for _ in range(num_scenarios):
        # Generate random scenario
        obs = {
            "aircraft": torch.randn(1, 10, 14),
            "aircraft_mask": torch.zeros(1, 10, dtype=torch.bool),
            "global_state": torch.randn(1, 4)
        }
        obs["aircraft_mask"][0, :num_aircraft] = True
        
        # Get communication
        with torch.no_grad():
            _, _, comm_attn = policy(obs, return_communication=True)
        
        # Store first layer attention
        if comm_attn:
            # Average over heads
            attn = comm_attn[0][0].mean(dim=0).cpu().numpy()  # (num_agents, num_agents)
            all_attention.append(attn[:num_aircraft, :num_aircraft])
    
    # Compute statistics
    all_attention = np.array(all_attention)
    mean_attn = all_attention.mean(axis=0)
    std_attn = all_attention.std(axis=0)
    
    # Visualize
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Mean attention
    sns.heatmap(
        mean_attn,
        annot=True,
        fmt='.3f',
        cmap='YlOrRd',
        square=True,
        ax=ax1
    )
    ax1.set_title(f'Mean Attention Across {num_scenarios} Scenarios')
    ax1.set_xlabel('Key Agent')
    ax1.set_ylabel('Query Agent')
    
    # Attention variance
    sns.heatmap(
        std_attn,
        annot=True,
        fmt='.3f',
        cmap='Blues',
        square=True,
        ax=ax2
    )
    ax2.set_title('Attention Std Dev (Consistency)')
    ax2.set_xlabel('Key Agent')
    ax2.set_ylabel('Query Agent')
    
    plt.tight_layout()
    plt.show()
    
    # Print insights
    print("\nCoordination Insights:")
    print("=" * 60)
    print(f"Mean self-attention (diagonal): {np.diag(mean_attn).mean():.3f}")
    print(f"Mean cross-attention (off-diagonal): {(mean_attn.sum() - np.diag(mean_attn).sum()) / (num_aircraft * (num_aircraft - 1)):.3f}")
    print(f"\nMost consistent attention (lowest variance): {std_attn.min():.3f}")
    print(f"Most variable attention: {std_attn.max():.3f}")

# Analyze coordination (with untrained policy)
analyze_coordination_patterns(policy, num_scenarios=20, num_aircraft=5)

## 6. Comparison: Single-Agent vs Multi-Agent

Compare the approaches conceptually.

In [None]:
comparison_data = {
    "Aspect": [
        "Decision Making",
        "Scalability",
        "Communication",
        "Training Complexity",
        "Action Space",
        "Coordination",
        "Sample Efficiency",
        "Emergent Behavior"
    ],
    "Single-Agent": [
        "Centralized: selects one aircraft at a time",
        "Difficult with many aircraft",
        "Not modeled",
        "Standard PPO",
        "Discrete (select aircraft + command)",
        "Learned implicitly",
        "Good with stable environment",
        "Limited"
    ],
    "Multi-Agent": [
        "Decentralized: each aircraft decides independently",
        "Naturally scales with aircraft count",
        "Explicit via attention mechanism",
        "MAPPO (more complex)",
        "Per-agent actions (parallel)",
        "Learned explicitly through communication",
        "Better with cooperation",
        "Rich coordination patterns"
    ]
}

import pandas as pd

df = pd.DataFrame(comparison_data)
print("\nSingle-Agent vs Multi-Agent Comparison:")
print("=" * 100)
print(df.to_string(index=False))
print("\n" + "=" * 100)

## 7. Common Pitfalls & Troubleshooting

### Problem 1: "Non-stationary environment - training unstable"
**Solution**: This is inherent to multi-agent RL. Mitigate with:
- **Centralized critic**: See global state to stabilize value estimates
- **Parameter sharing**: Reduces variance across agents
- **Larger replay buffer**: Smooths out non-stationarity

```python
config = MAPPOConfig(
    buffer_size=100000,  # Increase buffer
    batch_size=256,       # Larger batches
)
```

### Problem 2: Agents ignore communication (all self-attention)
**Causes**:
- **Insufficient training**: Communication patterns emerge later
- **Reward not cooperative**: Agents have no incentive to coordinate
- **Attention not trained**: Check gradient flow to communication layers

**Solution**: Add cooperation bonus to reward:
```python
reward = base_reward + 0.1 * coordination_bonus
```

### Problem 3: "Variable agent count causes shape mismatches"
**Solution**: Always use aircraft_mask properly:
```python
obs = {
    "aircraft": torch.randn(batch, max_aircraft, 14),
    "aircraft_mask": mask,  # Critical! Mask inactive agents
}
```

### Problem 4: Credit assignment problem - which agent to reward?
**Inherent challenge in multi-agent RL**. MAPPO addresses with:
- Value decomposition (each agent's contribution)
- Global reward shared by all agents
- Counterfactual baselines

**Tip**: Monitor individual agent metrics to debug credit assignment.

### Problem 5: Communication graph shows weak/uniform attention
**Causes**:
- **Model not trained**: Attention strengthens during training
- **Too much regularization**: Reduce attention dropout
- **Homogeneous scenarios**: Test with diverse conflicts

### Problem 6: Scaling to many aircraft (20+) causes memory issues
**Solution**: 
- Use gradient checkpointing
- Reduce hidden dimensions
- Process agents in chunks

```python
config = create_default_network_config(
    hidden_dim=128,  # Reduce from 256
    num_attention_heads=4,  # Reduce from 8
)
```

### Debugging Tips:
1. **Start with 3 agents**: Easier to visualize communication
2. **Visualize attention every epoch**: Track evolution of coordination
3. **Compare to single-agent**: Verify multi-agent actually helps
4. **Check gradient norms**: Ensure all agents learning equally

**Need more help?** See multi-agent RL surveys or MAPPO paper.

## 9. Key Insights

### Advantages of Multi-Agent Formulation:

1. **Natural Scalability**: Number of parameters doesn't grow with aircraft count (shared policy)
2. **Parallel Decision Making**: All aircraft decide simultaneously (more realistic)
3. **Explicit Communication**: Attention mechanism allows agents to share information
4. **Emergent Coordination**: Cooperative behaviors emerge from communication
5. **Decentralized Execution**: Robust to communication failures at test time

### Challenges:

1. **Training Complexity**: MAPPO is more complex than standard PPO
2. **Non-Stationarity**: Each agent's environment changes as others learn
3. **Credit Assignment**: Hard to determine which agent deserves credit/blame
4. **Variable Agents**: Must handle aircraft spawning/exiting gracefully

### Future Improvements:

1. **Graph Neural Networks**: Use GNN for more sophisticated spatial reasoning
2. **Hierarchical Communication**: Multi-level coordination (local clusters + global)
3. **Curriculum Learning**: Start with few aircraft, gradually increase
4. **Self-Play**: Agents learn against versions of themselves
5. **Meta-Learning**: Learn to adapt to different traffic patterns

## 8. Inference Benchmark

Let's benchmark the policy's inference speed to understand real-time performance.

In [ ]:
# ⏱️ ~10-15 seconds for 1000 iterations

import time

def benchmark_inference(policy, num_iterations=1000, num_aircraft=5):
    """
    Benchmark inference speed for the multi-agent policy.
    """
    # Prepare observation
    obs = {
        "aircraft": torch.randn(1, 10, 14),
        "aircraft_mask": torch.zeros(1, 10, dtype=torch.bool),
        "global_state": torch.randn(1, 4)
    }
    obs["aircraft_mask"][0, :num_aircraft] = True
    
    # Warm-up
    with torch.no_grad():
        for _ in range(10):
            _ = policy(obs, return_communication=False)
    
    # Benchmark
    start_time = time.time()
    with torch.no_grad():
        for _ in range(num_iterations):
            _ = policy(obs, return_communication=False)
    end_time = time.time()
    
    # Results
    total_time = end_time - start_time
    avg_time = total_time / num_iterations
    fps = num_iterations / total_time
    
    print(f"Inference Benchmark ({num_iterations} iterations, {num_aircraft} aircraft):")
    print("=" * 60)
    print(f"Total time: {total_time:.3f} seconds")
    print(f"Average time per inference: {avg_time*1000:.3f} ms")
    print(f"Throughput: {fps:.1f} inferences/second")
    print(f"\nReal-time capability: {'✓ Yes' if avg_time < 0.1 else '✗ No'} (target: <100ms)")

# Run benchmark
benchmark_inference(policy, num_iterations=1000, num_aircraft=5)