In [11]:
STEP_LIMIT = 50

In [12]:
from typing import Any

from gymnasium import Env, Wrapper, spaces


class ObservationRemoveScreen(Wrapper):
    """Wrapper for reward based on pokemon level"""

    def __init__(self, env: Env):
        super().__init__(env)
        if isinstance(env.observation_space, spaces.Dict):
            # We add the position to the observation space dict
            d_obs_space = env.observation_space.spaces
            d_obs_space.pop("screen")
            self.observation_space = spaces.Dict(d_obs_space)
        else:
            raise Exception("You should wrap your env in ObservationDict before using ObservationAddPosition")

    def step(self, action):
        observation, reward, truncated, terminated, info = self.env.step(action)
        observation.pop("screen")
        return observation, reward, truncated, terminated, info

    def reset(
        self, *, seed: int | None = None, options: dict[str, Any] | None = None
    ) -> tuple[dict[str, Any], dict[str, Any]]:
        observation, info = self.env.reset(seed=seed, options=options)
        observation.pop("screen")
        return observation, info

In [13]:
import gymnasium as gym
from gymnasium.wrappers import (
    AutoResetWrapper,
    FlattenObservation,
    GrayScaleObservation,
    NormalizeObservation,
    RecordEpisodeStatistics,
    ResizeObservation,
    TimeLimit,
)

from pokerl.env.pokemonblue import PokemonBlueEnv
from pokerl.env.wrappers import (
    ObservationAddPokemonLevel,
    ObservationAddPosition,
    ObservationDict,
    RemoveABAction,
    RemoveSelectStartAction,
    RewardCheckpoint,
    RewardDecreasingNoChange,
    RewardDecreasingSteps,
    RewardHistoryToInfo,
    RewardIncreasingBadges,
    RewardIncreasingCapturePokemon,
    RewardIncreasingPokemonLevel,
    RewardIncreasingPositionExploration,
    RewardStopCheckpoint,
    StopAtPokemon,
    ppFlattenInfo,
)


def create_env(interactive=False) -> gym.Env:
    env = PokemonBlueEnv(
        interactive=interactive,
        save_state="game_start"
    )
    # Setting observation
    env = ResizeObservation(env, 128)
    # env = GrayScaleObservation(env)
    env = ObservationDict(env)
    env = ObservationAddPosition(env)
    env = ObservationRemoveScreen(env)
    env = RemoveABAction(env)
    env = RemoveSelectStartAction(env)
    # env = ObservationAddPokemonLevel(env)
    # Setting reward
    # env = RewardDecreasingNoChange(env, .001)
    env = RewardDecreasingSteps(env, 1)
    # env = RewardIncreasingBadges(env, 100)
    # env = RewardIncreasingCapturePokemon(env, 10)
    # env = RewardIncreasingPokemonLevel(env, 10)
    # env = RewardIncreasingPositionExploration(env, 1)
    # env = RewardHistoryToInfo(env)
    # Post processing
    env = RewardStopCheckpoint(env)
    env = TimeLimit(env, STEP_LIMIT)
    # env = AutoResetWrapper(env)
    env = FlattenObservation(env)
    # env = NormalizeObservation(env)

    # env = ppFlattenInfo(env)
    return env

In [14]:
from stable_baselines3 import a2c, dqn, ppo, td3
from stable_baselines3.common.env_util import make_vec_env

from pokerl.agent.tools import get_device

env = create_env()

ppo = ppo.PPO(
    "MlpPolicy",
    env,
    device="cpu",
    verbose=1
)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [15]:
ppo = ppo.learn(STEP_LIMIT*1000, progress_bar=True)

Output()

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 50       |
|    ep_rew_mean     | -50      |
| time/              |          |
|    fps             | 67       |
|    iterations      | 1        |
|    time_elapsed    | 30       |
|    total_timesteps | 2048     |
---------------------------------


