In [1]:
from typing import Any, Dict, Optional

import numpy as np
import numpy.typing as npt
import torch
import torch.nn as nn
from gymnasium.spaces import Box, Discrete, Space
from poke_env.battle import AbstractBattle, Battle
from poke_env.environment import SingleAgentWrapper, SinglesEnv
from poke_env.player import RandomPlayer
from ray.rllib.algorithms import PPOConfig
from ray.rllib.core import Columns
from ray.rllib.core.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.env import ParallelPettingZooEnv
from ray.tune.registry import register_env


2025-09-02 13:34:52,609	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-09-02 13:34:54,321	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
class ExampleEnv(SinglesEnv[npt.NDArray[np.float32]]):
    LOW = [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0]
    HIGH = [3, 3, 3, 3, 4, 4, 4, 4, 1, 1]

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.observation_spaces = {
            agent: Box(
                np.array(self.LOW, dtype=np.float32),
                np.array(self.HIGH, dtype=np.float32),
                dtype=np.float32,
            )
            for agent in self.possible_agents
        }

    @classmethod
    def create_multi_agent_env(cls, config: Dict[str, Any]) -> ParallelPettingZooEnv:
        env = cls(
            battle_format=config["battle_format"],
            log_level=25,
            open_timeout=None,
            strict=False,
        )
        return ParallelPettingZooEnv(env)

    @classmethod
    def create_single_agent_env(cls, config: Dict[str, Any]) -> SingleAgentWrapper:
        env = cls(
            battle_format=config["battle_format"],
            log_level=25,
            open_timeout=None,
            strict=False,
        )
        opponent = RandomPlayer()
        return SingleAgentWrapper(env, opponent)

    def calc_reward(self, battle) -> float:
        return self.reward_computing_helper(
            battle, fainted_value=2.0, hp_value=1.0, victory_value=30.0
        )

    def embed_battle(self, battle: AbstractBattle):
        assert isinstance(battle, Battle)
        # -1 indicates that the move does not have a base power
        # or is not available
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = (
                move.base_power / 100
            )  # Simple rescaling to facilitate learning
            if battle.opponent_active_pokemon is not None:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                    type_chart=battle.opponent_active_pokemon._data.type_chart,
                )

        # We count how many pokemons have fainted in each team
        fainted_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
        fainted_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
        )

        # Final vector with 10 components
        final_vector = np.concatenate(
            [
                moves_base_power,
                moves_dmg_multiplier,
                [fainted_mon_team, fainted_mon_opponent],
            ]
        )
        return np.float32(final_vector)


