# Phase 2 Robot Policy Learning - Async Training Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mensch72/empo/blob/main/notebooks/phase2_async_colab_demo.ipynb)

This notebook demonstrates **Phase 2** of the EMPO framework - learning a robot policy
that maximizes aggregate human power. Based on equations (4)-(9) from the paper.

**Features:**
- GPU-accelerated training
- Async actor-learner architecture (optional)
- Model-based targets for stable Q-learning
- Warmup stages with buffer clearing

**Networks trained:**
- `Q_r`: Robot state-action value (eq. 4)
- `V_h^e`: Human goal achievement under robot policy (eq. 6)
- `X_h`: Aggregate goal achievement ability (eq. 7)
- `U_r`: Intrinsic robot reward (eq. 8)
- `V_r`: Robot state value (eq. 9)

## 1. Setup

Clone the repository and install dependencies.

In [None]:
# Clone the EMPO repository
!git clone --depth 1 https://github.com/mensch72/empo.git
%cd empo

In [None]:
# Install Python dependencies
!pip install -q -r requirements-colab.txt
print("Dependencies installed")

In [None]:
# Set up Python paths
import sys
import os

repo_root = os.getcwd()
sys.path.insert(0, os.path.join(repo_root, 'src'))
sys.path.insert(0, os.path.join(repo_root, 'vendor', 'multigrid'))

print(f"PYTHONPATH configured")
print(f"  Repository root: {repo_root}")

## 2. Verify GPU and Imports

In [None]:
import torch
import numpy as np
import time
import random

print("=" * 60)
print("System Information")
print("=" * 60)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    device = 'cuda'
else:
    print("WARNING: No GPU available - training will be slower")
    print("   Go to Runtime -> Change runtime type -> GPU")
    device = 'cpu'

print(f"\nUsing device: {device}")

In [None]:
# Test EMPO imports
print("Testing imports...")

from gym_multigrid.multigrid import MultiGridEnv, World, SmallActions
print("  MultiGrid imports OK")

from empo.multigrid import MultiGridGoalSampler, ReachCellGoal
from empo.possible_goal import TabularGoalSampler
from empo.human_policy_prior import HeuristicPotentialPolicy
from empo.nn_based.multigrid import PathDistanceCalculator
print("  EMPO core imports OK")

from empo.nn_based.phase2.config import Phase2Config
from empo.nn_based.multigrid.phase2 import train_multigrid_phase2
print("  Phase 2 imports OK")

print("\nAll imports successful!")

## 3. Create Environment

A simple grid with 1 robot (grey) that should help 1 human (yellow) reach their goal.

In [None]:
# Environment configuration
GRID_MAP = """
We We We We We We
We Ae Ro .. .. We
We We Ay We We We
We We We We We We
"""
# Ae = grey robot (agent), Ay = yellow human, Ro = rock, We = wall

MAX_STEPS = 10
FINAL_BETA_R = 100.0  # Robot policy concentration


class Phase2DemoEnv(MultiGridEnv):
    """Simple grid environment for Phase 2 demo."""
    
    def __init__(self, max_steps=MAX_STEPS):
        super().__init__(
            map=GRID_MAP,
            max_steps=max_steps,
            partial_obs=False,
            objects_set=World,
            actions_set=SmallActions
        )
        self.num_humans = sum(1 for a in self.agents if a.color == 'yellow')
        self.num_robots = sum(1 for a in self.agents if a.color == 'grey')


# Create and inspect environment
env = Phase2DemoEnv()
env.reset()

print("=" * 60)
print("Environment")
print("=" * 60)
print(f"Grid size: {env.width} x {env.height}")
print(f"Max steps: {env.max_steps}")
print(f"Action space: {env.action_space.n} actions")

# Identify agents
human_indices = []
robot_indices = []
for i, agent in enumerate(env.agents):
    if agent.color == 'yellow':
        human_indices.append(i)
        print(f"Human {i}: pos={tuple(agent.pos)}")
    elif agent.color == 'grey':
        robot_indices.append(i)
        print(f"Robot {i}: pos={tuple(agent.pos)}")

print(f"\n{len(human_indices)} human(s), {len(robot_indices)} robot(s)")

In [None]:
# Visualize the environment
import matplotlib.pyplot as plt