------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 50           |
|    ep_rew_mean          | -50          |
| time/                   |              |
|    fps                  | 66           |
|    iterations           | 2            |
|    time_elapsed         | 61           |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0011681395 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.61        |
|    explained_variance   | -0.000418    |
|    learning_rate        | 0.0003       |
|    loss                 | 9.24         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.000945    |
|    value_loss           | 88.8         |
------------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -50         |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 3           |
|    time_elapsed         | 92          |
|    total_timesteps      | 6144        |
| train/                  |             |
|    approx_kl            | 0.008303616 |
|    clip_fraction        | 0.00537     |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.6        |
|    explained_variance   | 0.00854     |
|    learning_rate        | 0.0003      |
|    loss                 | 7           |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.0011     |
|    value_loss           | 38.6        |
-----------------------------------------


------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 50           |
|    ep_rew_mean          | -50          |
| time/                   |              |
|    fps                  | 66           |
|    iterations           | 4            |
|    time_elapsed         | 123          |
|    total_timesteps      | 8192         |
| train/                  |              |
|    approx_kl            | 0.0001157996 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.59        |
|    explained_variance   | 0.00258      |
|    learning_rate        | 0.0003       |
|    loss                 | 5.91         |
|    n_updates            | 30           |
|    policy_gradient_loss | -6.65e-05    |
|    value_loss           | 42.8         |
------------------------------------------


-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 50            |
|    ep_rew_mean          | -50           |
| time/                   |               |
|    fps                  | 66            |
|    iterations           | 5             |
|    time_elapsed         | 154           |
|    total_timesteps      | 10240         |
| train/                  |               |
|    approx_kl            | 0.00020708132 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.59         |
|    explained_variance   | -0.000439     |
|    learning_rate        | 0.0003        |
|    loss                 | 6.56          |
|    n_updates            | 40            |
|    policy_gradient_loss | -0.000287     |
|    value_loss           | 34.2          |
-------------------------------------------


------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 50           |
|    ep_rew_mean          | -50          |
| time/                   |              |
|    fps                  | 66           |
|    iterations           | 6            |
|    time_elapsed         | 186          |
|    total_timesteps      | 12288        |
| train/                  |              |
|    approx_kl            | 0.0003057632 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.6         |
|    explained_variance   | -0.00026     |
|    learning_rate        | 0.0003       |
|    loss                 | 3.66         |
|    n_updates            | 50           |
|    policy_gradient_loss | -0.000382    |
|    value_loss           | 26.7         |
------------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 49.9        |
|    ep_rew_mean          | -49.9       |
| time/                   |             |
|    fps                  | 65          |
|    iterations           | 7           |
|    time_elapsed         | 217         |
|    total_timesteps      | 14336       |
| train/                  |             |
|    approx_kl            | 0.001553888 |
|    clip_fraction        | 0           |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.6        |
|    explained_variance   | 0.000735    |
|    learning_rate        | 0.0003      |
|    loss                 | 2.49        |
|    n_updates            | 60          |
|    policy_gradient_loss | -0.000534   |
|    value_loss           | 20.1        |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 49.9        |
|    ep_rew_mean          | -49.9       |
| time/                   |             |
|    fps                  | 65          |
|    iterations           | 8           |
|    time_elapsed         | 248         |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.006215332 |
|    clip_fraction        | 0.00674     |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.6        |
|    explained_variance   | 0.00104     |
|    learning_rate        | 0.0003      |
|    loss                 | 2.46        |
|    n_updates            | 70          |
|    policy_gradient_loss | -0.0016     |
|    value_loss           | 15          |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 49.8        |
|    ep_rew_mean          | -49.8       |
| time/                   |             |
|    fps                  | 65          |
|    iterations           | 9           |
|    time_elapsed         | 280         |
|    total_timesteps      | 18432       |
| train/                  |             |
|    approx_kl            | 0.007993879 |
|    clip_fraction        | 0.0186      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.57       |
|    explained_variance   | 0.000982    |
|    learning_rate        | 0.0003      |
|    loss                 | 1.52        |
|    n_updates            | 80          |
|    policy_gradient_loss | -0.00212    |
|    value_loss           | 11.4        |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 49.9       |
|    ep_rew_mean          | -49.9      |
| time/                   |            |
|    fps                  | 65         |
|    iterations           | 10         |
|    time_elapsed         | 310        |
|    total_timesteps      | 20480      |
| train/                  |            |
|    approx_kl            | 0.00189585 |
|    clip_fraction        | 0.00298    |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.56      |
|    explained_variance   | 0.000865   |
|    learning_rate        | 0.0003     |
|    loss                 | 1.09       |
|    n_updates            | 90         |
|    policy_gradient_loss | -0.000228  |
|    value_loss           | 8.19       |
----------------------------------------


