In [2]:
import asyncio
import os
import matplotlib.pyplot as plt
import numpy as np
from gymnasium.spaces import Box, Space
from gymnasium.utils.env_checker import check_env
from rl.agents.dqn import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import EpsGreedyQPolicy, LinearAnnealedPolicy
from tabulate import tabulate
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.player import (
    Gen8EnvSinglePlayer,
    MaxBasePowerPlayer,
    ObsType,
    RandomPlayer,
    SimpleHeuristicsPlayer,
    background_cross_evaluate,
    background_evaluate_player,
)
from poke_env.data import GenData

type_chart = GenData(9).type_chart

class DQNPlayer(Gen8EnvSinglePlayer):
    def calc_reward(self, last_battle, current_battle) -> float:
        # Add reward for dealing damage and penalty for fainted Pokémon
        damage_dealt = (last_battle.opponent_active_pokemon.current_hp - 
                        current_battle.opponent_active_pokemon.current_hp)
        reward = self.reward_computing_helper(
            current_battle,
            fainted_value=2.0,
            hp_value=1.0,
            victory_value=30.0,
        )
        # Reward for damage dealt
        reward += damage_dealt * 0.5
        return reward

    def embed_battle(self, battle: AbstractBattle) -> ObsType:
        # Embed battle with action masking for unavailable moves
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = move.base_power / 100
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                    type_chart=type_chart
                )

        fainted_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
        fainted_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
        )

        final_vector = np.concatenate(
            [
                moves_base_power,
                moves_dmg_multiplier,
                [fainted_mon_team, fainted_mon_opponent],
            ]
        )
        return np.float32(final_vector)

    def describe_embedding(self) -> Space:
        low = [-1] * 4 + [0] * 4 + [0, 0]
        high = [3] * 4 + [4] * 4 + [1, 1]
        return Box(np.array(low, dtype=np.float32), np.array(high, dtype=np.float32), dtype=np.float32)

async def run_dqn_training():
    train_opponent = RandomPlayer(battle_format="gen8randombattle")
    train_player_env = DQNPlayer(
        battle_format="gen8randombattle", opponent=train_opponent, start_challenging=True
    )
    eval_opponent = RandomPlayer(battle_format="gen8randombattle")
    eval_player_env = DQNPlayer(
        battle_format="gen8randombattle", opponent=eval_opponent, start_challenging=True
    )

    n_actions = train_player_env.action_space.n
    input_shape = (1,) + train_player_env.observation_space.shape

    dqn_model = Sequential([
        Flatten(input_shape=input_shape),
        Dense(128, activation="elu"),
        Dense(64, activation="elu"),
        Dense(n_actions, activation="linear")
    ])

    memory = SequentialMemory(limit=50000, window_length=1)
    dqn_policy = LinearAnnealedPolicy(
        EpsGreedyQPolicy(),
        attr="eps",
        value_max=1.0,
        value_min=0.1,
        value_test=0.05,
        nb_steps=50000,
    )

    dqn_agent = DQNAgent(
        model=dqn_model,
        nb_actions=n_actions,
        policy=dqn_policy,
        memory=memory,
        nb_steps_warmup=1000,
        gamma=0.9,
        target_model_update=1000,
        enable_double_dqn=True,
    )
    dqn_agent.compile(Adam(learning_rate=0.001), metrics=["mae"])

    # Training with logging
    history = dqn_agent.fit(train_player_env, nb_steps=50000, visualize=False, verbose=2)
    train_player_env.close()

    # Save the model
    if not os.path.exists("models"):
        os.makedirs("models")
    dqn_agent.model.save("models/dqn_pokemon_model.h5")

    # Plot training rewards
    plt.plot(history.history["episode_reward"])
    plt.title("Training Rewards")
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.show()

    # Evaluation
    print("Evaluating against random player...")
    dqn_agent.test(eval_player_env, nb_episodes=100, visualize=False)
    print(
        f"Victories: {eval_player_env.n_won_battles} / {eval_player_env.n_finished_battles}"
    )
    eval_player_env.close()

if __name__ == "__main__":
    import nest_asyncio

    nest_asyncio.apply()
    await run_dqn_training()

Matplotlib is building the font cache; this may take a moment.


Training for 50000 steps ...


ValueError: Error when checking input: expected flatten_input to have shape (1, 10) but got array with shape (1, 2)