img = env.render(mode='rgb_array', highlight=False)
plt.figure(figsize=(6, 4))
plt.imshow(img)
plt.title('Initial Environment State')
plt.axis('off')
plt.show()

## 4. Setup Human Policy and Goal Sampler

In [None]:
print("Setting up human policy and goal sampler...")

# Path calculator for heuristic policy
path_calc = PathDistanceCalculator(
    grid_height=env.height,
    grid_width=env.width,
    world_model=env
)
print("  Path calculator created")

# Human policy using HeuristicPotentialPolicy
human_policy = HeuristicPotentialPolicy(
    world_model=env,
    human_agent_indices=human_indices,
    path_calculator=path_calc,
    beta=10.0  # Quite deterministic
)
print("  Human policy created")

# Goal sampler: human wants to reach cell (2,1) with 90% prob, (2,2) with 10%
goal_sampler = TabularGoalSampler([
    ReachCellGoal(env, 1, (2, 1)),
    ReachCellGoal(env, 1, (2, 2))
], probabilities=[0.9, 0.1])
print("  Goal sampler created")

# Wrapper functions for trainer interface
def goal_sampler_fn(state, human_idx):
    goal, _ = goal_sampler.sample(state, human_idx)
    return goal

def human_policy_fn(state, human_idx, goal):
    return human_policy.sample(state, human_idx, goal)

print("\nAll components ready!")

## 5. Configure Training

Choose between **synchronous** and **async** training modes.

**Note:** Async training with multiprocessing may not work reliably in Colab notebooks due to the `spawn` context limitations. If async fails, the notebook will fall back to synchronous training.

In [None]:
# Training mode selection
USE_ASYNC = False  # Set to True to try async training (may not work in Colab)
QUICK_MODE = True  # Set to False for full training

if QUICK_MODE:
    NUM_EPISODES = 500
    HIDDEN_DIM = 16
    GOAL_FEATURE_DIM = 8
    AGENT_EMBEDDING_DIM = 4
    print("[QUICK MODE] Reduced episodes and network size")
else:
    NUM_EPISODES = 3000
    HIDDEN_DIM = 32
    GOAL_FEATURE_DIM = 16
    AGENT_EMBEDDING_DIM = 8
    print("[FULL MODE] Full training configuration")

# Create configuration
config = Phase2Config(
    # Discount factors
    gamma_r=0.95,
    gamma_h=0.95,
    
    # Preference parameters
    zeta=2.0,   # Risk aversion
    xi=1.0,     # Inter-human inequality aversion
    eta=1.1,    # Intertemporal inequality aversion
    
    # Policy concentration
    beta_r=FINAL_BETA_R,
    
    # Exploration
    epsilon_r_start=1.0,
    epsilon_r_end=0.1,
    epsilon_r_decay_steps=NUM_EPISODES * 10,
    
    # Learning rates
    lr_q_r=1e-4,
    lr_v_r=1e-4,
    lr_v_h_e=1e-3,  # Higher for V_h_e (critical network)
    lr_x_h=1e-4,
    lr_u_r=1e-4,
    
    # Buffer and batching
    buffer_size=10000,
    batch_size=16,
    x_h_batch_size=32,
    
    # Training loop
    num_episodes=NUM_EPISODES,
    steps_per_episode=MAX_STEPS,
    updates_per_step=1,
    goal_resample_prob=0.1,
    
    # Target networks
    v_h_target_update_freq=100,
    
    # Model-based targets (for stable Q-learning)
    use_model_based_targets=True,
    
    # Async training (optional)
    async_training=USE_ASYNC,
    num_actors=2,
    actor_sync_freq=50,
    async_min_buffer_size=200,
    async_queue_size=5000,
)

print("\n" + "=" * 60)
print("Training Configuration")
print("=" * 60)
print(f"Episodes: {config.num_episodes}")
print(f"Steps per episode: {config.steps_per_episode}")
print(f"Device: {device}")
print(f"Async training: {config.async_training}")
if config.async_training:
    print(f"  Actors: {config.num_actors}")
    print(f"  Sync frequency: {config.actor_sync_freq}")
print(f"Model-based targets: {config.use_model_based_targets}")
print(f"Network hidden dim: {HIDDEN_DIM}")

## 6. Train the Robot Policy