------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 50           |
|    ep_rew_mean          | -50          |
| time/                   |              |
|    fps                  | 66           |
|    iterations           | 11           |
|    time_elapsed         | 341          |
|    total_timesteps      | 22528        |
| train/                  |              |
|    approx_kl            | 0.0111182965 |
|    clip_fraction        | 0.0235       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.54        |
|    explained_variance   | 0.000572     |
|    learning_rate        | 0.0003       |
|    loss                 | 1.07         |
|    n_updates            | 100          |
|    policy_gradient_loss | -0.00242     |
|    value_loss           | 5.96         |
------------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -50         |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 12          |
|    time_elapsed         | 371         |
|    total_timesteps      | 24576       |
| train/                  |             |
|    approx_kl            | 0.017756991 |
|    clip_fraction        | 0.0387      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.52       |
|    explained_variance   | 0.000973    |
|    learning_rate        | 0.0003      |
|    loss                 | 0.939       |
|    n_updates            | 110         |
|    policy_gradient_loss | -0.00376    |
|    value_loss           | 4.32        |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -50         |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 13          |
|    time_elapsed         | 402         |
|    total_timesteps      | 26624       |
| train/                  |             |
|    approx_kl            | 0.009048158 |
|    clip_fraction        | 0.00576     |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.51       |
|    explained_variance   | 0.000538    |
|    learning_rate        | 0.0003      |
|    loss                 | 0.598       |
|    n_updates            | 120         |
|    policy_gradient_loss | -0.00104    |
|    value_loss           | 3.14        |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -50         |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 14          |
|    time_elapsed         | 432         |
|    total_timesteps      | 28672       |
| train/                  |             |
|    approx_kl            | 0.011180566 |
|    clip_fraction        | 0.0249      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.53       |
|    explained_variance   | 0.000868    |
|    learning_rate        | 0.0003      |
|    loss                 | 0.411       |
|    n_updates            | 130         |
|    policy_gradient_loss | -0.00305    |
|    value_loss           | 2.28        |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -50         |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 15          |
|    time_elapsed         | 463         |
|    total_timesteps      | 30720       |
| train/                  |             |
|    approx_kl            | 0.009146142 |
|    clip_fraction        | 0.0114      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.56       |
|    explained_variance   | 0.00065     |
|    learning_rate        | 0.0003      |
|    loss                 | 0.313       |
|    n_updates            | 140         |
|    policy_gradient_loss | -0.00194    |
|    value_loss           | 1.64        |
-----------------------------------------


------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 49.6         |
|    ep_rew_mean          | -49.6        |
| time/                   |              |
|    fps                  | 66           |
|    iterations           | 16           |
|    time_elapsed         | 493          |
|    total_timesteps      | 32768        |
| train/                  |              |
|    approx_kl            | 0.0033254651 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.58        |
|    explained_variance   | 0.00108      |
|    learning_rate        | 0.0003       |
|    loss                 | 0.334        |
|    n_updates            | 150          |
|    policy_gradient_loss | -0.00058     |
|    value_loss           | 1.2          |
------------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 49.6        |
|    ep_rew_mean          | -49.6       |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 17          |
|    time_elapsed         | 524         |
|    total_timesteps      | 34816       |
| train/                  |             |
|    approx_kl            | 0.013477413 |
|    clip_fraction        | 0.0484      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.56       |
|    explained_variance   | 0.00125     |
|    learning_rate        | 0.0003      |
|    loss                 | 0.234       |
|    n_updates            | 160         |
|    policy_gradient_loss | -0.00429    |
|    value_loss           | 0.885       |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 49.8        |
|    ep_rew_mean          | -49.8       |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 18          |
|    time_elapsed         | 554         |
|    total_timesteps      | 36864       |
| train/                  |             |
|    approx_kl            | 0.013034158 |
|    clip_fraction        | 0.0381      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.53       |
|    explained_variance   | 0.000739    |
|    learning_rate        | 0.0003      |
|    loss                 | 0.146       |
|    n_updates            | 170         |
|    policy_gradient_loss | -0.00444    |
|    value_loss           | 0.66        |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -50         |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 19          |
|    time_elapsed         | 585         |
|    total_timesteps      | 38912       |
| train/                  |             |
|    approx_kl            | 0.008742074 |
|    clip_fraction        | 0.0336      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.55       |
|    explained_variance   | 0.00104     |
|    learning_rate        | 0.0003      |
|    loss                 | 0.139       |
|    n_updates            | 180         |
|    policy_gradient_loss | -0.0043     |
|    value_loss           | 0.474       |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -50         |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 20          |
|    time_elapsed         | 615         |
|    total_timesteps      | 40960       |
| train/                  |             |
|    approx_kl            | 0.014233801 |
|    clip_fraction        | 0.00864     |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.56       |
|    explained_variance   | 0.000974    |
|    learning_rate        | 0.0003      |
|    loss                 | 0.101       |
|    n_updates            | 190         |
|    policy_gradient_loss | -0.0029     |
|    value_loss           | 0.352       |
-----------------------------------------


