In [2]:
from gymnasium import Env
from gymnasium.wrappers import GrayScaleObservation, ResizeObservation, TimeLimit

from pokerl.env.pokemonblue import PokemonBlueEnv
from pokerl.env.wrappers import (
    ObservationAddPokemonLevel,
    ObservationAddPosition,
    ObservationDict,
    RewardDecreasingNoChange,
    RewardDecreasingSteps,
    RewardHistoryToInfo,
    RewardIncreasingBadges,
    RewardIncreasingCapturePokemon,
    RewardIncreasingPokemonLevel,
    RewardIncreasingPositionExploration,
    ppFlattenInfo,
)

In [3]:
BATCH_SIZE = 2048
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
STEP_LIMIT = 10000

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:

def create_env(interactive=False) -> Env:
    env = PokemonBlueEnv(interactive=interactive)
    # Setting observation
    env = ResizeObservation(env, 64)
    env = GrayScaleObservation(env)
    env = ObservationDict(env)
    env = ObservationAddPosition(env)
    env = ObservationAddPokemonLevel(env)
    # Setting reward
    env = RewardDecreasingNoChange(env, 10)
    env = RewardDecreasingSteps(env, .01)
    env = RewardIncreasingBadges(env, 100)
    env = RewardIncreasingCapturePokemon(env, 10)
    env = RewardIncreasingPokemonLevel(env, 10)
    # env = RewardIncreasingPositionExploration(env, 1)
    env = RewardHistoryToInfo(env)
    # Post processing
    env = TimeLimit(env, 10000)
    env = ppFlattenInfo(env)
    return env

In [6]:
env = create_env()

In [7]:
import torch
from stable_baselines3 import ppo
from stable_baselines3.common.env_util import make_vec_env

from pokerl.agent.tools import get_device

policy_kwargs = {"activation_fn": torch.nn.ReLU, "net_arch": {"pi": [256, 256, 256, 256], "vf": [256, 256, 256, 256]}}

ppo = ppo.PPO(
    "MultiInputPolicy", 
    env,  
    device=get_device(), 
    verbose=1
    )


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [None]:
from wandb.integration.sb3 import WandbCallback
import wandb

# config = {
#     "policy_type": "MultiInputPolicy",
#     "total_timesteps": 5000,
#     "env_name": "PokemonBlueEnv-v1",
# }

# run = wandb.init(
#     project="sb3",
#     config=config,
#     sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
#     monitor_gym=True,  # auto-upload the videos of agents playing the game
#     save_code=True,  # optional
# )

ppo.learn(total_timesteps=50000, 
          progress_bar=True, 
        #   callback=WandbCallback(),
          )

In [14]:
from matplotlib import pyplot as plt
from tqdm import tqdm

test_env = create_env(interactive=True)
obs, _ = test_env.reset()
for _ in range(1000):
    ppo.predict(obs)
    action, _ = ppo.predict(obs)
    obs, reward, _, _, _ = test_env.step(action)


In [22]:
from tqdm import tqdm

for _ in tqdm(range(2000)):

    obs, reward, _, _, _ = env.step(1)

100%|██████████| 2000/2000 [00:04<00:00, 437.77it/s]


In [20]:
env.pyboy._rendering(True)

  logger.warn(
