# Zelda PPO Controller Training

This notebook demonstrates PPO training for the Zelda Oracle of Seasons controller.
It provides an interactive training loop with live monitoring and visualization.

In [None]:
import sys
import os
import time
import asyncio
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from collections import deque
from IPython.display import clear_output

# Add project root to Python path
project_root = Path('../').resolve()
sys.path.append(str(project_root))

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print(f"Project root: {project_root}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Configuration

In [None]:
# Training configuration
CONFIG = {
    'rom_path': '../roms/zelda_oracle_of_seasons.gbc',  # Update with your ROM path
    'total_timesteps': 100000,  # Reduced for notebook demo
    'rollout_length': 64,       # Steps per rollout
    'update_frequency': 1000,   # Update every N steps
    'eval_frequency': 5000,     # Evaluate every N steps
    'log_frequency': 500,       # Log every N steps
    'save_frequency': 10000,    # Save checkpoint every N steps
    'use_mock_planner': True,   # Use mock planner for demo
    'planner_frequency': 50,    # Call planner every N steps
}

# Check ROM path
rom_path = Path(CONFIG['rom_path'])
if not rom_path.exists():
    # Try to find any ROM file
    rom_dir = project_root / 'roms'
    rom_files = list(rom_dir.glob('*.gbc')) + list(rom_dir.glob('*.gb'))
    if rom_files:
        CONFIG['rom_path'] = str(rom_files[0])
        print(f"Using ROM: {CONFIG['rom_path']}")
    else:
        print("❌ No ROM file found! Please add a ROM to the roms/ directory.")
        CONFIG['rom_path'] = None

print(f"Configuration: {CONFIG}")

## Initialize Environment and Agent

In [None]:
if CONFIG['rom_path']:
    from emulator.zelda_env import ZeldaEnvironment
    from agents.controller import HybridAgent, ControllerConfig
    
    # Create environment
    env = ZeldaEnvironment(CONFIG['rom_path'], headless=True)
    
    # Create agent configuration
    agent_config = ControllerConfig(
        learning_rate=3e-4,
        gamma=0.99,
        gae_lambda=0.95,
        use_planner=CONFIG['use_mock_planner'],
        planner_frequency=CONFIG['planner_frequency']
    )
    
    # Create hybrid agent
    agent = HybridAgent(env, agent_config, use_mock_planner=CONFIG['use_mock_planner'])
    
    print(f"Environment: {env}")
    print(f"Agent device: {agent.controller.device}")
    print(f"Observation space: {env.observation_space}")
    print(f"Action space: {env.action_space}")
else:
    print("⏭️ Skipping initialization - no ROM file available")
    env = None
    agent = None

## Training Utilities

In [None]:
class TrainingMonitor:
    """Monitor training progress with live plotting."""
    
    def __init__(self, window_size=100):
        self.window_size = window_size
        self.metrics = {
            'episode_rewards': deque(maxlen=window_size),
            'episode_lengths': deque(maxlen=window_size),
            'policy_losses': deque(maxlen=window_size),
            'value_losses': deque(maxlen=window_size),
            'rupees_collected': deque(maxlen=window_size),
            'steps': []
        }
        
    def update(self, step, **kwargs):
        """Update metrics."""
        self.metrics['steps'].append(step)
        for key, value in kwargs.items():
            if key in self.metrics:
                self.metrics[key].append(value)
    
    def plot(self, figsize=(15, 10)):
        """Plot training metrics."""
        fig, axes = plt.subplots(2, 3, figsize=figsize)
        axes = axes.flatten()
        
        # Episode rewards
        if self.metrics['episode_rewards']:
            axes[0].plot(self.metrics['episode_rewards'])
            axes[0].set_title('Episode Rewards')
            axes[0].set_ylabel('Reward')
            axes[0].grid(True)
        
        # Episode lengths
        if self.metrics['episode_lengths']:
            axes[1].plot(self.metrics['episode_lengths'])
            axes[1].set_title('Episode Lengths')
            axes[1].set_ylabel('Steps')
            axes[1].grid(True)
        
        # Policy loss
        if self.metrics['policy_losses']:
            axes[2].plot(self.metrics['policy_losses'])
            axes[2].set_title('Policy Loss')
            axes[2].set_ylabel('Loss')
            axes[2].grid(True)
        
        # Value loss
        if self.metrics['value_losses']:
            axes[3].plot(self.metrics['value_losses'])
            axes[3].set_title('Value Loss')
            axes[3].set_ylabel('Loss')
            axes[3].grid(True)
        
        # Rupees collected
        if self.metrics['rupees_collected']:
            axes[4].plot(self.metrics['rupees_collected'])
            axes[4].set_title('Rupees Collected')
            axes[4].set_ylabel('Rupees')
            axes[4].grid(True)
        
        # Summary stats
        axes[5].text(0.1, 0.8, f"Total Episodes: {len(self.metrics['episode_rewards'])}")
        if self.metrics['episode_rewards']:
            avg_reward = np.mean(list(self.metrics['episode_rewards'])[-20:])
            axes[5].text(0.1, 0.6, f"Avg Reward (last 20): {avg_reward:.2f}")
        if self.metrics['episode_lengths']:
            avg_length = np.mean(list(self.metrics['episode_lengths'])[-20:])
            axes[5].text(0.1, 0.4, f"Avg Length (last 20): {avg_length:.1f}")
        axes[5].set_title('Summary')
        axes[5].set_xlim(0, 1)
        axes[5].set_ylim(0, 1)
        axes[5].axis('off')
        
        plt.tight_layout()
        plt.show()

# Create monitor
monitor = TrainingMonitor()
print("Training monitor created")

## Training Loop

In [None]:
async def collect_rollout(env, agent, max_steps=64):
    """Collect a rollout of experience."""
    observations = []
    actions = []
    log_probs = []
    rewards = []
    values = []
    dones = []
    
    obs, info = env.reset()
    initial_rupees = info.get('structured_state', {}).get('resources', {}).get('rupees', 0)
    
    for step in range(max_steps):
        observations.append(obs.copy())
        
        # Get action from agent
        structured_state = env.get_structured_state()
        action = await agent.act(obs, structured_state)
        
        # Get policy outputs for training
        obs_tensor = torch.FloatTensor(obs).unsqueeze(0).to(agent.controller.device)
        with torch.no_grad():
            action_tensor, log_prob, value = agent.controller.policy_net.get_action_and_value(obs_tensor)
        
        actions.append(action)
        log_probs.append(log_prob.item())
        values.append(value.item())
        
        # Step environment
        obs, reward, terminated, truncated, info = env.step(action)
        rewards.append(reward)
        done = terminated or truncated
        dones.append(done)
        
        if done:
            break
    
    # Calculate episode metrics
    final_rupees = info.get('structured_state', {}).get('resources', {}).get('rupees', 0)
    rupees_collected = final_rupees - initial_rupees
    
    return {
        'observations': observations,
        'actions': actions,
        'log_probs': log_probs,
        'rewards': rewards,
        'values': values,
        'dones': dones,
        'episode_reward': sum(rewards),
        'episode_length': len(rewards),
        'rupees_collected': rupees_collected
    }

async def training_loop():
    """Main training loop."""
    if not env or not agent:
        print("❌ Cannot start training - environment or agent not initialized")
        return
    
    print("🚀 Starting training loop...")
    
    global_step = 0
    episode_count = 0
    
    while global_step < CONFIG['total_timesteps']:
        # Collect rollout
        rollout_data = await collect_rollout(env, agent, CONFIG['rollout_length'])
        
        global_step += rollout_data['episode_length']
        episode_count += 1
        
        # Train agent
        if len(rollout_data['observations']) > 0:
            # Convert to tensors
            observations = torch.FloatTensor(rollout_data['observations']).to(agent.controller.device)
            actions = torch.LongTensor(rollout_data['actions']).to(agent.controller.device)
            old_log_probs = torch.FloatTensor(rollout_data['log_probs']).to(agent.controller.device)
            
            # Compute advantages and returns
            advantages, returns = agent.controller.compute_gae(
                rollout_data['rewards'], 
                rollout_data['values'], 
                rollout_data['dones']
            )
            advantages = torch.FloatTensor(advantages).to(agent.controller.device)
            returns = torch.FloatTensor(returns).to(agent.controller.device)
            
            # Prepare batch data
            batch_data = {
                'obs': observations,
                'actions': actions,
                'log_probs': old_log_probs,
                'advantages': advantages,
                'returns': returns
            }
            
            # Update policy
            training_metrics = agent.controller.update(batch_data, epochs=2)
        else:
            training_metrics = {'policy_loss': 0, 'value_loss': 0}
        
        # Update monitor
        monitor.update(
            step=global_step,
            episode_rewards=rollout_data['episode_reward'],
            episode_lengths=rollout_data['episode_length'],
            policy_losses=training_metrics['policy_loss'],
            value_losses=training_metrics['value_loss'],
            rupees_collected=rollout_data['rupees_collected']
        )
        
        # Logging
        if global_step % CONFIG['log_frequency'] == 0:
            clear_output(wait=True)
            print(f"Step {global_step}/{CONFIG['total_timesteps']} | Episode {episode_count}")
            print(f"Reward: {rollout_data['episode_reward']:.3f} | Length: {rollout_data['episode_length']}")
            print(f"Policy Loss: {training_metrics['policy_loss']:.4f} | Value Loss: {training_metrics['value_loss']:.4f}")
            print(f"Rupees: {rollout_data['rupees_collected']} | Agent metrics: {agent.get_metrics()}")
            
            # Plot metrics
            monitor.plot(figsize=(12, 8))
        
        # Save checkpoint
        if global_step % CONFIG['save_frequency'] == 0:
            checkpoint_path = f"../checkpoints/notebook_checkpoint_{global_step}.pt"
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            agent.controller.save_checkpoint(checkpoint_path)
            print(f"💾 Saved checkpoint: {checkpoint_path}")
    
    print("🎉 Training completed!")
    
    # Final plot
    monitor.plot(figsize=(15, 10))

# Run training loop
await training_loop()

## Evaluation

In [None]:
async def evaluate_agent(num_episodes=5):
    """Evaluate the trained agent."""
    if not env or not agent:
        print("❌ Cannot evaluate - environment or agent not initialized")
        return
    
    print(f"🔍 Evaluating agent for {num_episodes} episodes...")
    
    episode_rewards = []
    episode_lengths = []
    episode_metrics = []
    
    for episode in range(num_episodes):
        obs, info = env.reset()
        episode_reward = 0
        episode_length = 0
        
        initial_state = info.get('structured_state', {})
        initial_rupees = initial_state.get('resources', {}).get('rupees', 0)
        initial_health = initial_state.get('player', {}).get('health', 0)
        
        while True:
            # Use deterministic policy for evaluation
            action = agent.controller.act_deterministic(obs)
            obs, reward, terminated, truncated, info = env.step(action)
            
            episode_reward += reward
            episode_length += 1
            
            if terminated or truncated or episode_length >= 1000:
                break
        
        final_state = info.get('structured_state', {})
        final_rupees = final_state.get('resources', {}).get('rupees', 0)
        final_health = final_state.get('player', {}).get('health', 0)
        
        episode_rewards.append(episode_reward)
        episode_lengths.append(episode_length)
        episode_metrics.append({
            'rupees_gained': final_rupees - initial_rupees,
            'health_change': final_health - initial_health,
            'terminated': terminated
        })
        
        print(f"Episode {episode+1}: Reward={episode_reward:.3f}, Length={episode_length}, "
              f"Rupees={final_rupees-initial_rupees}, Health={final_health}/{initial_health}")
    
    # Summary statistics
    print(f"\n📊 Evaluation Results:")
    print(f"Mean reward: {np.mean(episode_rewards):.3f} ± {np.std(episode_rewards):.3f}")
    print(f"Mean length: {np.mean(episode_lengths):.1f} ± {np.std(episode_lengths):.1f}")
    print(f"Mean rupees gained: {np.mean([m['rupees_gained'] for m in episode_metrics]):.1f}")
    print(f"Termination rate: {np.mean([m['terminated'] for m in episode_metrics]):.2%}")
    
    # Plot evaluation results
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    axes[0].bar(range(len(episode_rewards)), episode_rewards)
    axes[0].set_title('Episode Rewards')
    axes[0].set_xlabel('Episode')
    axes[0].set_ylabel('Reward')
    
    axes[1].bar(range(len(episode_lengths)), episode_lengths)
    axes[1].set_title('Episode Lengths')
    axes[1].set_xlabel('Episode')
    axes[1].set_ylabel('Steps')
    
    plt.tight_layout()
    plt.show()
    
    return {
        'rewards': episode_rewards,
        'lengths': episode_lengths,
        'metrics': episode_metrics
    }

# Run evaluation
eval_results = await evaluate_agent(num_episodes=3)

## Cleanup

In [None]:
# Clean up resources
if env:
    env.close()
    print("Environment closed")

if agent:
    await agent.close()
    print("Agent closed")

print("🧹 Cleanup complete")

## Summary

This notebook demonstrated:

1. **Interactive PPO Training**: Real-time monitoring with live plots
2. **Hybrid Agent**: Integration of PPO controller with mock LLM planner
3. **Game-specific Metrics**: Tracking rupees, health, exploration
4. **Evaluation**: Deterministic policy evaluation

### Key Observations:
- The agent learns basic movement and interaction
- Mock planner provides strategic guidance
- Reward shaping encourages exploration and resource collection

### Next Steps:
- Use `03_planner_grpo.ipynb` for LLM planner optimization
- Run longer training with `training/run_cleanrl.py`
- Deploy LLM planner on OpenShift for full system integration