# Reinforce Tactics — Kaggle Demo

Train a **MaskablePPO** agent to play a turn-based strategy game using
reinforcement learning, right here on Kaggle.

**What this notebook does:**

1. Installs the game and RL dependencies
2. Creates the Gymnasium environment (6×6 beginner map, headless)
3. Trains a MaskablePPO agent with action masking against a random opponent
4. Evaluates the agent at three checkpoints (10K, 50K, 200K timesteps)
5. Visualises training curves and replays agent behaviour

**Runtime:** CPU is fine (∼10–15 min). GPU (P100/T4) will be faster.

---

### About the game

Reinforce Tactics is a turn-based strategy game where players create units,
capture structures, and eliminate the opponent. The RL agent observes a grid
with terrain, units, and global features (gold, turn number, etc.), and
chooses from a flat-discrete action space with exact action masking
to avoid invalid moves.

| Unit | Cost | Move | HP | Role |
|------|------|------|----|------|
| Warrior | 200 | 3 | 15 | Melee tank |
| Mage | 300 | 2 | 10 | Ranged + Paralyze |
| Cleric | 200 | 2 | 8 | Heal / Cure |
| Archer | 250 | 3 | 15 | Long range |
| Knight | 350 | 4 | 18 | Charge bonus |
| Rogue | 350 | 4 | 12 | Flank + Evade |
| Sorcerer | 400 | 2 | 10 | Buffs |
| Barbarian | 400 | 5 | 20 | Fast melee |