In [None]:
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import DummyVecEnv
from gymnasium.spaces import Box
from poke_env.data import GenData
from poke_env.player import Gen9EnvSinglePlayer, RandomPlayer

# Define the custom RL player
class SimpleRLPlayer(Gen9EnvSinglePlayer):
    def embed_battle(self, battle):
        """
        Embeds the current battle state into a vector of 10 values.
        """
        # Initialize base power and damage multipliers
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = (
                move.base_power / 100  # Rescale base power
            )
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                    type_chart=GEN_9_DATA.type_chart,
                )

        # Count remaining Pokemon for both sides
        remaining_mon_team = (
            len([mon for mon in battle.team.values() if not mon.fainted]) / 6
        )
        remaining_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if not mon.fainted]) / 6
        )

        # Final embedding vector
        print(50*"="+"\n",[
                moves_base_power,
                moves_dmg_multiplier,
                [remaining_mon_team, remaining_mon_opponent],
            ],"\n"+50*"=")
        return np.concatenate(
            [
                moves_base_power,
                moves_dmg_multiplier,
                [remaining_mon_team, remaining_mon_opponent],
            ]
        )

    def calc_reward(self, last_state, current_state) -> float:
        """
        Computes the reward for the agent based on the current state.
        """
        return self.reward_computing_helper(
            current_state, fainted_value=2, hp_value=1, victory_value=30
        )

    def describe_embedding(self):
        """
        Describes the observation space for Stable-Baselines3.
        """
        low = [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0]
        high = [3, 3, 3, 3, 4, 4, 4, 4, 1, 1]
        return Box(
            np.array(low, dtype=np.float32),
            np.array(high, dtype=np.float32),
            dtype=np.float32,
        )

# Define a MaxDamagePlayer for evaluation
class MaxDamagePlayer(RandomPlayer):
    def choose_move(self, battle):
        # Prioritize moves with the highest base power
        if battle.available_moves:
            best_move = max(battle.available_moves, key=lambda move: move.base_power)
            return self.create_order(best_move)
        else:
            return self.choose_random_move(battle)

# Training and evaluation settings
NB_TRAINING_STEPS = 100_000
TEST_EPISODES = 100
GEN_9_DATA = GenData.from_gen(9)

if __name__ == "__main__":
    # Create opponents
    random_opponent = RandomPlayer()
    max_damage_opponent = MaxDamagePlayer()

    # Create training environment
    env_player = SimpleRLPlayer(opponent=random_opponent)
    env = DummyVecEnv([lambda: env_player])  # Wrap in DummyVecEnv for Stable-Baselines3

    # Train the A2C model
    model = A2C("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=NB_TRAINING_STEPS)
    
    obs, reward, done, _, info = env_player.step(0)
    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, _, info = env_player.step(action)

    finished_episodes = 0

    # Evaluate against RandomPlayer
    env_player.reset_battles()
    obs, _ = env_player.reset()
    finished_episodes = 0
    while True:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, _, info = env_player.step(action)

        if done:
            finished_episodes += 1
            if finished_episodes >= TEST_EPISODES:
                break
            obs, _ = env_player.reset()

    print("Evaluation against RandomPlayer: ", env_player.n_won_battles, "wins out of", TEST_EPISODES)

    # Evaluate against MaxDamagePlayer
    env_player._opponent = max_damage_opponent
    env_player.reset_battles()
    obs, _ = env_player.reset()
    finished_episodes = 0
    while True:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, _, info = env_player.step(action)

        if done:
            finished_episodes += 1
            obs, _ = env_player.reset()
            if finished_episodes >= TEST_EPISODES:
                break

    print("Evaluation against MaxDamagePlayer: ", env_player.n_won_battles, "wins out of", TEST_EPISODES)


Using cpu device
------------------------------------
| time/                 |          |
|    fps                | 165      |
|    iterations         | 100      |
|    time_elapsed       | 3        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -3.23    |
|    explained_variance | 0.246    |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 1.51     |
|    value_loss         | 1.57     |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 188      |
|    iterations         | 200      |
|    time_elapsed       | 5        |
|    total_timesteps    | 1000     |
| train/                |          |
|    entropy_loss       | -3.25    |
|    explained_variance | 0        |
|    learning_rate      | 0.0007   |
|    n_updates          | 199      |
|    policy_loss        | 1.75     |
|    value_loss      