In [396]:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from gymnasium.wrappers import TimeLimit
from stable_baselines3.common.evaluation import evaluate_policy
import gymnasium as gym
import os
import numpy as np

In [397]:
class CustomEnv(gym.Env):
    def __init__(self):
        super(CustomEnv, self).__init__()
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=float)
        self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=float)
        self.state: float = np.random.rand() / 10.0

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.state = np.random.rand() / 10.0
        return np.array([self.state]), {}

    def step(self, action):
        self.state = self.state * 4 + action[0] + np.random.rand() / 100.0
        obs = np.array([self.state])
        reward = 1
        if obs < -0.9 or obs > 0.9:
            reward = -10
            done = True
        else:
            done = False
        truncated = False
        info = {}
        return obs, reward, done, truncated, info
    
    def render(self, mode='human'):
        if mode == 'human':
            print(f"State: {self.state}")

    def close(self):
        pass

In [398]:
env = CustomEnv()
env = TimeLimit(env, max_episode_steps=50)

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1.86     |
|    ep_rew_mean     | -9.14    |
| time/              |          |
|    fps             | 482      |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.97        |
|    ep_rew_mean          | -9.03       |
| time/                   |             |
|    fps                  | 399         |
|    iterations           | 2           |
|    time_elapsed         | 10          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.033229586 |
|    clip_fraction        | 0.152       |
|    clip_range           | 0.2         |
|    entropy_loss  

<stable_baselines3.ppo.ppo.PPO at 0x709d18209d10>

In [399]:
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100)
print(f'{mean_reward} +/- {std_reward}')

50.0 +/- 0.0


In [401]:
vec_env = model.get_env()
for i in range(1):
    obs = vec_env.reset()
    done = False
    j = 0
    while not done:
        action = model.predict(obs, deterministic=True)
        obs, reward, done, info = vec_env.step(action)
        env.render()
        # if obs[0] <= 0:
        #     print(j, ('.'*int(abs(obs[0])*50)).rjust(50))
        # else:
        #     print(j, ' '*50 + '.'*int(obs[0]*50))
        # j += 1

State: [-0.02468174]
State: [-0.01319472]
State: [-0.01504221]
State: [-0.01620372]
State: [-0.01676523]
State: [-0.01069554]
State: [-0.01453485]
State: [-0.01079907]
State: [-0.01162424]
State: [-0.01670368]
State: [-0.01919581]
State: [-0.01916472]
State: [-0.01714253]
State: [-0.01740879]
State: [-0.01279667]
State: [-0.01144087]
State: [-0.01311047]
State: [-0.01712181]
State: [-0.01275099]
State: [-0.01706106]
State: [-0.0142115]
State: [-0.01830034]
State: [-0.01791967]
State: [-0.01620653]
State: [-0.01375387]
State: [-0.01426512]
State: [-0.01337674]
State: [-0.01525953]
State: [-0.01812968]
State: [-0.01227057]
State: [-0.01222229]
State: [-0.01213068]
State: [-0.01343221]
State: [-0.01181662]
State: [-0.01267438]
State: [-0.02035748]
State: [-0.01673805]
State: [-0.01915842]
State: [-0.01353774]
State: [-0.01708484]
State: [-0.01253509]
State: [-0.01262238]
State: [-0.01335819]
State: [-0.01543556]
State: [-0.01533253]
State: [-0.01429032]
State: [-0.01830551]
State: [-0.010