**Repository:** [github.com/kuds/reinforce-tactics](https://github.com/kuds/reinforce-tactics)

## 1. Setup

In [None]:
# Install dependencies
!pip install -q gymnasium stable-baselines3 sb3-contrib tensorboard pandas numpy torch matplotlib

import torch
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")
if DEVICE == 'cuda':
    print(f"  GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Clone repo and install as a package
import os, sys
from pathlib import Path

REPO_DIR = Path('reinforce-tactics')

# Detect environment
if REPO_DIR.exists():
    os.chdir(REPO_DIR)
elif Path('notebooks').exists():
    # Already inside the repo
    os.chdir('..')
else:
    print('Cloning repository...')
    !git clone https://github.com/kuds/reinforce-tactics.git
    os.chdir(REPO_DIR)

# Install the package so all imports resolve
!pip install -q -e .

if os.getcwd() not in sys.path:
    sys.path.insert(0, os.getcwd())

print(f"Working directory: {os.getcwd()}")

## 2. Imports

In [None]:
import json
import time
from collections import Counter
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from sb3_contrib import MaskablePPO

from reinforcetactics.rl.masking import make_maskable_env, make_maskable_vec_env

print('All imports successful.')

## 3. Configuration

All tuneable knobs are in this one cell. The defaults are chosen for a
quick demo (∼10–15 min on CPU). To train a stronger agent, increase
`CHECKPOINTS` and optionally switch `OPPONENT` to `'bot'`.

In [None]:
# --- Benchmark settings ---
MAP_FILE        = 'maps/1v1/beginner.csv'   # 6x6 beginner map
OPPONENT        = 'random'                   # 'random' for easy wins, 'bot' for SimpleBot
MAX_STEPS       = 200                        # Max steps per episode
N_ENVS          = 4                          # Parallel training envs
SEED            = 42

# Action space mode:
#   'flat_discrete'  - exact per-action masks (recommended)
#   'multi_discrete' - per-dimension masks (original, ~99% invalid actions)
ACTION_SPACE    = 'flat_discrete'

# Checkpoints to evaluate
CHECKPOINTS     = [10_000, 50_000, 200_000]
EVAL_EPISODES   = 30                         # Episodes per evaluation

# --- Reward configuration ---
REWARD_CONFIG = {
    # Terminal rewards
    'win':                1000.0,
    'loss':              -1000.0,
    'draw':              -200.0,

    # Potential-based shaping
    'income_diff':          0.05,
    'unit_diff':            0.3,
    'structure_control':    1.0,

    # Per-action rewards
    'create_unit':          1.0,
    'move':                 0.1,
    'damage_scale':         0.2,
    'kill':                15.0,
    'seize_progress':       1.0,
    'capture':             30.0,
    'cure':                 5.0,
    'heal_scale':           0.5,
    'paralyze':             8.0,
    'haste':                6.0,
    'defence_buff':         5.0,
    'attack_buff':          5.0,

    # Penalties
    'invalid_action':     -10.0,
    'turn_penalty':        -2.0,
}

# PPO hyperparameters
PPO_CONFIG = dict(
    learning_rate = 3e-4,
    n_steps       = 2048,
    batch_size    = 64,
    n_epochs      = 10,
    gamma         = 0.99,
    gae_lambda    = 0.95,
    clip_range    = 0.2,
    ent_coef      = 0.05,
    vf_coef       = 0.5,
    max_grad_norm = 0.5,
)

# Output directory
OUTPUT_DIR = Path('kaggle_demo_output')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f'Map:          {MAP_FILE}')
print(f'Opponent:     {OPPONENT}')
print(f'Action space: {ACTION_SPACE}')
print(f'Max steps:    {MAX_STEPS}')
print(f'Checkpoints:  {CHECKPOINTS}')
print(f'Eval eps:     {EVAL_EPISODES}')
print(f'Output dir:   {OUTPUT_DIR}')

## 4. Create environments

In [None]:
# Training envs (vectorized, headless)
vec_env = make_maskable_vec_env(
    n_envs=N_ENVS,
    map_file=MAP_FILE,
    opponent=OPPONENT,
    max_steps=MAX_STEPS,
    reward_config=REWARD_CONFIG,
    seed=SEED,
    use_subprocess=False,   # DummyVecEnv (safer in notebooks)
    action_space_type=ACTION_SPACE,
)

# Separate eval env (single, deterministic)
eval_env = make_maskable_env(
    map_file=MAP_FILE,
    opponent=OPPONENT,
    max_steps=MAX_STEPS,
    reward_config=REWARD_CONFIG,
    action_space_type=ACTION_SPACE,
)

print(f'Observation space: {vec_env.observation_space}')
print(f'Action space:      {vec_env.action_space}')

## 5. Explore the environment

Before training, let's take a quick look at what the agent sees and
what actions are available.

In [None]:
obs, info = eval_env.reset()

print('Observation keys and shapes:')
for key, val in obs.items():
    if hasattr(val, 'shape'):
        print(f'  {key:20s}  shape={val.shape}  dtype={val.dtype}')
    else:
        print(f'  {key:20s}  {val}')

masks = eval_env.action_masks()
n_legal = int(masks.sum())
n_total = len(masks)
print(f'\nAction mask: {n_legal} legal actions out of {n_total} total ({n_legal/n_total*100:.1f}%)')

# Show the grid observation as a heatmap
fig, axes = plt.subplots(1, 3, figsize=(12, 3.5))
channel_names = ['Terrain type', 'Owner', 'Structure HP']
for i, (ax, name) in enumerate(zip(axes, channel_names)):
    im = ax.imshow(obs['grid'][:, :, i], cmap='viridis', interpolation='nearest')
    ax.set_title(name)
    plt.colorbar(im, ax=ax, shrink=0.8)
fig.suptitle('Grid observation channels (initial state)', fontweight='bold')
fig.tight_layout()
plt.show()

## 6. Create MaskablePPO model

We use **MaskablePPO** from `sb3-contrib` which supports action masking
natively. The flat-discrete action space maps each legal game action to
a unique integer, and the mask tells the policy which integers are valid
on each step.

In [None]:
model = MaskablePPO(
    'MultiInputPolicy',
    vec_env,
    verbose=0,
    tensorboard_log=str(OUTPUT_DIR / 'tensorboard'),
    device=DEVICE,
    seed=SEED,
    **PPO_CONFIG,
)

print('MaskablePPO model created.')
print(f'Policy:  {model.policy.__class__.__name__}')
print(f'Device:  {model.device}')

## 7. Evaluation helper

In [None]:
def evaluate_model(model, env, n_episodes=30):
    """
    Evaluate a trained model and return summary statistics.

    Returns dict with: win_rate, avg_reward, std_reward,
    avg_length, std_length, wins, losses, draws
    """
    wins, losses, draws = 0, 0, 0
    rewards, lengths = [], []

    for _ in range(n_episodes):
        obs, _ = env.reset()
        done = False
        ep_reward = 0.0
        ep_len = 0

        while not done:
            masks = env.action_masks()
            action, _ = model.predict(obs, deterministic=True, action_masks=masks)
            obs, reward, terminated, truncated, info = env.step(action)
            ep_reward += reward
            ep_len += 1
            done = terminated or truncated

        rewards.append(ep_reward)
        lengths.append(ep_len)

        winner = info.get('winner')
        if winner == 1:
            wins += 1
        elif winner is not None:
            losses += 1
        else:
            draws += 1

    return {
        'win_rate':    wins / n_episodes,
        'avg_reward':  float(np.mean(rewards)),
        'std_reward':  float(np.std(rewards)),
        'avg_length':  float(np.mean(lengths)),
        'std_length':  float(np.std(lengths)),
        'wins':        wins,
        'losses':      losses,
        'draws':       draws,
    }

print('evaluate_model() defined.')

## 8. Train and evaluate at each checkpoint

We train incrementally: 0 → 10K → 50K → 200K timesteps,
evaluating at each checkpoint.

In [None]:
results = []
trained_so_far = 0
start_time = time.time()

for checkpoint_ts in CHECKPOINTS:
    steps_to_train = checkpoint_ts - trained_so_far
    print(f'\n{"="*60}')
    print(f'Training {trained_so_far:,} -> {checkpoint_ts:,} '
          f'({steps_to_train:,} steps)...')
    print(f'{"="*60}')

    t0 = time.time()
    model.learn(
        total_timesteps=steps_to_train,
        reset_num_timesteps=False,
        progress_bar=True,
    )
    train_time = time.time() - t0
    trained_so_far = checkpoint_ts

    # Save checkpoint
    ckpt_path = OUTPUT_DIR / f'model_{checkpoint_ts}.zip'
    model.save(str(ckpt_path))
    print(f'Saved checkpoint: {ckpt_path}')

    # Evaluate
    print(f'Evaluating over {EVAL_EPISODES} episodes...')
    metrics = evaluate_model(model, eval_env, n_episodes=EVAL_EPISODES)
    metrics['timesteps'] = checkpoint_ts
    metrics['train_time_s'] = round(train_time, 1)
    results.append(metrics)

    print(f'  Win rate:       {metrics["win_rate"]*100:.1f}%')
    print(f'  Avg reward:     {metrics["avg_reward"]:.1f} '
          f'(+/- {metrics["std_reward"]:.1f})')
    print(f'  Avg length:     {metrics["avg_length"]:.1f} '
          f'(+/- {metrics["std_length"]:.1f})')
    print(f'  W/L/D:          {metrics["wins"]}/{metrics["losses"]}/{metrics["draws"]}')
    print(f'  Training time:  {train_time:.1f}s')

total_time = time.time() - start_time
print(f'\nTotal wall time: {total_time/60:.1f} minutes')

## 9. Results table

In [None]:
df = pd.DataFrame(results)
df['win_rate_pct'] = (df['win_rate'] * 100).round(1)
df['avg_reward'] = df['avg_reward'].round(1)
df['avg_length'] = df['avg_length'].round(1)

display_df = df[['timesteps', 'win_rate_pct', 'avg_reward', 'avg_length',
                  'wins', 'losses', 'draws', 'train_time_s']].copy()
display_df.columns = ['Timesteps', 'Win Rate (%)', 'Avg Reward',
                       'Avg Length', 'Wins', 'Losses', 'Draws',
                       'Train Time (s)']
display_df = display_df.set_index('Timesteps')
display_df

## 10. Training curves

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

ts = [r['timesteps'] for r in results]

# Win rate
ax = axes[0]
wr = [r['win_rate'] * 100 for r in results]
ax.plot(ts, wr, 'o-', color='#2196F3', linewidth=2, markersize=8)
ax.set_xlabel('Timesteps')
ax.set_ylabel('Win Rate (%)')
ax.set_title('Win Rate vs Opponent')
ax.set_xscale('log')
ax.set_ylim(-5, 105)
ax.axhline(y=70, color='green', linestyle='--', alpha=0.5, label='70% target')
ax.legend()
ax.grid(True, alpha=0.3)

# Average reward
ax = axes[1]
avg_r = [r['avg_reward'] for r in results]
std_r = [r['std_reward'] for r in results]
ax.plot(ts, avg_r, 'o-', color='#FF9800', linewidth=2, markersize=8)
ax.fill_between(ts,
                [a - s for a, s in zip(avg_r, std_r)],
                [a + s for a, s in zip(avg_r, std_r)],
                alpha=0.2, color='#FF9800')
ax.set_xlabel('Timesteps')
ax.set_ylabel('Average Reward')
ax.set_title('Average Episode Reward')
ax.set_xscale('log')
ax.grid(True, alpha=0.3)

# Episode length
ax = axes[2]
avg_l = [r['avg_length'] for r in results]
std_l = [r['std_length'] for r in results]
ax.plot(ts, avg_l, 'o-', color='#4CAF50', linewidth=2, markersize=8)
ax.fill_between(ts,
                [a - s for a, s in zip(avg_l, std_l)],
                [a + s for a, s in zip(avg_l, std_l)],
                alpha=0.2, color='#4CAF50')
ax.set_xlabel('Timesteps')
ax.set_ylabel('Average Length (steps)')
ax.set_title('Average Episode Length')
ax.set_xscale('log')
ax.grid(True, alpha=0.3)

fig.suptitle(f'MaskablePPO Training  |  6x6 beginner map  |  vs {OPPONENT}',
             fontsize=13, fontweight='bold', y=1.02)
fig.tight_layout()

fig.savefig(str(OUTPUT_DIR / 'training_curves.png'),
            dpi=150, bbox_inches='tight')
print(f'Saved plot: {OUTPUT_DIR / "training_curves.png"}')
plt.show()

## 11. Watch the agent play

Record a few evaluation episodes and inspect exactly what the agent does
each step: which actions it takes, how units and gold change over time,
and whether it's learning meaningful tactics.

In [None]:
N_REPLAY_EPISODES = 3

ACTION_NAMES = [
    'create_unit', 'move', 'attack', 'seize', 'heal',
    'end_turn', 'paralyze', 'haste', 'defence_buff', 'attack_buff',
]
UNIT_NAMES = ['W', 'M', 'C', 'A', 'K', 'R', 'S', 'B']


def _snapshot_game_state(env):
    """Capture a summary of the current game state."""
    gs = env.unwrapped.game_state
    ap = env.unwrapped.agent_player
    opp = 3 - ap
    return {
        'agent_gold': gs.player_gold.get(ap, 0),
        'opponent_gold': gs.player_gold.get(opp, 0),
        'agent_units': sum(1 for u in gs.units if u.player == ap),
        'opponent_units': sum(1 for u in gs.units if u.player == opp),
        'agent_structures': len(gs.grid.get_capturable_tiles(player=ap)),
        'opponent_structures': len(gs.grid.get_capturable_tiles(player=opp)),
        'turn': gs.turn_number,
    }


def evaluate_with_replay(model, env, n_episodes=3):
    """Run evaluation episodes and record every action."""
    episodes = []

    for ep_idx in range(n_episodes):
        obs, _ = env.reset()
        done = False
        ep_reward = 0.0
        steps = []
        step_num = 0

        while not done:
            masks = env.action_masks()
            action, _ = model.predict(obs, deterministic=True, action_masks=masks)

            # Decode action before stepping
            raw_action = action
            if ACTION_SPACE == 'flat_discrete':
                action_idx = int(action)
                inner = env.unwrapped
                if 0 <= action_idx < len(inner._current_actions):
                    action_arr = inner._current_actions[action_idx]
                else:
                    action_arr = np.array([5, 0, 0, 0, 0, 0])
            else:
                action_arr = np.asarray(action)

            action_type = int(action_arr[0])
            action_name = ACTION_NAMES[action_type] if action_type < len(ACTION_NAMES) else f'unknown_{action_type}'
            unit_type = UNIT_NAMES[int(action_arr[1]) % 8]
            from_pos = [int(action_arr[2]), int(action_arr[3])]
            to_pos = [int(action_arr[4]), int(action_arr[5])]

            obs, reward, terminated, truncated, info = env.step(raw_action)
            ep_reward += reward
            step_num += 1
            done = terminated or truncated

            step_record = {
                'step': step_num,
                'action': action_name,
                'unit_type': unit_type if action_type == 0 else None,
                'from': from_pos,
                'to': to_pos,
                'reward': round(float(reward), 3),
                'cumulative_reward': round(float(ep_reward), 3),
                'valid': info.get('valid_action', True),
                'game_state': _snapshot_game_state(env),
            }
            steps.append(step_record)

        winner = info.get('winner')
        if winner == 1:
            outcome = 'win'
        elif winner is not None:
            outcome = 'loss'
        else:
            outcome = 'draw'

        final_turn = steps[-1]['game_state']['turn'] if steps else 0
        episodes.append({
            'episode': ep_idx,
            'outcome': outcome,
            'total_reward': round(float(ep_reward), 2),
            'length': step_num,
            'turns': final_turn,
            'steps': steps,
        })

    return episodes

print('Replay helpers defined.')

In [None]:
print(f'Recording {N_REPLAY_EPISODES} replay episodes from final model...\n')
replay_episodes = evaluate_with_replay(model, eval_env, n_episodes=N_REPLAY_EPISODES)

for ep in replay_episodes:
    final_turn = ep['steps'][-1]['game_state']['turn'] if ep['steps'] else 0

    print(f'\n{"\u2500" * 55}')
    print(f'Episode {ep["episode"]}  |  outcome={ep["outcome"]}  |  '
          f'length={ep["length"]}  |  turns={final_turn}  |  '
          f'reward={ep["total_reward"]}')
    print(f'{"\u2500" * 55}')

    # Action distribution
    action_counts = Counter(s['action'] for s in ep['steps'])
    total = len(ep['steps'])
    print(f'\n  Action distribution ({total} steps):')
    for action, count in action_counts.most_common():
        pct = count / total * 100
        bar = '#' * int(pct / 2)
        print(f'    {action:15s}  {count:4d}  ({pct:5.1f}%)  {bar}')

    # Unit creation breakdown
    create_steps = [s for s in ep['steps'] if s['action'] == 'create_unit']
    if create_steps:
        unit_counts = Counter(s['unit_type'] for s in create_steps)
        print(f'\n  Units created ({len(create_steps)} total):')
        for ut, count in unit_counts.most_common():
            bar = '#' * count
            print(f'    {ut:3s}  {count:3d}  {bar}')

    # Final state
    final = ep['steps'][-1]['game_state']
    print(f'\n  Final state (step {ep["length"]}, turn {final["turn"]}):')
    print(f'    Agent:    {final["agent_units"]} units, '
          f'{final["agent_structures"]} structures, '
          f'{final["agent_gold"]} gold')
    print(f'    Opponent: {final["opponent_units"]} units, '
          f'{final["opponent_structures"]} structures, '
          f'{final["opponent_gold"]} gold')

    # First 5 moves
    steps = ep['steps']
    print(f'\n  First 5 moves:')
    for s in steps[:5]:
        gs = s['game_state']
        ut = f' ({s["unit_type"]})' if s['unit_type'] else ''
        print(f'    step {s["step"]:3d}  {s["action"]:15s}{ut:5s}  '
              f'{s["from"]}\u2192{s["to"]}  r={s["reward"]:+.1f}')

## 12. Replay visualisations

In [None]:
# Colour palette
ACTION_COLORS = {
    'move':         '#4CAF50',
    'attack':       '#F44336',
    'end_turn':     '#9E9E9E',
    'create_unit':  '#2196F3',
    'seize':        '#FF9800',
    'heal':         '#E91E63',
    'paralyze':     '#9C27B0',
    'haste':        '#00BCD4',
    'defence_buff': '#795548',
    'attack_buff':  '#FF5722',
}
AGENT_COLOR = '#2196F3'
OPP_COLOR   = '#F44336'

ORDERED_ACTIONS = [
    'move', 'attack', 'create_unit', 'end_turn',
    'seize', 'heal', 'paralyze', 'haste', 'defence_buff', 'attack_buff',
]


def _per_turn(steps, key):
    """Aggregate a game-state key per turn (value at end of turn)."""
    turn_vals = {}
    for s in steps:
        t = s['game_state']['turn']
        turn_vals[t] = s['game_state'][key]
    turns = sorted(turn_vals)
    return turns, [turn_vals[t] for t in turns]


def plot_episode(ep, ep_idx=None):
    """Create a 2x3 dashboard for a single replay episode."""
    steps = ep['steps']
    idx = ep_idx if ep_idx is not None else ep.get('episode', '?')
    outcome = ep['outcome'].upper()
    total_reward = ep['total_reward']
    length = ep['length']
    final_turn = steps[-1]['game_state']['turn'] if steps else 0

    fig = plt.figure(figsize=(16, 9))
    gs = gridspec.GridSpec(2, 3, hspace=0.35, wspace=0.35)

    outcome_color = {'WIN': '#4CAF50', 'LOSS': '#F44336', 'DRAW': '#FF9800'}
    fig.suptitle(
        f'Episode {idx}  \u2014  {outcome}  |  '
        f'{length} steps  |  {final_turn} turns  |  reward {total_reward:+.0f}',
        fontsize=14, fontweight='bold',
        color=outcome_color.get(outcome, 'black'),
    )

    # 1. Action distribution
    ax = fig.add_subplot(gs[0, 0])
    action_counts = Counter(s['action'] for s in steps)
    actions = [a for a in ORDERED_ACTIONS if action_counts.get(a, 0) > 0]
    counts = [action_counts[a] for a in actions]
    colors = [ACTION_COLORS.get(a, '#607D8B') for a in actions]
    bars = ax.barh(actions, counts, color=colors, edgecolor='white', linewidth=0.5)
    ax.set_xlabel('Count')
    ax.set_title('Action Distribution')
    ax.invert_yaxis()
    for bar, c in zip(bars, counts):
        pct = c / len(steps) * 100
        ax.text(bar.get_width() + max(counts) * 0.02, bar.get_y() + bar.get_height() / 2,
                f'{pct:.0f}%', va='center', fontsize=9, color='#555')

    # 2. Unit count over turns
    ax = fig.add_subplot(gs[0, 1])
    turns_a, agent_units = _per_turn(steps, 'agent_units')
    _, opp_units = _per_turn(steps, 'opponent_units')
    ax.plot(turns_a, agent_units, color=AGENT_COLOR, linewidth=1.5, label='Agent')
    ax.plot(turns_a, opp_units, color=OPP_COLOR, linewidth=1.5, label='Opponent')
    ax.fill_between(turns_a, agent_units, alpha=0.1, color=AGENT_COLOR)
    ax.fill_between(turns_a, opp_units, alpha=0.1, color=OPP_COLOR)
    ax.set_xlabel('Turn')
    ax.set_ylabel('Units')
    ax.set_title('Army Size')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.2)

    # 3. Gold over turns
    ax = fig.add_subplot(gs[0, 2])
    _, agent_gold = _per_turn(steps, 'agent_gold')
    _, opp_gold = _per_turn(steps, 'opponent_gold')
    ax.plot(turns_a, agent_gold, color=AGENT_COLOR, linewidth=1.5, label='Agent')
    ax.plot(turns_a, opp_gold, color=OPP_COLOR, linewidth=1.5, label='Opponent')
    ax.set_xlabel('Turn')
    ax.set_ylabel('Gold')
    ax.set_title('Economy')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.2)

    # 4. Cumulative reward
    ax = fig.add_subplot(gs[1, 0])
    cum_r = [s['cumulative_reward'] for s in steps]
    step_nums = [s['step'] for s in steps]
    ax.plot(step_nums, cum_r, color='#FF9800', linewidth=1.2)
    ax.axhline(0, color='grey', linewidth=0.5, linestyle='--')
    ax.set_xlabel('Step')
    ax.set_ylabel('Cumulative Reward')
    ax.set_title('Reward Curve')
    ax.grid(True, alpha=0.2)

    # 5. Steps per turn
    ax = fig.add_subplot(gs[1, 1])
    turn_steps = {}
    for s in steps:
        t = s['game_state']['turn']
        turn_steps[t] = turn_steps.get(t, 0) + 1
    ts_sorted = sorted(turn_steps)
    spt = [turn_steps[t] for t in ts_sorted]
    bar_colors = ['#F44336' if v == 1 else '#4CAF50' for v in spt]
    ax.bar(ts_sorted, spt, color=bar_colors, width=1.0, edgecolor='white', linewidth=0.3)
    avg_spt = np.mean(spt)
    ax.axhline(avg_spt, color='#2196F3', linewidth=1, linestyle='--',
               label=f'avg {avg_spt:.1f}')
    ax.set_xlabel('Turn')
    ax.set_ylabel('Steps')
    ax.set_title('Steps per Turn')
    ax.legend(fontsize=8)

    # 6. Unit creation breakdown
    ax = fig.add_subplot(gs[1, 2])
    create_steps = [s for s in steps if s['action'] == 'create_unit']
    if create_steps:
        unit_counts = Counter(s['unit_type'] for s in create_steps)
        labels = list(unit_counts.keys())
        sizes = list(unit_counts.values())
        wedge_colors = plt.cm.Set2(np.linspace(0, 1, len(labels)))
        ax.pie(sizes, labels=labels, autopct='%1.0f%%',
               colors=wedge_colors, startangle=90,
               textprops={'fontsize': 10})
        ax.set_title(f'Units Created ({sum(sizes)})')
    else:
        ax.text(0.5, 0.5, 'No units\ncreated', ha='center', va='center',
                fontsize=14, color='#999')
        ax.set_title('Units Created')

    plot_path = OUTPUT_DIR / f'replay_episode_{idx}.png'
    fig.savefig(str(plot_path), dpi=150, bbox_inches='tight')
    print(f'Saved: {plot_path}')
    plt.show()


