<a href="https://colab.research.google.com/github/kuds/reinforce-tactics/blob/main/notebooks/ppo_baseline_benchmark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reinforce Tactics — PPO Baseline Training Benchmarks

This notebook trains a **MaskablePPO** agent against `SimpleBot` on the 6×6 beginner map
and records reference metrics at four training checkpoints:

| Checkpoint | Timesteps |
|------------|-----------|
| 1 | 10,000 |
| 2 | 50,000 |
| 3 | 200,000 |
| 4 | 1,000,000 |

At each checkpoint the agent is evaluated over **50 episodes** and we record:
- **Win rate** (% of games won against SimpleBot)
- **Average episode reward**
- **Average episode length** (steps)

The goal is to provide a **reference curve** so that users can run the same
notebook and compare their results to known-good training runs.

**Runtime:** CPU is fine (~20–40 min total). GPU will be faster.

---

### Why MaskablePPO?

The game has a `MultiDiscrete` action space where many action combinations
are invalid at any given time (e.g. you can’t attack a tile with no enemy).
**Action masking** prevents the agent from sampling these invalid actions,
which typically yields 2–3× faster convergence compared to plain PPO.

## 1. Setup

In [None]:
# Install dependencies
!pip install -q gymnasium stable-baselines3 sb3-contrib tensorboard pandas numpy torch matplotlib

