# Pokemon Red RL — Exploration Notebook

Use this notebook to:
- Verify the environment is set up correctly
- Inspect checkpoint content
- Visualise training curves from TensorBoard logs
- Run quick experiments before committing to a full training run

In [None]:
import sys
sys.path.insert(0, '..')   # so src.* imports work from notebooks/

import numpy as np
import torch
import matplotlib.pyplot as plt
from pathlib import Path

print('torch:', torch.__version__)
print('device:', 'cuda' if torch.cuda.is_available() else 'cpu')

## 1. Agent sanity check (no ROM required)

In [None]:
from src.agent.simple_agent import SimpleAgent

obs_shape   = (4, 84, 84)
num_actions = 8
agent = SimpleAgent(obs_shape=obs_shape, num_actions=num_actions)

print(f'Parameters: {agent.count_parameters():,}')

dummy_obs = torch.randint(0, 256, (1, *obs_shape), dtype=torch.uint8)
action, log_prob, entropy, value = agent.get_action_and_value(dummy_obs)
print(f'action={action.item()}  log_prob={log_prob.item():.4f}  entropy={entropy.item():.4f}  value={value.item():.4f}')

## 2. Environment smoke test (requires ROM)

Place `PokemonRed.gb` in `../roms/` before running this cell.

In [None]:
ROM_PATH = '../roms/PokemonRed.gb'

if not Path(ROM_PATH).exists():
    print(f'ROM not found at {ROM_PATH}. Skipping environment tests.')
else:
    from src.env.pokemon_env import make_env

    env = make_env(rom_path=ROM_PATH, max_steps=100, headless=True)
    obs, info = env.reset()
    print(f'obs shape : {obs.shape}  dtype: {obs.dtype}')
    print(f'info      : {info}')

    # Take 10 random steps
    rewards = []
    for _ in range(10):
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        rewards.append(reward)
        if terminated or truncated:
            break

    env.close()
    print(f'Rewards over 10 steps: {[f"{r:.4f}" for r in rewards]}')
    print(f'Total reward: {sum(rewards):.4f}')

## 3. Visualise a screen frame

In [None]:
if Path(ROM_PATH).exists():
    from src.env.pokemon_env import make_env

    env = make_env(rom_path=ROM_PATH, max_steps=50, headless=True, frame_stack=4)
    obs, _ = env.reset()

    fig, axes = plt.subplots(1, 4, figsize=(14, 3))
    for i, ax in enumerate(axes):
        ax.imshow(obs[i], cmap='gray', vmin=0, vmax=255)
        ax.set_title(f'Frame {i}')
        ax.axis('off')
    plt.suptitle('Stacked Grayscale Observation (4 frames)', y=1.02)
    plt.tight_layout()
    plt.show()

    env.close()
else:
    print('No ROM — showing random noise instead')
    fake_obs = np.random.randint(0, 256, (4, 84, 84), dtype=np.uint8)
    fig, axes = plt.subplots(1, 4, figsize=(14, 3))
    for i, ax in enumerate(axes):
        ax.imshow(fake_obs[i], cmap='gray', vmin=0, vmax=255)
        ax.set_title(f'Frame {i} (random)')
        ax.axis('off')
    plt.tight_layout()
    plt.show()

## 4. Inspect a saved checkpoint

In [None]:
import glob

checkpoints = sorted(glob.glob('../checkpoints/*.pt'))
if not checkpoints:
    print('No checkpoints found. Run training first: python scripts/train_local.py')
else:
    latest = checkpoints[-1]
    print(f'Latest checkpoint: {latest}')

    ckpt = torch.load(latest, map_location='cpu')
    print(f'  update       : {ckpt["update"]}')
    print(f'  global_step  : {ckpt["global_step"]:,}')
    print(f'  config keys  : {list(ckpt["config"].keys())}')

    # Parameter norms
    from src.agent.simple_agent import SimpleAgent
    agent = SimpleAgent()
    agent.load_state_dict(ckpt['agent_state'])
    for name, p in agent.named_parameters():
        print(f'  {name:40s}  norm={p.norm().item():.4f}')

## 5. TensorBoard inline (optional)

If you prefer not to open a browser, TensorBoard can render inline in Jupyter:

In [None]:
# Uncomment to load TensorBoard inline:
# %load_ext tensorboard
# %tensorboard --logdir ../runs