------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 50           |
|    ep_rew_mean          | -50          |
| time/                   |              |
|    fps                  | 66           |
|    iterations           | 21           |
|    time_elapsed         | 646          |
|    total_timesteps      | 43008        |
| train/                  |              |
|    approx_kl            | 0.0075585647 |
|    clip_fraction        | 0.0193       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.57        |
|    explained_variance   | 0.000922     |
|    learning_rate        | 0.0003       |
|    loss                 | 0.0453       |
|    n_updates            | 200          |
|    policy_gradient_loss | -0.00274     |
|    value_loss           | 0.266        |
------------------------------------------


------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 50           |
|    ep_rew_mean          | -50          |
| time/                   |              |
|    fps                  | 66           |
|    iterations           | 22           |
|    time_elapsed         | 676          |
|    total_timesteps      | 45056        |
| train/                  |              |
|    approx_kl            | 0.0071426425 |
|    clip_fraction        | 0.0268       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.56        |
|    explained_variance   | 0.00122      |
|    learning_rate        | 0.0003       |
|    loss                 | 0.0342       |
|    n_updates            | 210          |
|    policy_gradient_loss | -0.00363     |
|    value_loss           | 0.205        |
------------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -50         |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 23          |
|    time_elapsed         | 706         |
|    total_timesteps      | 47104       |
| train/                  |             |
|    approx_kl            | 0.010538045 |
|    clip_fraction        | 0.0115      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.55       |
|    explained_variance   | 0.000434    |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0585      |
|    n_updates            | 220         |
|    policy_gradient_loss | -0.00252    |
|    value_loss           | 0.155       |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -50         |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 24          |
|    time_elapsed         | 737         |
|    total_timesteps      | 49152       |
| train/                  |             |
|    approx_kl            | 0.009751833 |
|    clip_fraction        | 0.0539      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.56       |
|    explained_variance   | 0.000398    |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0266      |
|    n_updates            | 230         |
|    policy_gradient_loss | -0.00507    |
|    value_loss           | 0.117       |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -50         |
| time/                   |             |
|    fps                  | 66          |
|    iterations           | 25          |
|    time_elapsed         | 767         |
|    total_timesteps      | 51200       |
| train/                  |             |
|    approx_kl            | 0.010780363 |
|    clip_fraction        | 0.0692      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.57       |
|    explained_variance   | 0.00103     |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0222      |
|    n_updates            | 240         |
|    policy_gradient_loss | -0.00707    |
|    value_loss           | 0.0895      |
-----------------------------------------


In [6]:
env = create_env(True)

In [7]:
from matplotlib import pyplot as plt
from tqdm import tqdm

obs, _ = env.reset()
for _ in range(200):
    ppo.predict(obs)
    action, _ = ppo.predict(obs)
    obs, reward, _, _, _ = env.step(action)
    print(f"Observation is: {obs}")
    print(f"Reward is: {reward}")

Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward is: -1
Observation is: [0 0]
Reward

