In [3]:
from gymnasium import Env
from gymnasium.wrappers import GrayScaleObservation, ResizeObservation, TimeLimit

from pokerl.env.pokemonblue import PokemonBlueEnv
from pokerl.env.wrappers import (
    ObservationAddPokemonLevel,
    ObservationAddPosition,
    ObservationDict,
    RewardDecreasingNoChange,
    RewardDecreasingSteps,
    RewardHistoryToInfo,
    RewardIncreasingBadges,
    RewardIncreasingCapturePokemon,
    RewardIncreasingPokemonLevel,
    RewardIncreasingPositionExploration,
    ppFlattenInfo,
)

In [8]:
BATCH_SIZE = 2048
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
STEP_LIMIT = 10000

In [4]:
%load_ext autoreload
%autoreload 2

In [23]:

def create_env(interactive=False) -> Env:
    env = PokemonBlueEnv(interactive=interactive)
    # Setting observation
    env = ResizeObservation(env, 64)
    env = GrayScaleObservation(env)
    env = ObservationDict(env)
    env = ObservationAddPosition(env)
    env = ObservationAddPokemonLevel(env)
    # Setting reward
    env = RewardDecreasingNoChange(env, 10)
    env = RewardDecreasingSteps(env, .01)
    env = RewardIncreasingBadges(env, 100)
    env = RewardIncreasingCapturePokemon(env, 10)
    env = RewardIncreasingPokemonLevel(env, 10)
    # env = RewardIncreasingPositionExploration(env, 1)
    env = RewardHistoryToInfo(env)
    # Post processing
    env = TimeLimit(env, 10000)
    env = ppFlattenInfo(env)
    return env

In [6]:
env = create_env()

In [17]:
import torch
from stable_baselines3 import ppo
from stable_baselines3.common.env_util import make_vec_env

from pokerl.agent.tools import get_device

policy_kwargs = {"activation_fn": torch.nn.ReLU, "net_arch": {"pi": [256, 256, 256, 256], "vf": [256, 256, 256, 256]}}

ppo = ppo.PPO(
    "MultiInputPolicy", 
    env,  
    device=get_device(), 
    verbose=1
    )


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [None]:
from wandb.integration.sb3 import WandbCallback
import wandb

config = {
    "policy_type": "MultiInputPolicy",
    "total_timesteps": 5000,
    "env_name": "PokemonBlueEnv-v1",
}

run = wandb.init(
    project="sb3",
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    monitor_gym=True,  # auto-upload the videos of agents playing the game
    save_code=True,  # optional
)

ppo.learn(total_timesteps=5000, progress_bar=True, callback=WandbCallback())

Output()

-----------------------------
| time/              |      |
|    fps             | 41   |
|    iterations      | 1    |
|    time_elapsed    | 49   |
|    total_timesteps | 2048 |
-----------------------------


-----------------------------------------
| time/                   |             |
|    fps                  | 41          |
|    iterations           | 2           |
|    time_elapsed         | 97          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.013846958 |
|    clip_fraction        | 0.0989      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.8        |
|    explained_variance   | -1.23e-05   |
|    learning_rate        | 0.0003      |
|    loss                 | 2.51e+03    |
|    n_updates            | 60          |
|    policy_gradient_loss | -0.0105     |
|    value_loss           | 5.01e+03    |
-----------------------------------------


-----------------------------------------
| time/                   |             |
|    fps                  | 39          |
|    iterations           | 3           |
|    time_elapsed         | 154         |
|    total_timesteps      | 6144        |
| train/                  |             |
|    approx_kl            | 0.017450126 |
|    clip_fraction        | 0.212       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.76       |
|    explained_variance   | 7.15e-07    |
|    learning_rate        | 0.0003      |
|    loss                 | 3.38e+03    |
|    n_updates            | 70          |
|    policy_gradient_loss | -0.0207     |
|    value_loss           | 7.34e+03    |
-----------------------------------------


-----------------------------------------
| time/                   |             |
|    fps                  | 39          |
|    iterations           | 4           |
|    time_elapsed         | 208         |
|    total_timesteps      | 8192        |
| train/                  |             |
|    approx_kl            | 0.014876718 |
|    clip_fraction        | 0.0876      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.74       |
|    explained_variance   | -2.86e-06   |
|    learning_rate        | 0.0003      |
|    loss                 | 2.04e+03    |
|    n_updates            | 80          |
|    policy_gradient_loss | -0.00814    |
|    value_loss           | 3.66e+03    |
-----------------------------------------


------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 1e+04        |
|    ep_rew_mean          | 4.21e+04     |
| time/                   |              |
|    fps                  | 38           |
|    iterations           | 5            |
|    time_elapsed         | 266          |
|    total_timesteps      | 10240        |
| train/                  |              |
|    approx_kl            | 0.0066578626 |
|    clip_fraction        | 0.0379       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.74        |
|    explained_variance   | -2.26e-05    |
|    learning_rate        | 0.0003       |
|    loss                 | 3.63e+03     |
|    n_updates            | 90           |
|    policy_gradient_loss | -0.00322     |
|    value_loss           | 7.34e+03     |
------------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1e+04       |
|    ep_rew_mean          | 4.21e+04    |
| time/                   |             |
|    fps                  | 37          |
|    iterations           | 6           |
|    time_elapsed         | 327         |
|    total_timesteps      | 12288       |
| train/                  |             |
|    approx_kl            | 0.007261129 |
|    clip_fraction        | 0.0285      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.7        |
|    explained_variance   | 6.56e-07    |
|    learning_rate        | 0.0003      |
|    loss                 | 2.69e+03    |
|    n_updates            | 100         |
|    policy_gradient_loss | -0.0021     |
|    value_loss           | 4.19e+03    |
-----------------------------------------


KeyboardInterrupt: 

In [None]:
from matplotlib import pyplot as plt
from tqdm import tqdm

test_env = create_env(interactive=True)
obs, _ = test_env.reset()
for _ in range(200):
    ppo.predict(obs)
    action, _ = ppo.predict(obs)
    obs, reward, _, _, _ = test_env.step(action)


KeyboardInterrupt: 