In [1]:
from gymnasium import Env
from gymnasium.wrappers import GrayScaleObservation, ResizeObservation, TimeLimit

from pokerl.env.pokemonblue import PokemonBlueEnv
from pokerl.env.wrappers import (
    ObservationAddPokemonLevel,
    ObservationAddPosition,
    ObservationDict,
    RewardDecreasingNoChange,
    RewardDecreasingSteps,
    RewardHistoryToInfo,
    RewardIncreasingBadges,
    RewardIncreasingCapturePokemon,
    RewardIncreasingPokemonLevel,
    RewardIncreasingPositionExploration,
    ppFlattenInfo,
)



In [2]:
BATCH_SIZE = 2048
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
STEP_LIMIT = 10000

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:

def create_env(interactive=False) -> Env:
    env = PokemonBlueEnv(interactive=interactive)
    # Setting observation
    env = ResizeObservation(env, 64)
    env = GrayScaleObservation(env)
    env = ObservationDict(env)
    env = ObservationAddPosition(env)
    env = ObservationAddPokemonLevel(env)
    # Setting reward
    env = RewardDecreasingNoChange(env, 10)
    env = RewardDecreasingSteps(env, .01)
    env = RewardIncreasingBadges(env, 100)
    env = RewardIncreasingCapturePokemon(env, 10)
    env = RewardIncreasingPokemonLevel(env, 10)
    # env = RewardIncreasingPositionExploration(env, 1)
    env = RewardHistoryToInfo(env)
    # Post processing
    env = TimeLimit(env, 10000)
    env = ppFlattenInfo(env)
    return env

In [5]:
env = create_env()

In [6]:
import torch
from stable_baselines3 import ppo
from stable_baselines3.common.env_util import make_vec_env

from pokerl.agent.tools import get_device

policy_kwargs = {"activation_fn": torch.nn.ReLU, "net_arch": {"pi": [256, 256, 256, 256], "vf": [256, 256, 256, 256]}}

ppo = ppo.PPO(
    "MultiInputPolicy", 
    env,  
    device=get_device(), 
    verbose=1
    )


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [7]:
from wandb.integration.sb3 import WandbCallback
import wandb

config = {
    "policy_type": "MultiInputPolicy",
    "total_timesteps": 5000,
    "env_name": "PokemonBlueEnv-v1",
}

run = wandb.init(
    project="sb3",
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    monitor_gym=True,  # auto-upload the videos of agents playing the game
    save_code=True,  # optional
)

ppo.learn(total_timesteps=500, progress_bar=True, callback=WandbCallback())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/baptistepugnaire/.netrc


Output()

-----------------------------
| time/              |      |
|    fps             | 61   |
|    iterations      | 1    |
|    time_elapsed    | 33   |
|    total_timesteps | 2048 |
-----------------------------


<stable_baselines3.ppo.ppo.PPO at 0x28e8ced70>

In [None]:
from matplotlib import pyplot as plt
from tqdm import tqdm

test_env = create_env(interactive=True)
obs, _ = test_env.reset()
for _ in range(200):
    ppo.predict(obs)
    action, _ = ppo.predict(obs)
    obs, reward, _, _, _ = test_env.step(action)


KeyboardInterrupt: 

In [22]:
from tqdm import tqdm

for _ in tqdm(range(2000)):

    obs, reward, _, _, _ = env.step(1)

100%|██████████| 2000/2000 [00:04<00:00, 437.77it/s]


In [20]:
env.pyboy._rendering(True)

  logger.warn(