This will train all Phase 2 networks with warmup stages:
1. **Stage 1**: V_h_e only
2. **Stage 2**: V_h_e + X_h
3. **Stage 3**: V_h_e + X_h + U_r
4. **Stage 4**: V_h_e + X_h + U_r + Q_r
5. **beta_r ramping**: All networks with increasing policy concentration
6. **Full training**: All networks with learning rate decay

In [None]:
print("=" * 60)
print("Phase 2 Robot Policy Training")
print("=" * 60)
print()

# Output directory
output_dir = 'outputs/phase2_colab_demo'
os.makedirs(output_dir, exist_ok=True)
tensorboard_dir = os.path.join(output_dir, 'tensorboard')

t0 = time.time()

try:
    robot_q_network, networks, history = train_multigrid_phase2(
        world_model=env,
        human_agent_indices=human_indices,
        robot_agent_indices=robot_indices,
        human_policy_prior=human_policy_fn,
        goal_sampler=goal_sampler_fn,
        config=config,
        hidden_dim=HIDDEN_DIM,
        goal_feature_dim=GOAL_FEATURE_DIM,
        agent_embedding_dim=AGENT_EMBEDDING_DIM,
        device=device,
        verbose=True,
        debug=False,
        tensorboard_dir=tensorboard_dir
    )
    training_success = True
    
except Exception as e:
    if config.async_training:
        print(f"\nAsync training failed: {e}")
        print("Falling back to synchronous training...\n")
        
        # Retry with sync training
        config.async_training = False
        t0 = time.time()
        
        robot_q_network, networks, history = train_multigrid_phase2(
            world_model=env,
            human_agent_indices=human_indices,
            robot_agent_indices=robot_indices,
            human_policy_prior=human_policy_fn,
            goal_sampler=goal_sampler_fn,
            config=config,
            hidden_dim=HIDDEN_DIM,
            goal_feature_dim=GOAL_FEATURE_DIM,
            agent_embedding_dim=AGENT_EMBEDDING_DIM,
            device=device,
            verbose=True,
            debug=False,
            tensorboard_dir=tensorboard_dir
        )
        training_success = True
    else:
        print(f"\nTraining failed: {e}")
        training_success = False
        raise

elapsed = time.time() - t0

print()
print("=" * 60)
print(f"Training completed in {elapsed:.1f} seconds")
print(f"  Mode: {'Async' if config.async_training else 'Synchronous'}")
print(f"  Device: {device}")
print(f"  Episodes: {len(history)}")
print("=" * 60)

In [None]:
# Show final loss values
if history and len(history) > 0:
    print("\nLoss history (last 5 episodes):")
    print("-" * 50)
    for i, losses in enumerate(history[-5:]):
        episode_num = len(history) - 5 + i
        loss_str = ", ".join(f"{k}={v:.4f}" for k, v in losses.items() if v > 0)
        print(f"Episode {episode_num}: {loss_str}")

## 7. Evaluate Learned Policy

Run rollouts with the learned robot policy and visualize.

In [None]:
# Action names for display
ACTION_NAMES = ['still', 'left', 'right', 'forward']

def run_evaluation_rollout(env, robot_q_network, human_policy, goal_sampler, 
                           human_indices, robot_indices, device, config):
    """Run one rollout and return frames with Q-values."""
    robot_q_network.eval()
    env.reset()
    
    # Sample goal
    state = env.get_state()
    human_goals = {}
    for h in human_indices:
        goal, _ = goal_sampler.sample(state, h)
        human_goals[h] = goal
    
    frames = []
    q_values_history = []
    
    for step in range(env.max_steps):
        state = env.get_state()
        
        # Get Q-values
        with torch.no_grad():
            q_values = robot_q_network.encode_and_forward(state, env, device)
            q_np = q_values.squeeze().cpu().numpy()
            q_values_history.append(q_np.copy())
        
        # Get actions
        actions = [0] * len(env.agents)
        
        for h in human_indices:
            actions[h] = human_policy.sample(state, h, human_goals[h])
        
        with torch.no_grad():
            robot_action = robot_q_network.sample_action(
                q_values, epsilon=0.0, beta_r=config.beta_r
            )
            for i, r in enumerate(robot_indices):
                if i < len(robot_action):
                    actions[r] = robot_action[i]
        
        # Render and step
        frame = env.render(mode='rgb_array', highlight=False)
        frames.append(frame)
        
        _, _, done, _ = env.step(actions)
        if done:
            break
    
    return frames, q_values_history