In [None]:
last_info = env.tick()
while True:
    info = env.tick()
    for k, v in info.items():
        if "tick" in k:
            continue
        if (last_info[k] != v).any():
            print(f"{k}: {v}")
            last_info = info
    pass

  logger.warn(


position: [ 6  3 38]
position: [ 6  4 38]
position: [ 6  5 38]
position: [ 5  5 38]
position: [ 3  5 38]
position: [ 2  5 38]
position: [ 1  5 38]
position: [ 1  6 38]
position: [ 1  7 37]
position: [ 2  7 37]
position: [ 3  7 37]
position: [ 4  7 37]
position: [ 5  7 37]
position: [ 6  6 37]
position: [ 6  5 37]
position: [ 6  4 37]
position: [ 7  3 37]
position: [7 3 0]
position: [5 5 0]
position: [6 5 0]
position: [7 5 0]
position: [7 6 0]


KeyboardInterrupt: 

In [None]:
obs, info = env.reset()

In [None]:
info

{'tick': 0,
 'pokemon_level': array([0, 0, 0, 0, 0, 0]),
 'badges': array(0),
 'position': array([0, 0, 0]),
 'absolute_position': array([0, 0]),
 'owned_pokemon': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'rewardHistory': deque([], maxlen=10000)}

In [None]:
env = PokemonBlueEnv(interactive=True)

In [None]:
from tqdm import tqdm

observation, info = env.reset()
done = False
for i in tqdm(range(1000)):
    action, _states = ppo.predict(observation)
    observation, reward, truncated, terminated, info = env.step(action)


  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:11<00:00, 88.53it/s]


TypeError: '<' not supported between instances of 'int' and 'dict'

In [10]:
import gymnasium as gym
from stable_baselines3 import a2c, dqn, ppo, td3
from stable_baselines3.common.env_util import make_vec_env

from pokerl.agent.tools import get_device

env = gym.make("Acrobot-v1")

agent = ppo.PPO(
    "MlpPolicy",
    env,
    device="cpu",
    verbose=1
)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [13]:
agent = agent.learn(100000)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 195      |
|    ep_rew_mean     | -194     |
| time/              |          |
|    fps             | 1795     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 193          |
|    ep_rew_mean          | -192         |
| time/                   |              |
|    fps                  | 1359         |
|    iterations           | 2            |
|    time_elapsed         | 3            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0059554596 |
|    clip_fraction        | 0.0559       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.972       |
|    explained_variance   | 0.304        |
|    learning_r

In [15]:
from matplotlib import pyplot as plt
from tqdm import tqdm

env = gym.make("Acrobot-v1", render_mode="human")
obs, info = env.reset()
print(obs, info)
for _ in range(500):
    action, states = agent.predict(obs)
    print(action, states)
    obs, reward, _, _, _ = env.step(action)
    print(f"Observation is: {obs}")
    print(f"Reward is: {reward}")

[1.0000000e+00 3.9915062e-06 9.9776387e-01 6.6837676e-02 3.5921879e-02
 9.4700865e-02] {}
2 None
Observation is: [ 0.99998754 -0.0049965   0.993651    0.11250681 -0.08324838  0.35477015]
Reward is: -1.0
2 None
Observation is: [ 0.9995197  -0.03099017  0.9795182   0.20135584 -0.16842666  0.5252706 ]
Reward is: -1.0
2 None
Observation is: [ 0.9976735  -0.06817307  0.9513012   0.30826288 -0.1929538   0.55648196]
Reward is: -1.0
2 None
Observation is: [ 0.9946283  -0.10351071  0.9149288   0.40361533 -0.15201782  0.44306067]
Reward is: -1.0
2 None
Observation is: [ 0.9921127  -0.12534912  0.88551056  0.46461922 -0.06212005  0.2213518 ]
Reward is: -1.0
2 None
Observation is: [ 0.99192476 -0.1268278   0.87726724  0.4800023   0.0476535  -0.04920846]
Reward is: -1.0
0 None
Observation is: [ 0.99662924 -0.08203755  0.92156583  0.38822216  0.39282143 -0.95148295]
Reward is: -1.0
0 None
Observation is: [ 0.99974364  0.02264208  0.99093354  0.1343528   0.62721616 -1.6304858 ]
Reward is: -1.0
0 None