print('Visualisation helpers defined.')

In [None]:
for ep in replay_episodes:
    plot_episode(ep)

## 13. Save results

In [None]:
# Save benchmark results as JSON
benchmark_data = {
    'metadata': {
        'date': datetime.now().isoformat(),
        'map': MAP_FILE,
        'opponent': OPPONENT,
        'max_steps': MAX_STEPS,
        'n_envs': N_ENVS,
        'eval_episodes': EVAL_EPISODES,
        'seed': SEED,
        'device': DEVICE,
        'ppo_config': PPO_CONFIG,
        'reward_config': REWARD_CONFIG,
    },
    'results': results,
}

results_path = OUTPUT_DIR / 'benchmark_results.json'
with open(results_path, 'w') as f:
    json.dump(benchmark_data, f, indent=2)
print(f'Saved results:  {results_path}')

csv_path = OUTPUT_DIR / 'benchmark_results.csv'
df.to_csv(csv_path, index=False)
print(f'Saved CSV:      {csv_path}')

# List all saved files
print(f'\nAll output files:')
for p in sorted(OUTPUT_DIR.iterdir()):
    if p.is_file():
        size = p.stat().st_size
        if size > 1024 * 1024:
            size_str = f'{size / 1024 / 1024:.1f} MB'
        elif size > 1024:
            size_str = f'{size / 1024:.1f} KB'
        else:
            size_str = f'{size} B'
        print(f'  {p.name:40s}  {size_str}')

## 14. What's next?

This demo trained a basic agent on the smallest map against a random
opponent. Here are ways to go further:

### Train longer / harder

| Change | How |
|--------|-----|
| More timesteps | Add `1_000_000` and `2_000_000` to `CHECKPOINTS` |
| Harder opponent | Set `OPPONENT = 'bot'` (SimpleBot) |
| Bigger map | Change `MAP_FILE` to a 10×10 or 14×14 map |

### Other algorithms

The repo includes:
- **Self-play** (`train/train_self_play.py`) — train against past versions of itself
- **AlphaZero** (`train/train_alphazero.py`) — MCTS + neural network
- **Feudal RL** (`train/train_feudal_rl.py`) — hierarchical manager-worker

### Run a tournament

See `notebooks/bot_tournament.ipynb` to pit SimpleBot, MediumBot,
AdvancedBot, and your trained agent against each other with Elo ratings.

### Repository

[github.com/kuds/reinforce-tactics](https://github.com/kuds/reinforce-tactics)

In [None]:
# Clean up environments
vec_env.close()
eval_env.close()
print('Done.')