In [27]:
from pyboy import PyBoy
from PIL import Image
import numpy as np
import time
from IPython.display import display

import sys
print("Python:", sys.executable)

import torch
print("Torch:", torch.__version__)

import gymnasium as gym
from gymnasium import spaces
print("Gymnasium:", gym.__version__)

from stable_baselines3 import PPO
print("SB3 imported OK")

# --- quick CartPole sanity test (optional, can comment out later) ---
env_test = gym.make("CartPole-v1")
obs, info = env_test.reset()
print("CartPole obs:", obs)

model_test = PPO("MlpPolicy", env_test, n_steps=64, batch_size=32, verbose=0)
model_test.learn(total_timesteps=100)
print("PPO sanity test passed.")
env_test.close()

Python: c:\Users\toanc\Develop\PyBoy\venv\Scripts\python.exe
Torch: 2.9.1+cpu
Gymnasium: 1.2.2
SB3 imported OK
CartPole obs: [ 0.02623104 -0.00920921  0.03467976 -0.01302149]
PPO sanity test passed.


In [None]:

# --- constants from before ---
MAP_ID_ADDR   = 0xD35E
PLAYER_Y_ADDR = 0xD361
PLAYER_X_ADDR = 0xD362

# Keep your 8-button action space for now
ACTIONS = {
    0: lambda pb: None,
    1: lambda pb: pb.button("up"),
    2: lambda pb: pb.button("down"),
    3: lambda pb: pb.button("left"),
    4: lambda pb: pb.button("right"),
    5: lambda pb: pb.button("a"),
    6: lambda pb: pb.button("b"),
    7: lambda pb: pb.button("start"),
}
NUM_ACTIONS = len(ACTIONS)


class PokemonRedEnv:
    """
    Task:
      - Start inside the first house.
      - Leave the house (map id changes).

    Rewards:
      - step penalty:              -0.001
      - move to a different tile:  +0.01
      - visit a new tile:          +0.02
      - leave starting map:        +1.0  (episode done)
    """

    def __init__(self, rom_path: str, state_path: str, max_steps: int = 300):
        self.rom_path = rom_path
        self.state_path = state_path
        self.max_steps = max_steps

        self.pyboy = None
        self.game = None

        self.start_map_id = None
        self.step_count = 0

        self.prev_player_pos = None  # (x, y)
        self.visited = set()         # {(map_id, x, y)}

    # ---- helpers ----

    def _init_emulator(self):
        self.pyboy = PyBoy(
            self.rom_path,
            window="null",    # headless for training
            no_input=False,   # IMPORTANT: allow input from ACTIONS
        )
        self.game = self.pyboy.game_wrapper

    def _get_map_id(self) -> int:
        return int(self.pyboy.memory[MAP_ID_ADDR])

    def _get_player_pos(self) -> tuple[int, int]:
        y = int(self.pyboy.memory[PLAYER_Y_ADDR])
        x = int(self.pyboy.memory[PLAYER_X_ADDR])
        return (x, y)  # (x, y) consistently

    def _get_obs(self) -> np.ndarray:
        return np.array(self.game.game_area(), dtype=np.int16)

    # ---- API ----

    def reset(self):
        # stop any previous emu
        if self.pyboy is not None:
            self.pyboy.stop()

        self._init_emulator()

        # load your starting save state
        with open(self.state_path, "rb") as f:
            self.pyboy.load_state(f)

        # settle a bit
        for _ in range(10):
            self.pyboy.tick()

        self.start_map_id = self._get_map_id()
        self.step_count = 0

        pos = self._get_player_pos()
        self.prev_player_pos = pos

        self.visited = {(self.start_map_id, *pos)}

        obs = self._get_obs()
        return obs

    def step(self, action: int):
        self.step_count += 1

        # ---- apply action ----
        ACTIONS[int(action)](self.pyboy)

        for _ in range(3):
            self.pyboy.tick()

        obs = self._get_obs()
        current_map = self._get_map_id()
        pos = self._get_player_pos()
        old_pos = self.prev_player_pos

        done = False
        reward = 0.0
        reward -= 0.001

        moved = (pos != old_pos)
        if moved:
            reward += 0.1
            key = (current_map, *pos)
            if key not in self.visited:
                reward += 5
                self.visited.add(key)

        self.prev_player_pos = pos

        if current_map != self.start_map_id:
            reward += 15.0
            done = True

        if self.step_count >= self.max_steps:
            done = True

        info = {"map_id": current_map, "player_pos": pos}

        return obs, reward, done, info

    def close(self):
        if self.pyboy is not None:
            self.pyboy.stop()
            self.pyboy = None
            self.game = None

