In [1]:
import gymnasium as gym
import numpy as np

from gymnasium.envs.classic_control.cartpole import CartPoleEnv
from gymnasium.wrappers.time_limit import TimeLimit
from sb3_contrib import RecurrentPPO
from stable_baselines3 import PPO
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import VecEnv, VecFrameStack

In [2]:
class PartiallyObservableCartPole(CartPoleEnv):
    def __init__(self):
        super().__init__()
        high = np.array([self.x_threshold * 2, self.theta_threshold_radians * 2], dtype=np.float32)
        self.observation_space = gym.spaces.Box(-high, high, dtype=np.float32)

    @staticmethod
    def _pos_obs(full_obs):
        xpos, _, thetapos, _ = full_obs
        return xpos, thetapos

    def reset(self, seed=None, options=None):
        full_obs, _ = super().reset(seed=seed)
        return PartiallyObservableCartPole._pos_obs(full_obs), {}

    def step(self, action):
        full_obs, rew, term, trunc, info = super().step(action)
        return PartiallyObservableCartPole._pos_obs(full_obs), rew, term, trunc, info

    def close(self):
        pass

In [3]:
def make_env(observable: bool = True):
    if observable:
        base_env = CartPoleEnv()
    else:
        base_env = PartiallyObservableCartPole()
    base_env = TimeLimit(base_env, max_episode_steps=200)
    return base_env

In [4]:
def bugfix_seed():
    return int(np.random.randint(0, np.iinfo(np.uint32).max, dtype=np.uint32))

In [5]:
def make_envs(observable: bool = True, n_envs=8):
    return make_vec_env(lambda: make_env(observable), n_envs=n_envs, seed=bugfix_seed())

In [6]:
observable_unstacked_env = make_envs()
observable_unstacked_env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [7]:
hidden_unstacked_env = make_envs(observable=False)
hidden_unstacked_env.observation_space

Box([-4.8        -0.41887903], [4.8        0.41887903], (2,), float32)

In [8]:
hidden_stacked_env = VecFrameStack(make_envs(observable=False), 10)
hidden_stacked_env.observation_space

Box([-4.8        -4.8        -4.8        -4.8        -4.8        -4.8
 -4.8        -4.8        -4.8        -4.8        -0.41887903 -0.41887903
 -0.41887903 -0.41887903 -0.41887903 -0.41887903 -0.41887903 -0.41887903
 -0.41887903 -0.41887903], [4.8        4.8        4.8        4.8        4.8        4.8
 4.8        4.8        4.8        4.8        0.41887903 0.41887903
 0.41887903 0.41887903 0.41887903 0.41887903 0.41887903 0.41887903
 0.41887903 0.41887903], (20,), float32)

In [9]:
def evaluate(model: BaseAlgorithm, eval_env: VecEnv):
    print(evaluate_policy(model, eval_env, n_eval_episodes=20, warn=False)[0])

def learn_and_evaluate(model: BaseAlgorithm, eval_env: VecEnv, n_steps: int = 5000, n_envs: int = 8):
    model.learn(n_steps * n_envs)
    evaluate(model, eval_env)

In [10]:
observable_unstacked_non_recurrent = PPO("MlpPolicy", observable_unstacked_env, verbose=1)
evaluate(observable_unstacked_non_recurrent, observable_unstacked_env)

Using cuda device
9.25


In [11]:
hidden_unstacked_non_recurrent = PPO("MlpPolicy", hidden_unstacked_env, verbose=1)
evaluate(hidden_unstacked_non_recurrent, hidden_unstacked_env)

Using cuda device
8.8


In [12]:

hidden_unstacked_recurrent = RecurrentPPO("MlpLstmPolicy", hidden_unstacked_env, verbose=1)
evaluate(hidden_unstacked_recurrent, hidden_unstacked_env)

Using cuda device
9.25


In [13]:
hidden_stacked_non_recurrent = PPO("MlpPolicy", hidden_stacked_env, verbose=1)
evaluate(hidden_stacked_non_recurrent, hidden_stacked_env)

Using cuda device
9.6


In [14]:
learn_and_evaluate(observable_unstacked_non_recurrent, observable_unstacked_env)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 22.8     |
|    ep_rew_mean     | 22.8     |
| time/              |          |
|    fps             | 6133     |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 16384    |
---------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 34.3       |
|    ep_rew_mean          | 34.3       |
| time/                   |            |
|    fps                  | 2278       |
|    iterations           | 2          |
|    time_elapsed         | 14         |
|    total_timesteps      | 32768      |
| train/                  |            |
|    approx_kl            | 0.01571438 |
|    clip_fraction        | 0.264      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.68      |
|    explained_variance   | 0.000233   |
|    learning_rate        | 0.0003     |
|   

In [15]:
learn_and_evaluate(hidden_unstacked_non_recurrent, hidden_unstacked_env)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 22.1     |
|    ep_rew_mean     | 22.1     |
| time/              |          |
|    fps             | 6014     |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 16384    |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 26.5        |
|    ep_rew_mean          | 26.5        |
| time/                   |             |
|    fps                  | 2414        |
|    iterations           | 2           |
|    time_elapsed         | 13          |
|    total_timesteps      | 32768       |
| train/                  |             |
|    approx_kl            | 0.006602927 |
|    clip_fraction        | 0.0795      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.688      |
|    explained_variance   | 0.000113    |
|    learning_rate        | 0.

In [16]:
learn_and_evaluate(hidden_unstacked_recurrent, hidden_unstacked_env)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 21.7     |
|    ep_rew_mean     | 21.7     |
| time/              |          |
|    fps             | 2845     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 1024     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 22.4         |
|    ep_rew_mean          | 22.4         |
| time/                   |              |
|    fps                  | 434          |
|    iterations           | 2            |
|    time_elapsed         | 4            |
|    total_timesteps      | 2048         |
| train/                  |              |
|    approx_kl            | 0.0018576242 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.693       |
|    explained_variance   | 0.000757     |
|    learning_r

In [17]:
learn_and_evaluate(hidden_stacked_non_recurrent, hidden_stacked_env)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 21.8     |
|    ep_rew_mean     | 21.8     |
| time/              |          |
|    fps             | 6204     |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 16384    |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 25.7         |
|    ep_rew_mean          | 25.7         |
| time/                   |              |
|    fps                  | 2364         |
|    iterations           | 2            |
|    time_elapsed         | 13           |
|    total_timesteps      | 32768        |
| train/                  |              |
|    approx_kl            | 0.0065383962 |
|    clip_fraction        | 0.0775       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.687       |
|    explained_variance   | -0.00131     |
|    learning_r