print("Running evaluation rollout...")
frames, q_history = run_evaluation_rollout(
    env, robot_q_network, human_policy, goal_sampler,
    human_indices, robot_indices, device, config
)
print(f"Rollout completed: {len(frames)} steps")

In [None]:
# Display Q-values over rollout
print("\nQ-values during rollout:")
print("-" * 60)
print(f"{'Step':>4}  " + "  ".join(f"{name:>8}" for name in ACTION_NAMES))
print("-" * 60)

for step, q_vals in enumerate(q_history):
    best_action = np.argmax(q_vals)
    q_str = "  ".join(
        f"{q_vals[i]:>8.3f}" + ("*" if i == best_action else " ")
        for i in range(len(ACTION_NAMES))
    )
    print(f"{step:>4}  {q_str}")

print("\n(* = selected action)")

In [None]:
# Visualize rollout frames
fig, axes = plt.subplots(1, min(len(frames), 6), figsize=(15, 3))
if len(frames) == 1:
    axes = [axes]

for i, ax in enumerate(axes):
    if i < len(frames):
        ax.imshow(frames[i])
        ax.set_title(f'Step {i}')
    ax.axis('off')

plt.suptitle('Evaluation Rollout', fontsize=14)
plt.tight_layout()
plt.show()

## 8. Generate Movie (Optional)

Generate a video of multiple rollouts.

In [None]:
NUM_ROLLOUTS = 5
MOVIE_FPS = 2

print(f"Generating {NUM_ROLLOUTS} rollouts...")

env.start_video_recording()

for rollout_idx in range(NUM_ROLLOUTS):
    env.reset()
    
    # Sample goal
    state = env.get_state()
    human_goals = {h: goal_sampler_fn(state, h) for h in human_indices}
    
    env.render(mode='rgb_array', highlight=False)
    
    for step in range(env.max_steps):
        state = env.get_state()
        actions = [0] * len(env.agents)
        
        for h in human_indices:
            actions[h] = human_policy_fn(state, h, human_goals[h])
        
        with torch.no_grad():
            q_values = robot_q_network.encode_and_forward(state, env, device)
            robot_action = robot_q_network.sample_action(q_values, epsilon=0.0, beta_r=config.beta_r)
            for i, r in enumerate(robot_indices):
                if i < len(robot_action):
                    actions[r] = robot_action[i]
        
        _, _, done, _ = env.step(actions)
        env.render(mode='rgb_array', highlight=False)
        
        if done:
            break
    
    print(f"  Rollout {rollout_idx + 1}/{NUM_ROLLOUTS} complete")

# Save movie
movie_path = os.path.join(output_dir, 'phase2_demo.mp4')
env.save_video(movie_path, fps=MOVIE_FPS)

print(f"\nMovie saved to: {movie_path}")

In [None]:
# Display the video in Colab
from IPython.display import HTML
from base64 import b64encode

if os.path.exists(movie_path):
    with open(movie_path, 'rb') as f:
        video_data = f.read()
    video_b64 = b64encode(video_data).decode()
    
    html = f'''
    <video width="400" controls>
        <source src="data:video/mp4;base64,{video_b64}" type="video/mp4">
    </video>
    '''
    display(HTML(html))
else:
    print("Movie file not found")

## 9. TensorBoard (Optional)

View training metrics in TensorBoard.

In [None]:
# Load TensorBoard extension
%load_ext tensorboard
%tensorboard --logdir outputs/phase2_colab_demo/tensorboard

## Summary

This notebook demonstrated:

1. **Environment Setup**: Simple gridworld with robot and human agents
2. **Phase 2 Training**: Learning robot policy to maximize human power
3. **Warmup Stages**: Progressive network activation for stable training
4. **Model-Based Targets**: Using transition probabilities for consistent Q-values
5. **Policy Evaluation**: Visualizing learned behavior

**Key training features:**
- Buffer clearing at warmup stage transitions
- beta_r ramp-up for gradual policy concentration
- Target networks for stable value estimation
- Optional async training for scalability (works better as a script)

For more details, see the [EMPO documentation](https://github.com/mensch72/empo).