class VisualPokemonRedEnv(PokemonRedEnv):
    def _init_emulator(self):
        self.pyboy = PyBoy(
            self.rom_path,
            window="SDL2",   # GUI window
            no_input=False,  # ðŸ‘ˆ still allow input from .button()
        )
        self.game = self.pyboy.game_wrapper

In [29]:
# ====== GYM WRAPPER (Gymnasium-compatible) ======

class PokemonRedGymWrapper(gym.Env):
    """
    Gymnasium-compatible wrapper around PokemonRedEnv.
    This is what you pass to Stable-Baselines3.
    """

    metadata = {"render_modes": []}

    def __init__(self, rom_path, state_path, max_steps: int = 500):
        super().__init__()
        self.env = PokemonRedEnv(rom_path, state_path, max_steps=max_steps)

        # Observation: 18x20 integer grid
        self.observation_space = spaces.Box(
            low=0,
            high=300,     # safe upper bound for tile IDs
            shape=(18, 20),
            dtype=np.int16,
        )

        # Actions: 8 discrete actions
        self.action_space = spaces.Discrete(NUM_ACTIONS)

    def reset(self, *, seed=None, options=None):
        # Gymnasium reset signature: returns (obs, info)
        if seed is not None:
            # we don't use seed internally yet, but gymnasium expects the arg
            np.random.seed(seed)
        obs = self.env.reset()
        info = {}
        return obs, info

    def step(self, action):
        action = int(action)
        obs, reward, done, info = self.env.step(action)
        terminated = done
        truncated = False
        return obs, reward, terminated, truncated, info

    def close(self):
        self.env.close()
        super().close()



In [None]:
# ====== PPO TRAINING ON POKÃ‰MON ======

gym_env = PokemonRedGymWrapper("red.gb", "pokemon_red_start.state", max_steps=300)

model = PPO(
    "MlpPolicy",
    gym_env,
    verbose=1,
    learning_rate=3e-4,
    n_steps=256,
    batch_size=64,
    gamma=0.99,
)

model.learn(total_timesteps=200_000)
print("PPO training on PokÃ©mon completed.")

# ðŸ”´ DON'T close gym_env yet
# gym_env.close()

def evaluate_agent(model, env, n_episodes=20, max_steps=300):
    successes = 0
    rewards = []

    for ep in range(n_episodes):
        obs, info = env.reset()
        start_map = env.env.start_map_id
        total_reward = 0.0

        for t in range(max_steps):
            action, _ = model.predict(obs, deterministic=False)
            action = int(action)

            obs, reward, terminated, truncated, info = env.step(action)
            total_reward += reward
            done = terminated or truncated

            if done:
                break

        final_map = info["map_id"]
        success = (final_map != start_map)
        if success:
            successes += 1

        rewards.append(total_reward)
        print(
            f"Episode {ep}: reward={total_reward:.3f}, "
            f"start_map={start_map}, final_map={final_map}, success={success}, "
            f"final_pos={info.get('player_pos')}"
        )

    print(f"\nSuccesses: {successes}/{n_episodes}")
    print(f"Average reward: {np.mean(rewards):.3f}")

evaluate_agent(model, gym_env, n_episodes=20, max_steps=300)
gym_env.close()


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------
| time/              |     |
|    fps             | 953 |
|    iterations      | 1   |
|    time_elapsed    | 0   |
|    total_timesteps | 256 |
----------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 300         |
|    ep_rew_mean          | 20.2        |
| time/                   |             |
|    fps                  | 817         |
|    iterations           | 2           |
|    time_elapsed         | 0           |
|    total_timesteps      | 512         |
| train/                  |             |
|    approx_kl            | 0.010152247 |
|    clip_fraction        | 0.025       |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.07       |
|    explained_variance   | -0.115      |
|    learning_rate        | 0.0003      |
|    loss                 | 1.4

In [None]:
vis_env = VisualPokemonRedEnv("red.gb", "pokemon_red_start.state", max_steps=2000)

obs = vis_env.reset()
start_map_id = vis_env.start_map_id
total_reward = 0.0

for t in range(300):
    action, _ = model.predict(obs, deterministic=False)
    action = int(action)

    obs, reward, done, info = vis_env.step(action)
    total_reward += reward


vis_env.close()