In [3]:
from gymnasium import Env
from gymnasium.wrappers import GrayScaleObservation, ResizeObservation, TimeLimit

from pokerl.env.pokemonblue import PokemonBlueEnv
from pokerl.env.wrappers import (
    ObservationAddPokemonLevel,
    ObservationAddPosition,
    ObservationDict,
    RewardDecreasingNoChange,
    RewardDecreasingSteps,
    RewardHistoryToInfo,
    RewardIncreasingBadges,
    RewardIncreasingCapturePokemon,
    RewardIncreasingPokemonLevel,
    RewardIncreasingPositionExploration,
    ppFlattenInfo,
)

In [8]:
BATCH_SIZE = 2048
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
STEP_LIMIT = 10000

In [4]:
%load_ext autoreload
%autoreload 2

In [23]:

def create_env(interactive=False) -> Env:
    env = PokemonBlueEnv(interactive=interactive)
    # Setting observation
    env = ResizeObservation(env, 64)
    env = GrayScaleObservation(env)
    env = ObservationDict(env)
    env = ObservationAddPosition(env)
    env = ObservationAddPokemonLevel(env)
    # Setting reward
    env = RewardDecreasingNoChange(env, 10)
    env = RewardDecreasingSteps(env, .01)
    env = RewardIncreasingBadges(env, 100)
    env = RewardIncreasingCapturePokemon(env, 10)
    env = RewardIncreasingPokemonLevel(env, 10)
    # env = RewardIncreasingPositionExploration(env, 1)
    env = RewardHistoryToInfo(env)
    # Post processing
    env = TimeLimit(env, 10000)
    env = ppFlattenInfo(env)
    return env

In [6]:
env = create_env()

In [17]:
import torch
from stable_baselines3 import ppo
from stable_baselines3.common.env_util import make_vec_env

from pokerl.agent.tools import get_device

policy_kwargs = {"activation_fn": torch.nn.ReLU, "net_arch": {"pi": [256, 256, 256, 256], "vf": [256, 256, 256, 256]}}

ppo = ppo.PPO(
    "MultiInputPolicy", 
    env,  
    device=get_device(), 
    verbose=1
    )


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [18]:
ppo.learn(total_timesteps=100000, progress_bar=True)

Output()

-----------------------------
| time/              |      |
|    fps             | 40   |
|    iterations      | 1    |
|    time_elapsed    | 50   |
|    total_timesteps | 2048 |
-----------------------------


-----------------------------------------
| time/                   |             |
|    fps                  | 39          |
|    iterations           | 2           |
|    time_elapsed         | 105         |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.014815954 |
|    clip_fraction        | 0.0946      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.93       |
|    explained_variance   | 0.0116      |
|    learning_rate        | 0.0003      |
|    loss                 | 587         |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.00848    |
|    value_loss           | 1.45e+03    |
-----------------------------------------


-----------------------------------------
| time/                   |             |
|    fps                  | 39          |
|    iterations           | 3           |
|    time_elapsed         | 156         |
|    total_timesteps      | 6144        |
| train/                  |             |
|    approx_kl            | 0.014936582 |
|    clip_fraction        | 0.063       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.93       |
|    explained_variance   | 0.00246     |
|    learning_rate        | 0.0003      |
|    loss                 | 605         |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.00505    |
|    value_loss           | 1.43e+03    |
-----------------------------------------


-----------------------------------------
| time/                   |             |
|    fps                  | 39          |
|    iterations           | 4           |
|    time_elapsed         | 205         |
|    total_timesteps      | 8192        |
| train/                  |             |
|    approx_kl            | 0.014406692 |
|    clip_fraction        | 0.147       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.92       |
|    explained_variance   | -0.000311   |
|    learning_rate        | 0.0003      |
|    loss                 | 1.73e+03    |
|    n_updates            | 30          |
|    policy_gradient_loss | -0.0167     |
|    value_loss           | 4.21e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1e+04       |
|    ep_rew_mean          | 3.07e+04    |
| time/                   |             |
|    fps                  | 38          |
|    iterations           | 5           |
|    time_elapsed         | 263         |
|    total_timesteps      | 10240       |
| train/                  |             |
|    approx_kl            | 0.019709725 |
|    clip_fraction        | 0.218       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.87       |
|    explained_variance   | -0.000145   |
|    learning_rate        | 0.0003      |
|    loss                 | 2.82e+03    |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.0238     |
|    value_loss           | 5.85e+03    |
-----------------------------------------


<stable_baselines3.ppo.ppo.PPO at 0x7f4ce43cd450>

In [24]:
from matplotlib import pyplot as plt
from tqdm import tqdm

test_env = create_env(interactive=True)
obs, _ = test_env.reset()
for _ in range(200):
    ppo.predict(obs)
    action, _ = ppo.predict(obs)
    obs, reward, _, _, _ = test_env.step(action)


KeyboardInterrupt: 