import torch
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")
if DEVICE == 'cuda':
    print(f"  GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Clone repo and install as a package
import os, sys
from pathlib import Path

REPO_DIR = Path('reinforce-tactics')
if REPO_DIR.exists():
    os.chdir(REPO_DIR)
elif Path('notebooks').exists():
    # Already inside the repo
    os.chdir('..')
else:
    print('Cloning repository...')
    !git clone https://github.com/kuds/reinforce-tactics.git
    os.chdir(REPO_DIR)

# Install the package so all imports resolve
!pip install -q -e .

if os.getcwd() not in sys.path:
    sys.path.insert(0, os.getcwd())

print(f"Working directory: {os.getcwd()}")

## 2. Imports

In [None]:
import json
import time
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sb3_contrib import MaskablePPO
from stable_baselines3.common.callbacks import CheckpointCallback, BaseCallback

from reinforcetactics.rl.masking import make_maskable_env, make_maskable_vec_env
from reinforcetactics.core.game_state import GameState

print('All imports successful.')

## 3. Configuration

In [None]:
# --- Benchmark settings ---
MAP_FILE        = 'maps/1v1/beginner.csv'   # 6x6 beginner map
OPPONENT        = 'bot'                      # SimpleBot
MAX_STEPS       = 500                        # max steps per episode
N_ENVS          = 4                          # parallel training envs
SEED            = 42

# Action space mode:
#   'flat_discrete'  — exact per-action masks (recommended, eliminates invalid actions)
#   'multi_discrete' — per-dimension masks (over-approximation, original behaviour)
ACTION_SPACE    = 'flat_discrete'

# Checkpoints to evaluate
CHECKPOINTS     = [10_000, 50_000, 200_000, 1_000_000]
EVAL_EPISODES   = 50                         # episodes per evaluation

# PPO hyperparameters
PPO_CONFIG = dict(
    learning_rate = 3e-4,
    n_steps       = 2048,
    batch_size    = 64,
    n_epochs      = 10,
    gamma         = 0.99,
    gae_lambda    = 0.95,
    clip_range    = 0.2,
    ent_coef      = 0.01,
    vf_coef       = 0.5,
    max_grad_norm = 0.5,
)

# Replay settings — save evaluation game replays for later viewing
SAVE_REPLAYS           = True                # save evaluation game replays
REPLAYS_PER_CHECKPOINT = 5                   # number of replays to save per checkpoint

# Output paths
BENCHMARK_DIR = Path('benchmarks/ppo_vs_simplebot')
BENCHMARK_DIR.mkdir(parents=True, exist_ok=True)

REPLAY_DIR = BENCHMARK_DIR / 'replays'
REPLAY_DIR.mkdir(parents=True, exist_ok=True)

print(f'Map:          {MAP_FILE}')
print(f'Opponent:     {OPPONENT}')
print(f'Action space: {ACTION_SPACE}')
print(f'Checkpoints:  {CHECKPOINTS}')
print(f'Eval eps:     {EVAL_EPISODES}')
print(f'Save replays: {SAVE_REPLAYS} ({REPLAYS_PER_CHECKPOINT} per checkpoint)')
print(f'Output dir:   {BENCHMARK_DIR}')

## 4. Create environments

In [None]:
# Training envs (vectorized, headless)
vec_env = make_maskable_vec_env(
    n_envs=N_ENVS,
    map_file=MAP_FILE,
    opponent=OPPONENT,
    max_steps=MAX_STEPS,
    seed=SEED,
    use_subprocess=False,   # DummyVecEnv (safer in notebooks)
    action_space_type=ACTION_SPACE,
)

# Separate eval env (single, deterministic)
eval_env = make_maskable_env(
    map_file=MAP_FILE,
    opponent=OPPONENT,
    max_steps=MAX_STEPS,
    action_space_type=ACTION_SPACE,
)

print(f'Observation space: {vec_env.observation_space}')
print(f'Action space:      {vec_env.action_space}')

## 5. Create MaskablePPO model

In [None]:
model = MaskablePPO(
    'MultiInputPolicy',
    vec_env,
    verbose=0,
    tensorboard_log=str(BENCHMARK_DIR / 'tensorboard'),
    device=DEVICE,
    seed=SEED,
    **PPO_CONFIG,
)

print('MaskablePPO model created.')
print(f'Policy:  {model.policy.__class__.__name__}')
print(f'Device:  {model.device}')

## 6. Evaluation helper

In [None]:
def evaluate_model(model, env, n_episodes=50, save_replays=False,
                   replay_dir=None, checkpoint_name=None, max_replays=5):
    """
    Evaluate a trained model and return summary statistics.
    Optionally save game replays to JSON files for later viewing.

    Returns dict with: win_rate, avg_reward, std_reward,
    avg_length, std_length, wins, losses, draws, replays_saved
    """
    wins, losses, draws = 0, 0, 0
    rewards, lengths = [], []
    replays_saved = 0

    for ep in range(n_episodes):
        obs, _ = env.reset()

        # Set up game state metadata for replay saving
        if save_replays and replays_saved < max_replays:
            gs = env.game_state
            gs.map_file_used = MAP_FILE
            gs.player_configs = [
                GameState.build_player_config(1, f'PPO-{checkpoint_name}', 'rl'),
                GameState.build_player_config(2, 'SimpleBot', 'bot'),
            ]

        done = False
        ep_reward = 0.0
        ep_len = 0

        while not done:
            masks = env.action_masks()
            action, _ = model.predict(obs, deterministic=True, action_masks=masks)
            obs, reward, terminated, truncated, info = env.step(action)
            ep_reward += reward
            ep_len += 1
            done = terminated or truncated

        rewards.append(ep_reward)
        lengths.append(ep_len)

        winner = info.get('winner')
        if winner == 1:
            wins += 1
        elif winner is not None:
            losses += 1
        else:
            draws += 1

        # Save replay if enabled
        if save_replays and replays_saved < max_replays:
            gs = env.game_state
            outcome = 'win' if winner == 1 else ('loss' if winner is not None else 'draw')
            replay_path = replay_dir / f'{checkpoint_name}_ep{ep}_{outcome}.json'
            gs.save_replay_to_file(str(replay_path))
            replays_saved += 1

    return {
        'win_rate':    wins / n_episodes,
        'avg_reward':  float(np.mean(rewards)),
        'std_reward':  float(np.std(rewards)),
        'avg_length':  float(np.mean(lengths)),
        'std_length':  float(np.std(lengths)),
        'wins':        wins,
        'losses':      losses,
        'draws':       draws,
        'replays_saved': replays_saved,
    }

print('evaluate_model() defined.')

## 7. Train and evaluate at each checkpoint

We train incrementally: 0 → 10K → 50K → 200K → 1M timesteps,
evaluating at each checkpoint.

In [None]:
results = []
trained_so_far = 0
start_time = time.time()

for checkpoint_ts in CHECKPOINTS:
    steps_to_train = checkpoint_ts - trained_so_far
    print(f'\n{"="*60}')
    print(f'Training {trained_so_far:,} -> {checkpoint_ts:,} '
          f'({steps_to_train:,} steps)...')
    print(f'{"="*60}')

    t0 = time.time()
    model.learn(
        total_timesteps=steps_to_train,
        reset_num_timesteps=False,
        progress_bar=True,
    )
    train_time = time.time() - t0
    trained_so_far = checkpoint_ts

    # Save checkpoint
    ckpt_path = BENCHMARK_DIR / f'model_{checkpoint_ts}.zip'
    model.save(str(ckpt_path))
    print(f'Saved checkpoint: {ckpt_path}')

    # Evaluate
    print(f'Evaluating over {EVAL_EPISODES} episodes...')
    metrics = evaluate_model(
        model, eval_env, n_episodes=EVAL_EPISODES,
        save_replays=SAVE_REPLAYS,
        replay_dir=REPLAY_DIR,
        checkpoint_name=f'{checkpoint_ts}',
        max_replays=REPLAYS_PER_CHECKPOINT,
    )
    metrics['timesteps'] = checkpoint_ts
    metrics['train_time_s'] = round(train_time, 1)
    results.append(metrics)

    print(f'  Win rate:       {metrics["win_rate"]*100:.1f}%')
    print(f'  Avg reward:     {metrics["avg_reward"]:.1f} '
          f'(+/- {metrics["std_reward"]:.1f})')
    print(f'  Avg length:     {metrics["avg_length"]:.1f} '
          f'(+/- {metrics["std_length"]:.1f})')
    print(f'  W/L/D:          {metrics["wins"]}/{metrics["losses"]}/{metrics["draws"]}')
    print(f'  Training time:  {train_time:.1f}s')
    if SAVE_REPLAYS:
        print(f'  Replays saved:  {metrics["replays_saved"]}')

total_time = time.time() - start_time
print(f'\nTotal wall time: {total_time/60:.1f} minutes')

## 8. Results table

In [None]:
df = pd.DataFrame(results)
df['win_rate_pct'] = (df['win_rate'] * 100).round(1)
df['avg_reward'] = df['avg_reward'].round(1)
df['avg_length'] = df['avg_length'].round(1)

display_df = df[['timesteps', 'win_rate_pct', 'avg_reward', 'avg_length',
                  'wins', 'losses', 'draws', 'train_time_s']].copy()
display_df.columns = ['Timesteps', 'Win Rate (%)', 'Avg Reward',
                       'Avg Length', 'Wins', 'Losses', 'Draws',
                       'Train Time (s)']
display_df = display_df.set_index('Timesteps')
display_df

## 9. Training curves

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

ts = [r['timesteps'] for r in results]

# Win rate
ax = axes[0]
wr = [r['win_rate'] * 100 for r in results]
ax.plot(ts, wr, 'o-', color='#2196F3', linewidth=2, markersize=8)
ax.set_xlabel('Timesteps')
ax.set_ylabel('Win Rate (%)')
ax.set_title('Win Rate vs SimpleBot')
ax.set_xscale('log')
ax.set_ylim(-5, 105)
ax.axhline(y=70, color='green', linestyle='--', alpha=0.5, label='70% target')
ax.legend()
ax.grid(True, alpha=0.3)

# Average reward
ax = axes[1]
avg_r = [r['avg_reward'] for r in results]
std_r = [r['std_reward'] for r in results]
ax.plot(ts, avg_r, 'o-', color='#FF9800', linewidth=2, markersize=8)
ax.fill_between(ts,
                [a - s for a, s in zip(avg_r, std_r)],
                [a + s for a, s in zip(avg_r, std_r)],
                alpha=0.2, color='#FF9800')
ax.set_xlabel('Timesteps')
ax.set_ylabel('Average Reward')
ax.set_title('Average Episode Reward')
ax.set_xscale('log')
ax.grid(True, alpha=0.3)

# Episode length
ax = axes[2]
avg_l = [r['avg_length'] for r in results]
std_l = [r['std_length'] for r in results]
ax.plot(ts, avg_l, 'o-', color='#4CAF50', linewidth=2, markersize=8)
ax.fill_between(ts,
                [a - s for a, s in zip(avg_l, std_l)],
                [a + s for a, s in zip(avg_l, std_l)],
                alpha=0.2, color='#4CAF50')
ax.set_xlabel('Timesteps')
ax.set_ylabel('Average Length (steps)')
ax.set_title('Average Episode Length')
ax.set_xscale('log')
ax.grid(True, alpha=0.3)

fig.suptitle('PPO Baseline Benchmarks  |  6x6 beginner map  |  vs SimpleBot',
             fontsize=13, fontweight='bold', y=1.02)
fig.tight_layout()

fig.savefig(str(BENCHMARK_DIR / 'training_curves.png'),
            dpi=150, bbox_inches='tight')
print(f'Saved plot: {BENCHMARK_DIR / "training_curves.png"}')
plt.show()

## 10. Save results

In [None]:
# Save benchmark results as JSON
benchmark_data = {
    'metadata': {
        'date': datetime.now().isoformat(),
        'map': MAP_FILE,
        'opponent': OPPONENT,
        'max_steps': MAX_STEPS,
        'n_envs': N_ENVS,
        'eval_episodes': EVAL_EPISODES,
        'seed': SEED,
        'device': DEVICE,
        'ppo_config': PPO_CONFIG,
    },
    'results': results,
}

results_path = BENCHMARK_DIR / 'benchmark_results.json'
with open(results_path, 'w') as f:
    json.dump(benchmark_data, f, indent=2)

print(f'Saved results:  {results_path}')

# Also save as CSV for easy viewing
csv_path = BENCHMARK_DIR / 'benchmark_results.csv'
df.to_csv(csv_path, index=False)
print(f'Saved CSV:      {csv_path}')

# List all saved files
print(f'\nAll benchmark files:')
for p in sorted(BENCHMARK_DIR.iterdir()):
    if p.is_dir():
        # List files inside subdirectories (e.g. replays/)
        sub_files = sorted(p.iterdir())
        print(f'  {p.name}/ ({len(sub_files)} files)')
        for sp in sub_files[:5]:
            size = sp.stat().st_size
            if size > 1024 * 1024:
                size_str = f'{size / 1024 / 1024:.1f} MB'
            elif size > 1024:
                size_str = f'{size / 1024:.1f} KB'
            else:
                size_str = f'{size} B'
            print(f'    {sp.name:38s}  {size_str}')
        if len(sub_files) > 5:
            print(f'    ... and {len(sub_files) - 5} more')
    else:
        size = p.stat().st_size
        if size > 1024 * 1024:
            size_str = f'{size / 1024 / 1024:.1f} MB'
        elif size > 1024:
            size_str = f'{size / 1024:.1f} KB'
        else:
            size_str = f'{size} B'
        print(f'  {p.name:40s}  {size_str}')

## 11. Viewing saved replays

Replays are saved as JSON files in `benchmarks/ppo_vs_simplebot/replays/`.
Each file captures every action taken during an evaluation game, allowing
you to watch the PPO agent play at different training stages.

**To view a replay in the game UI:**
```bash
python main.py --replay benchmarks/ppo_vs_simplebot/replays/10000_ep0_win.json
```

**Replay filenames** follow the pattern `{timesteps}_ep{episode}_{outcome}.json`
where outcome is `win`, `loss`, or `draw`.

In [None]:
# List saved replays by checkpoint
if SAVE_REPLAYS and REPLAY_DIR.exists():
    replay_files = sorted(REPLAY_DIR.glob('*.json'))
    print(f'Total replays saved: {len(replay_files)}')
    print()
    for ckpt in CHECKPOINTS:
        ckpt_replays = [f for f in replay_files if f.name.startswith(f'{ckpt}_')]
        if ckpt_replays:
            outcomes = [f.stem.rsplit('_', 1)[-1] for f in ckpt_replays]
            print(f'  {ckpt:>10,} timesteps: {len(ckpt_replays)} replays '
                  f'({outcomes.count("win")}W / {outcomes.count("loss")}L / {outcomes.count("draw")}D)')
            for f in ckpt_replays:
                print(f'    {f.name}')
else:
    print('No replays saved (SAVE_REPLAYS is False).')

## 12. TensorBoard (optional)

Launch TensorBoard to inspect detailed training metrics (loss, entropy, etc.).

In [None]:
# Uncomment to launch TensorBoard inline:
# %load_ext tensorboard
# %tensorboard --logdir benchmarks/ppo_vs_simplebot/tensorboard

print('To view TensorBoard locally, run:')
print(f'  tensorboard --logdir {BENCHMARK_DIR / "tensorboard"}')

## 13. Interpreting the results

### What to expect

| Timesteps | Expected Win Rate | Notes |
|-----------|-------------------|-------|
| 10K | 0-15% | Agent is mostly random, learning basic actions |
| 50K | 15-40% | Agent starts making meaningful moves |
| 200K | 40-70% | Competent play, learns unit creation and combat |
| 1M | 60-90%+ | Strong play against SimpleBot |

**Note:** Exact numbers depend on hardware and random seed. The important
thing is that your curve has a similar *shape* -- monotonically increasing
win rate with diminishing returns after ~200K steps.

### If your results differ significantly

- **Much worse:** Check that action masking is working (the agent should
  rarely attempt invalid actions). Verify the map file path is correct.
- **Much better:** You may have found better hyperparameters! Consider
  contributing them back.
- **Unstable (oscillating win rate):** Try reducing the learning rate
  or increasing the batch size.

### Viewing replays

Use the saved replay files to visually compare how the agent plays at
each checkpoint. Early replays (10K) will show mostly random behaviour,
while later replays (1M) should demonstrate strategic play such as
creating units, capturing structures, and coordinated attacks.

### Next steps

1. **Try different maps:** Larger maps (10x10, 14x14) are harder
2. **Tune hyperparameters:** Adjust `ent_coef`, `learning_rate`, etc.
3. **Self-play training:** See `train/train_self_play.py`
4. **AlphaZero:** See `train/train_alphazero.py` for MCTS-based training

In [None]:
# Clean up environments
vec_env.close()
eval_env.close()
print('Done.')