In [3]:
class ActorCriticModule(TorchRLModule, ValueFunctionAPI):
    def __init__(
        self,
        observation_space: Space,
        action_space: Space,
        inference_only: bool,
        model_config: Dict[str, Any],
        catalog_class: Any,
    ):
        super().__init__(
            observation_space=observation_space,
            action_space=action_space,
            inference_only=inference_only,
            model_config=model_config,
            catalog_class=catalog_class,
        )
        self.model = nn.Linear(10, 100)
        self.actor = nn.Linear(100, 26)
        self.critic = nn.Linear(100, 1)

    def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        obs = batch[Columns.OBS]
        embeddings = self.model(obs)
        logits = self.actor(embeddings)
        return {Columns.EMBEDDINGS: embeddings, Columns.ACTION_DIST_INPUTS: logits}

    def compute_values(
        self, batch: Dict[str, Any], embeddings: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        if embeddings is None:
            embeddings = self.model(batch[Columns.OBS])
        return self.critic(embeddings).squeeze(-1)

In [4]:
def single_agent_train():
    register_env("showdown", ExampleEnv.create_single_agent_env)
    config = PPOConfig()
    config = config.environment(
        "showdown",
        env_config={"battle_format": "gen9randombattle"},
        disable_env_checking=True,
    )
    config = config.learners(num_learners=1)
    config = config.rl_module(
        rl_module_spec=RLModuleSpec(
            module_class=ActorCriticModule,
            observation_space=Box(
                np.array(ExampleEnv.LOW, dtype=np.float32),
                np.array(ExampleEnv.HIGH, dtype=np.float32),
                dtype=np.float32,
            ),
            action_space=Discrete(26),
            model_config={},
        )
    )
    config = config.training(
        gamma=0.99, lr=1e-3, train_batch_size=1024, num_epochs=10, minibatch_size=64
    )
    algo = config.build_algo()
    algo.train()

In [5]:
def multi_agent_train():
    register_env("showdown", ExampleEnv.create_multi_agent_env)
    config = PPOConfig()
    config = config.environment(
        "showdown",
        env_config={"battle_format": "gen9randombattle"},
        disable_env_checking=True,
    )
    config = config.learners(num_learners=1)
    config = config.multi_agent(
        policies={"p1"},
        policy_mapping_fn=lambda agent_id, ep_type: "p1",
        policies_to_train=["p1"],
    )
    config = config.rl_module(
        rl_module_spec=RLModuleSpec(
            module_class=ActorCriticModule,
            observation_space=Box(
                np.array(ExampleEnv.LOW, dtype=np.float32),
                np.array(ExampleEnv.HIGH, dtype=np.float32),
                dtype=np.float32,
            ),
            action_space=Discrete(26),
            model_config={},
        )
    )
    config = config.training(
        gamma=0.99, lr=1e-3, train_batch_size=1024, num_epochs=10, minibatch_size=64
    )
    algo = config.build_algo()
    algo.train()

In [None]:
single_agent_train()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2025-09-02 13:35:16,864	INFO worker.py:1951 -- Started a local Ray instance.
[2025-09-02 13:35:17,893 E 84320 84320] core_worker.cc:2246: Actor with class name: 'SingleAgentEnvRunner' and ID: 'c960946e66b878ecd8fd4e0401000000' has constructor arguments in the object store and max_restarts > 0. If the a

[36m(_WrappedExecutable pid=84811)[0m [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0


2025-09-02 13:35:26,748	INFO trainable.py:161 -- Trainable.setup took 11.187 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
[36m(SingleAgentEnvRunner pid=84810)[0m 2025-09-02 13:36:16,765	ERROR actor_manager.py:187 -- Worker exception caught during `apply()`: Agent is not challenging
[36m(SingleAgentEnvRunner pid=84810)[0m   File "/home/bapti/code/blegeron/pokemon/poke-agent/.venv/lib/python3.12/site-packages/ray/rllib/utils/actor_manager.py", line 183, in apply
[36m(SingleAgentEnvRunner pid=84810)[0m     return func(self, *args, **kwargs)
[36m(SingleAgentEnvRunner pid=84810)[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[36m(SingleAgentEnvRunner pid=84810)[0m   File "/home/bapti/code/blegeron/pokemon/poke-agent/.venv/lib/python3.12/site-packages/ray/rllib/execution/rollout_ops.py", line 110, in <lambda>
[36m(SingleAgentEnvRunner pid=84810)[0m     else (lambda w: (w.sample(**random_action_kwargs), w.get_m

[33m(raylet)[0m A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff586a086dbe96eaf445d116a801000000 Worker ID: a4b8af2a4efddedea4b5469cfbf7b142866a95eb05f5b77e2e7d92ca Node ID: 4bf3cca88804bdfd14ef7ea9d0197bbd5ca69e9becb4177f749e879f Worker IP address: 172.29.92.4 Worker port: 46177 Worker PID: 84811 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.


[36m(MultiAgentEnvRunner pid=84814)[0m 2025-09-02 13:37:15,500 - ExampleEnv f9072 - CRITICAL - Error message received: |nametaken|ExampleEnv f9072|Your authentication token was invalid.
[36m(MultiAgentEnvRunner pid=84814)[0m 2025-09-02 13:37:15,500 - ExampleEnv f9072 - ERROR - Unhandled exception raised while handling message:
[36m(MultiAgentEnvRunner pid=84814)[0m |nametaken|ExampleEnv f9072|Your authentication token was invalid.
[36m(MultiAgentEnvRunner pid=84814)[0m Traceback (most recent call last):
[36m(MultiAgentEnvRunner pid=84814)[0m   File "/home/bapti/code/blegeron/pokemon/poke-agent/.venv/lib/python3.12/site-packages/poke_env/ps_client/ps_client.py", line 170, in _handle_message
[36m(MultiAgentEnvRunner pid=84814)[0m     raise ShowdownException("Error message received: %s", message)
[36m(MultiAgentEnvRunner pid=84814)[0m poke_env.exceptions.ShowdownException: ('Error message received: %s', '|nametaken|ExampleEnv f9072|Your authentication token was invalid.')
[

[33m(raylet)[0m A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffffc960946e66b878ecd8fd4e0401000000 Worker ID: 8fa219badc97f38353b324259e37ced97d024094a81fb23009f8f9f6 Node ID: 4bf3cca88804bdfd14ef7ea9d0197bbd5ca69e9becb4177f749e879f Worker IP address: 172.29.92.4 Worker port: 36955 Worker PID: 84810 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker exits unexpectedly. Worker exits with an exit code 1.


2025-09-02 13:38:05,538	ERROR actor_manager.py:873 -- Ray error (The actor died because of an error raised in its creation task, [36mray::MultiAgentEnvRunner.__init__()[39m (pid=84814, ip=172.29.92.4, actor_id=3be107bdcb26f5a68f8c5a1b01000000, repr=<ray.rllib.env.multi_agent_env_runner.MultiAgentEnvRunner object at 0x70bccdbb4f50>)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bapti/code/blegeron/pokemon/poke-agent/.venv/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 112, in __init__
    self.make_env()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bapti/code/blegeron/pokemon/poke-agent/.venv/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 825, in make_env
    self.env = make_vec(
               ^^^^^^^^^
  File "/home/bapti/code/blegeron/pokemon/poke-agent/.venv/lib/python3.12/site-packages/ray/rllib/env/vector/registration.py", line 69, in make_vec
    env 

IndexError: list index out of range