In [1]:
!pip install gym
!pip install stable_baselines3[extra]

Collecting stable_baselines3[extra]
  Downloading stable_baselines3-2.0.0-py3-none-any.whl (178 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m178.4/178.4 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gymnasium==0.28.1 (from stable_baselines3[extra])
  Downloading gymnasium-0.28.1-py3-none-any.whl (925 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m925.5/925.5 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
Collecting shimmy[atari]~=0.2.1 (from stable_baselines3[extra])
  Downloading Shimmy-0.2.1-py3-none-any.whl (25 kB)
Collecting autorom[accept-rom-license]~=0.6.0 (from stable_baselines3[extra])
  Downloading AutoROM-0.6.1-py3-none-any.whl (9.4 kB)
Collecting jax-jumpy>=1.0.0 (from gymnasium==0.28.1->stable_baselines3[extra])
  Downloading jax_jumpy-1.0.0-py3-none-any.whl (20 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium==0.28.1->stable_baselines3[extra])
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5

In [2]:
from Pokemon import Pokemon, Pokemon_Move, Pokemon_Battle
import gymnasium as gym
from gymnasium import Env
from gymnasium.spaces import Discrete, Box, Tuple, MultiBinary, MultiDiscrete
from stable_baselines3.common.evaluation import evaluate_policy
import numpy as np
import random



In [14]:
class BattleEnv(gym.Env):
    def __init__(self):
        #Four moves for Pokemon
        self.action_space = Discrete(4)
        self.done = False

        # Define the observation space for continuous attributes (Pokemon stats and move damage dealt)
        self.num_stats = 6
        self.num_moves = 4
        self.observation_stats_low = np.zeros(self.num_stats)
        self.observation_stats_high = np.array([255.0] * self.num_stats)

        self.observation_damage_low = np.zeros(self.num_moves)
        self.observation_damage_high = np.array([100.0] * self.num_moves)

        # Define the observation space for discrete attributes (types, move types,)
        self.num_types = 18
        self.num_damage_types = 3
        self.pokemon_types = ['normal', 'fighting', 'flying', 'poison', 'ground', 'rock', 'bug', 'ghost', 'steel', 'fire', 'water', 'grass', 'electric', 'psychic', 'ice', 'dragon', 'dark', 'fairy']
        self.type_to_index = {type_name: idx for idx, type_name in enumerate(self.pokemon_types)}
        self.damage_types = ['physical', 'special', 'status']


        # Combine the continuous and discrete attributes
        self.observation_space = gym.spaces.Dict({
            'player_stats': gym.spaces.Box(low=self.observation_stats_low, high=self.observation_stats_high, dtype=np.float32),
            'player_pokemon_types': gym.spaces.MultiBinary(self.num_types),
            'player_move_types': gym.spaces.MultiDiscrete([self.num_types, self.num_types, self.num_types, self.num_types]),
            'player_move_damage': gym.spaces.Box(low=self.observation_damage_low, high=self.observation_damage_high, dtype=np.float32),
            'player_previous_moves': gym.spaces.MultiBinary(self.num_moves),
            'opponent_move_types': gym.spaces.MultiDiscrete([self.num_types, self.num_types, self.num_types, self.num_types]),
            'opponent_previous_moves': gym.spaces.MultiBinary(self.num_moves),
        })



    def step(self, action):
        #Check speed for who goes first, then perform turn if both Pokemon are up
        opponent_action = Discrete(4).sample()
        if self.Pokemon1.hp > 0 and self.Pokemon2.hp > 0:
            if self.Pokemon1.speed > self.Pokemon2.speed:
                self.battle.perform_turn(self.Pokemon1, self.Pokemon2, action)
                if self.Pokemon2.hp > 0:
                    self.battle.perform_turn(self.Pokemon2, self.Pokemon1, opponent_action)
                else:
                    print(f'{self.Pokemon2.name.capitalize()} has fainted.')
                    self.done = True
            elif self.Pokemon2.speed > self.Pokemon1.speed:
                self.battle.perform_turn(self.Pokemon2, self.Pokemon1, opponent_action)
                if self.Pokemon1.hp > 0:
                    self.battle.perform_turn(self.Pokemon1, self.Pokemon2, action)
                else:
                    print(f'{self.Pokemon1.name.capitalize()} has fainted.')
                    self.done = True
            else:
                #if speed is equal, who goes first is random
                first_pokemon = random.choice([self.Pokemon1, self.Pokemon2])
                second_pokemon = self.Pokemon2 if first_pokemon == self.Pokemon1 else self.Pokemon1
                self.battle.perform_turn(first_pokemon, second_pokemon, action if first_pokemon == self.Pokemon1 else opponent_action)
                if second_pokemon.hp > 0:
                    self.battle.perform_turn(second_pokemon, first_pokemon, action if second_pokemon == self.Pokemon1 else opponent_action)
                else:
                    print(f'{second_pokemon.name.capitalize()} has fainted.')
                    self.done = True
        else:
            self.done = True

        observation = self.get_observation()
        reward = self.get_reward()
        info = {}

        return observation, reward, self.done, False, info


    def render(self):
        if not self.done:
            print(f'{self.Pokemon1.name.capitalize()} has {self.Pokemon1.hp} hp.')
            print(f'{self.Pokemon2.name.capitalize()} has {self.Pokemon2.hp} hp.')

    def reset(self, seed=None):
        self.Pokemon1 = self.get_random_pokemon()  # Get a random player Pokemon
        self.Pokemon2 = self.get_random_pokemon()  # Get a random opponent Pokemon
        self.battle = Pokemon_Battle(self.Pokemon1, self.Pokemon2)
        self.done = False

        # Get the initial observation
        observation = self.get_observation()

        info = {}

        return observation, info

    def get_reward(self):

        if self.done:
            if self.Pokemon1.hp <= 0:
                # Player Pokemon has fainted, so the agent lost the battle
                reward = -50  # Penalty for losing the battle
            elif self.Pokemon2.hp <= 0:
                # Opponent Pokemon has fainted, so the agent won the battle
                reward = 50  # Reward for winning the battle
            else:
                # The battle ended in a draw or some other unknown condition
                reward = 0
        else:
            reward = self.Pokemon1.damage

        return reward

    def get_observation(self):
        # Observation for player's Pokemon
        player_observation_stats = np.array([self.Pokemon1.hp, self.Pokemon1.attack, self.Pokemon1.defense,
                                             self.Pokemon1.spattack, self.Pokemon1.spdefense, self.Pokemon1.speed],
                                            dtype=np.float32)

        player_observation_types = np.zeros(self.num_types, dtype=np.int8)
        for type_name in self.Pokemon1.types:
            type_idx = self.type_to_index[type_name]
            player_observation_types[type_idx] = 1.0

        player_observation_move_types = np.zeros(self.num_moves, dtype=np.int64)
        for move_idx, move in enumerate(self.Pokemon1.moves):
            move_type_idx = self.type_to_index[move.type]
            player_observation_move_types[move_idx] = 1.0


        player_observation_move_damage = np.zeros(self.num_moves, dtype=np.float32)
        for move_idx, move in enumerate(self.Pokemon1.moves):
            if not np.isnan(self.Pokemon1.damage):
                player_observation_move_damage[move_idx] = np.clip(self.Pokemon1.damage, 0.0, 100.0)
            else:
                player_observation_move_damage[move_idx] = 0.0

        player_observation_previous_moves = np.zeros(self.num_moves, dtype=np.int8)
        for move_idx, move in enumerate(self.Pokemon1.moves):
            if move in self.Pokemon1.previous_moves:
                player_observation_previous_moves[move_idx] = 1.0

        # Observation for opponent's Pokemon (types and previously used moves are not visible)
        opponent_observation_move_types = np.zeros(self.num_moves, dtype=np.int64)
        for move_idx, move in enumerate(self.Pokemon2.moves):
            move_type_idx = self.type_to_index[move.type]
            opponent_observation_move_types[move_idx] = 1.0


        opponent_observation_previous_moves = np.zeros(self.num_moves, dtype=np.int8)
        for move_idx, move in enumerate(self.Pokemon2.moves):
            if move in self.Pokemon2.previous_moves:
                opponent_observation_previous_moves[move_idx] = 1.0

        observation = {
            'player_stats': player_observation_stats,
            'player_pokemon_types': player_observation_types,
            'player_move_types': player_observation_move_types,
            'player_move_damage': player_observation_move_damage,
            'player_previous_moves': player_observation_previous_moves,
            'opponent_move_types': opponent_observation_move_types,
            'opponent_previous_moves': opponent_observation_previous_moves,
        }

        return observation


    def get_random_pokemon(self):
        #dictionary of pokemon
        pre_built_pokemon = {
        "venusaur": ["vine-whip", "razor-leaf", "sludge-bomb", "leer"],
        "charizard": ["flamethrower", "wing-attack", "dragon-claw", "growl"],
        "blastoise": ["water-gun", "hydro-pump", "ice-beam", "bite"],
        "raichu": ["thunderbolt", "volt-switch", "tackle", "growl"],
        "wigglytuff": ["dazzling-gleam", "play-rough", "leer", "tackle"],
        "persian": ["slash", "night-slash", "swift", "leer"],
    }

        PokeName = random.choice(list(pre_built_pokemon.keys()))
        moves = pre_built_pokemon[PokeName]
        list_moves = []
        for move in moves:
            list_moves.append(Pokemon_Move(move))
        pokemon = Pokemon(PokeName, list_moves)

        return pokemon


  and should_run_async(code)


In [15]:
env = BattleEnv()

  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [None]:
env.observation_space.sample()

In [None]:
env.reset()

In [None]:
from stable_baselines3.common.env_checker import check_env

In [None]:
check_env(env, warn=True)

In [None]:
Discrete(4).sample()

In [5]:
episodes = 3
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        action = env.action_space.sample()
        obs, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()

Blastoise has 139 hp.
Charizard has 138 hp.
Flamethrower: 15 pp, 90 power, 100 accuracy, fire
Wing-attack: 35 pp, 60 power, 100 accuracy, flying
Dragon-claw: 15 pp, 80 power, 100 accuracy, dragon
Growl: 40 pp, 0 power, 100 accuracy, normal
Not very effective
0.5
Charizard used Flamethrower!
Water-gun: 25 pp, 40 power, 100 accuracy, water
Hydro-pump: 5 pp, 110 power, 80 accuracy, water
Ice-beam: 10 pp, 90 power, 100 accuracy, ice
Bite: 25 pp, 60 power, 100 accuracy, dark
Hydro-pump missed!
Blastoise has 117.48 hp.
Charizard has 138 hp.
Flamethrower: 15 pp, 90 power, 100 accuracy, fire
Wing-attack: 35 pp, 60 power, 100 accuracy, flying
Dragon-claw: 15 pp, 80 power, 100 accuracy, dragon
Growl: 40 pp, 0 power, 100 accuracy, normal
1.0
Charizard used Growl!
Water-gun: 25 pp, 40 power, 100 accuracy, water
Hydro-pump: 5 pp, 110 power, 80 accuracy, water
Ice-beam: 10 pp, 90 power, 100 accuracy, ice
Bite: 25 pp, 60 power, 100 accuracy, dark
Super Effective!
2.0
Blastoise used Hydro-pump!
Blasto

In [6]:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

  and should_run_async(code)


In [16]:
model = PPO("MultiInputPolicy", env, verbose=1, tensorboard_log='Models')

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [17]:
model.learn(total_timesteps=40000)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Sludge-bomb: 10 pp, 90 power, 100 accuracy, poison
Leer: 30 pp, 0 power, 100 accuracy, normal
1.0
Venusaur used Vine-whip!
Dazzling-gleam: 10 pp, 80 power, 100 accuracy, fairy
Play-rough: 10 pp, 90 power, 90 accuracy, fairy
Leer: 30 pp, 0 power, 100 accuracy, normal
Tackle: 35 pp, 40 power, 100 accuracy, normal
Not very effective
0.5
Wigglytuff used Dazzling-gleam!
Vine-whip: 25 pp, 45 power, 100 accuracy, grass
Razor-leaf: 25 pp, 55 power, 95 accuracy, grass
Sludge-bomb: 10 pp, 90 power, 100 accuracy, poison
Leer: 30 pp, 0 power, 100 accuracy, normal
1.0
Venusaur used Razor-leaf!
Dazzling-gleam: 10 pp, 80 power, 100 accuracy, fairy
Play-rough: 10 pp, 90 power, 90 accuracy, fairy
Leer: 30 pp, 0 power, 100 accuracy, normal
Tackle: 35 pp, 40 power, 100 accuracy, normal
1.0
Wigglytuff used Leer!
Vine-whip: 25 pp, 45 power, 100 accuracy, grass
Razor-leaf: 25 pp, 55 power, 95 accuracy, grass
Sludge-bomb: 10 pp, 90 power, 100 a

KeyboardInterrupt: ignored

In [19]:
model.save('Models')

In [22]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)



Dazzling-gleam: 10 pp, 80 power, 100 accuracy, fairy
Play-rough: 10 pp, 90 power, 90 accuracy, fairy
Leer: 30 pp, 0 power, 100 accuracy, normal
Tackle: 35 pp, 40 power, 100 accuracy, normal
1.0
Wigglytuff used Dazzling-gleam!
Dazzling-gleam: 10 pp, 80 power, 100 accuracy, fairy
Play-rough: 10 pp, 90 power, 90 accuracy, fairy
Leer: 30 pp, 0 power, 100 accuracy, normal
Tackle: 35 pp, 40 power, 100 accuracy, normal
1.0
Wigglytuff used Tackle!
Dazzling-gleam: 10 pp, 80 power, 100 accuracy, fairy
Play-rough: 10 pp, 90 power, 90 accuracy, fairy
Leer: 30 pp, 0 power, 100 accuracy, normal
Tackle: 35 pp, 40 power, 100 accuracy, normal
1.0
Wigglytuff used Play-rough!
Dazzling-gleam: 10 pp, 80 power, 100 accuracy, fairy
Play-rough: 10 pp, 90 power, 90 accuracy, fairy
Leer: 30 pp, 0 power, 100 accuracy, normal
Tackle: 35 pp, 40 power, 100 accuracy, normal
Play-rough missed!
Dazzling-gleam: 10 pp, 80 power, 100 accuracy, fairy
Play-rough: 10 pp, 90 power, 90 accuracy, fairy
Leer: 30 pp, 0 power, 10



Flamethrower: 15 pp, 90 power, 100 accuracy, fire
Wing-attack: 35 pp, 60 power, 100 accuracy, flying
Dragon-claw: 15 pp, 80 power, 100 accuracy, dragon
Growl: 40 pp, 0 power, 100 accuracy, normal
Super Effective!
2.0
Charizard used Wing-attack!
Vine-whip: 25 pp, 45 power, 100 accuracy, grass
Razor-leaf: 25 pp, 55 power, 95 accuracy, grass
Sludge-bomb: 10 pp, 90 power, 100 accuracy, poison
Leer: 30 pp, 0 power, 100 accuracy, normal
Not very effective
Not very effective
0.25
Venusaur used Razor-leaf!
Flamethrower: 15 pp, 90 power, 100 accuracy, fire
Wing-attack: 35 pp, 60 power, 100 accuracy, flying
Dragon-claw: 15 pp, 80 power, 100 accuracy, dragon
Growl: 40 pp, 0 power, 100 accuracy, normal
1.0
Charizard used Dragon-claw!
Vine-whip: 25 pp, 45 power, 100 accuracy, grass
Razor-leaf: 25 pp, 55 power, 95 accuracy, grass
Sludge-bomb: 10 pp, 90 power, 100 accuracy, poison
Leer: 30 pp, 0 power, 100 accuracy, normal
Not very effective
Not very effective
0.25
Venusaur used Razor-leaf!
Flamethro

(107.98621234893798, 99.03240513744237)

In [24]:
!tensorboard --logdir Models.zip


NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.12.3 at http://localhost:6006/ (Press CTRL+C to quit)
^C
