In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from gymnasium import Env
from gymnasium.wrappers import GrayScaleObservation, ResizeObservation, TimeLimit

from pokerl.env.pokemonblue import  PokemonBlueEnv
from pokerl.env.wrappers import (
    ObservationAddPokemonLevel,
    ObservationAddPosition,
    ObservationDict,
    RewardDecreasingNoChange,
    RewardDecreasingSteps,
    RewardHistoryToInfo,
    RewardIncreasingBadges,
    RewardIncreasingCapturePokemon,
    RewardIncreasingPokemonLevel,
    RewardIncreasingPositionExploration,
    RemoveSelectStartAction,
    ppFlattenInfo,
)

from pokerl.env.wrappers.rewards import RewardIncreasingLandedAttack,RewardDecreasingLostBattle



In [3]:
def create_env(interactive=False) -> Env:
    env = PokemonBlueEnv(interactive=interactive)
    # Setting observation
    env = ResizeObservation(env, 64)
    env = GrayScaleObservation(env)
    env = ObservationDict(env)
    env = ObservationAddPosition(env)
    env = ObservationAddPokemonLevel(env)
    env = RemoveSelectStartAction(env)
    # Setting reward
    env = RewardDecreasingNoChange(env, 0.01)
    env = RewardDecreasingSteps(env, .01)
    env = RewardIncreasingBadges(env, 100)
    env = RewardIncreasingCapturePokemon(env, 10)
    env = RewardIncreasingPokemonLevel(env, 10)
    env = RewardIncreasingLandedAttack(env, 0.05)
    env = RewardDecreasingLostBattle(env, 0.1)
    # env = RewardIncreasingPositionExploration(env, 1)
    env = RewardHistoryToInfo(env)
    # Post processing
    # env = TimeLimit(env, 300)
    # env = ppFlattenInfo(env)
    return env

In [4]:
env = create_env()

In [5]:
env.action_space

Discrete(7)

In [6]:
import torch
from stable_baselines3 import ppo
from stable_baselines3.common.env_util import make_vec_env


from pokerl.agent.tools import get_device

# env = make_vec_env(create_env, n_envs=8)

# model = ppo.PPO(
#     "MultiInputPolicy",
#     env,
#     device=get_device(),
#     verbose=1
#     )


In [7]:
from stable_baselines3.common.vec_env import SubprocVecEnv
from wandb.integration.sb3 import WandbCallback

def make_env(rank, seed=0):
    """
    Utility function for multiprocessed env.
    :param env_id: (str) the environment ID
    :param num_env: (int) the number of environments you wish to have in subprocesses
    :param seed: (int) the initial seed for RNG
    :param rank: (int) index of the subprocess
    """
    def _init():
        env = create_env()
        env.reset(seed=(seed + rank))
        return env
    return _init

# config = {
#     "policy_type": "MultiInputPolicy",
#     "total_timesteps": 5000,
#     "env_name": "PokemonBlueEnv-v1",
# }

# run = wandb.init(
#     project="sb3",
#     config=config,
#     sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
#     monitor_gym=True,  # auto-upload the videos of agents playing the game
#     save_code=True,  # optional
# )

nb_cpus = 16
ep_length = 1e3
subproc = SubprocVecEnv([make_env(i) for i in range(nb_cpus)])

model = ppo.PPO(
    "MultiInputPolicy",
    subproc,
    learning_rate=0.001,
    n_steps=int(ep_length*nb_cpus),
    batch_size=512,
    n_epochs=10,
    gamma=0.95,
    gae_lambda=0.95,
    clip_range=0.2,
    verbose=2,
    # callback=WandbCallback(),
)
model.learn(total_timesteps=ep_length*nb_cpus*nb_cpus, progress_bar=True)


Output()

Using cpu device


-------------------------------
| time/              |        |
|    fps             | 1941   |
|    iterations      | 1      |
|    time_elapsed    | 131    |
|    total_timesteps | 256000 |
-------------------------------


<stable_baselines3.ppo.ppo.PPO at 0x2944f3190>

In [14]:
subproc.close()

In [20]:
test_env = create_env(interactive=True)
test_env.reset()

({'screen': array([[100, 100,  91, ...,   0,   0,   0],
         [163, 163,  79, ...,   0,   0,   0],
         [ 43,  63,  50, ...,   0,   0,   0],
         ...,
         [ 97,  97,  97, ...,   0,   0,   0],
         [ 97,  97,  97, ...,   0,   0,   0],
         [ 86,  86,  86, ...,   0,   0,   0]], dtype=uint8),
  'position': array([0., 0.], dtype=float16),
  'pokemon_level': array([0, 0, 0, 0, 0, 0], dtype=uint8)},
 {'tick': 0,
  'pokemon_level': array([5, 0, 0, 0, 0, 0]),
  'badges': array(0),
  'position': array([ 5,  5, 40]),
  'absolute_position': array([26, -1]),
  'owned_pokemon': array([8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
  'start_combat': False,
  'rewardHistory': deque([], maxlen=10)})

In [26]:
from tqdm import tqdm

test_env = create_env(interactive=False)
obs, _ = test_env.reset()

dict_actions_counter = {} # Dictionary to store the number of times each action is taken
model = ppo.PPO.load("../models/7cxs6l8i/model.zip")

for _ in tqdm(range(600)):
    action, _ = model.predict(obs)
    obs, reward, _, _, _ = test_env.step(action)
    if int(action) not in dict_actions_counter:
        dict_actions_counter[int(action)] = 1
    else:
        dict_actions_counter[int(action)] += 1

test_env.close()


  return self.fget.__get__(instance, owner)()
100%|██████████| 600/600 [00:01<00:00, 382.18it/s]


In [27]:
dict_actions_counter

{5: 600}

: 

In [140]:
env = create_env(interactive=False)
env.reset()

({'screen': array([[142, 142, 130, ...,   0,   0,   0],
         [232, 232, 112, ...,   0,   0,   0],
         [ 62,  90,  71, ...,   0,   0,   0],
         ...,
         [138, 138, 138, ...,   0,   0,   0],
         [138, 138, 138, ...,   0,   0,   0],
         [123, 123, 123, ...,   0,   0,   0]], dtype=uint8),
  'position': array([0., 0.], dtype=float16),
  'pokemon_level': array([0, 0, 0, 0, 0, 0], dtype=uint8)},
 {'tick': 0,
  'pokemon_level': array([5, 0, 0, 0, 0, 0]),
  'badges': array(0),
  'position': array([ 5,  5, 40]),
  'absolute_position': array([26, -1]),
  'owned_pokemon': array([8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
  'rewardHistory': deque([], maxlen=10000)})

In [12]:
env.helper.get_max_hp_pokemon(0)

  logger.warn(


19

In [15]:
from tqdm import tqdm

env = create_env(interactive=False)
env.reset()
for _ in tqdm(range(2000)):
    obs, reward, _, _, _ = env.step(2)
env.close()

100%|██████████| 2000/2000 [00:07<00:00, 278.22it/s]


In [20]:
NB_CPUS = 16
EP_LENGTH = 2**14 # 16384
NB_UPDATES = 2**8 # 256
TOTAL_TIMESTEPS = EP_LENGTH*NB_CPUS*NB_UPDATES
print(TOTAL_TIMESTEPS/(1800*3600))

10.356306172839506


In [21]:
0.99**200

0.13397967485796172