In [2]:
import numpy as np

from gymnasium.wrappers.time_limit import TimeLimit
from sb3_contrib import RecurrentPPO
from stable_baselines3 import PPO
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import VecEnv, VecFrameStack

import reinfocus.learning.focus_environment as env

In [3]:
def make_env(observable_type: env.ObservableType = env.ObservableType.FULL):
    base_env = env.FocusEnvironment(observable_type=observable_type)
    base_env = TimeLimit(base_env, max_episode_steps=200)
    return base_env

In [4]:
def bugfix_seed():
    return int(np.random.randint(0, np.iinfo(np.uint32).max, dtype=np.uint32))

In [5]:
def make_envs(observable_type: env.ObservableType, n_envs=8):
    return make_vec_env(lambda: make_env(observable_type), n_envs=n_envs, seed=bugfix_seed())

In [6]:
observable_unstacked_env = make_envs(env.ObservableType.FULL)
observable_unstacked_env.observation_space



Box(-1.0, 1.0, (3,), float32)

In [7]:
no_target_unstacked_env = make_envs(env.ObservableType.NO_TARGET)
no_target_unstacked_env.observation_space

Box(-1.0, 1.0, (2,), float32)

In [8]:
no_target_stacked_env = VecFrameStack(make_envs(env.ObservableType.NO_TARGET), 10)
no_target_stacked_env.observation_space

Box(-1.0, 1.0, (20,), float32)

In [9]:
def evaluate(model: BaseAlgorithm, eval_env: VecEnv):
    print(evaluate_policy(model, eval_env, n_eval_episodes=20, warn=False)[0])

def learn_and_evaluate(model: BaseAlgorithm, eval_env: VecEnv, n_steps: int = 5000, n_envs: int = 8):
    model.learn(n_steps * n_envs)
    evaluate(model, eval_env)

In [10]:
observable_unstacked_non_recurrent = PPO("MlpPolicy", observable_unstacked_env, verbose=1)
evaluate(observable_unstacked_non_recurrent, observable_unstacked_env)

Using cuda device
-44.917272649999994


In [11]:
no_target_unstacked_non_recurrent = PPO("MlpPolicy", no_target_unstacked_env, verbose=1)
evaluate(no_target_unstacked_non_recurrent, no_target_unstacked_env)

Using cuda device
-93.05241855


In [12]:
no_target_unstacked_recurrent = RecurrentPPO("MlpLstmPolicy", no_target_unstacked_env, verbose=1)
evaluate(no_target_unstacked_recurrent, no_target_unstacked_env)

Using cuda device
-72.8706235


In [13]:
no_target_stacked_non_recurrent = PPO("MlpPolicy", no_target_stacked_env, verbose=1)
evaluate(no_target_stacked_non_recurrent, no_target_stacked_env)

Using cuda device
-102.83595385000001


In [14]:
learn_and_evaluate(observable_unstacked_non_recurrent, observable_unstacked_env)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 200      |
|    ep_rew_mean     | -88.2    |
| time/              |          |
|    fps             | 75       |
|    iterations      | 1        |
|    time_elapsed    | 217      |
|    total_timesteps | 16384    |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 200          |
|    ep_rew_mean          | -82.1        |
| time/                   |              |
|    fps                  | 66           |
|    iterations           | 2            |
|    time_elapsed         | 496          |
|    total_timesteps      | 32768        |
| train/                  |              |
|    approx_kl            | 0.0097805215 |
|    clip_fraction        | 0.121        |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.41        |
|    explained_variance   | -0.00987     |
|    learning_r

In [15]:
learn_and_evaluate(no_target_unstacked_non_recurrent, no_target_unstacked_env)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 200      |
|    ep_rew_mean     | -88.2    |
| time/              |          |
|    fps             | 75       |
|    iterations      | 1        |
|    time_elapsed    | 215      |
|    total_timesteps | 16384    |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 200         |
|    ep_rew_mean          | -86.1       |
| time/                   |             |
|    fps                  | 74          |
|    iterations           | 2           |
|    time_elapsed         | 442         |
|    total_timesteps      | 32768       |
| train/                  |             |
|    approx_kl            | 0.007191565 |
|    clip_fraction        | 0.0811      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.39       |
|    explained_variance   | -0.0152     |
|    learning_rate        | 0.

In [16]:
learn_and_evaluate(no_target_unstacked_recurrent, no_target_unstacked_env)

-----------------------------
| time/              |      |
|    fps             | 74   |
|    iterations      | 1    |
|    time_elapsed    | 13   |
|    total_timesteps | 1024 |
-----------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 200          |
|    ep_rew_mean          | -87.5        |
| time/                   |              |
|    fps                  | 56           |
|    iterations           | 2            |
|    time_elapsed         | 36           |
|    total_timesteps      | 2048         |
| train/                  |              |
|    approx_kl            | 0.0018123561 |
|    clip_fraction        | 0.000488     |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.42        |
|    explained_variance   | -0.00169     |
|    learning_rate        | 0.0003       |
|    loss                 | 4.12         |
|    n_updates            | 10           |
|    policy_grad

In [17]:
learn_and_evaluate(no_target_stacked_non_recurrent, no_target_stacked_env)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 200      |
|    ep_rew_mean     | -88.1    |
| time/              |          |
|    fps             | 74       |
|    iterations      | 1        |
|    time_elapsed    | 220      |
|    total_timesteps | 16384    |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 200         |
|    ep_rew_mean          | -81.3       |
| time/                   |             |
|    fps                  | 73          |
|    iterations           | 2           |
|    time_elapsed         | 447         |
|    total_timesteps      | 32768       |
| train/                  |             |
|    approx_kl            | 0.012090378 |
|    clip_fraction        | 0.128       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.39       |
|    explained_variance   | -0.0449     |
|    learning_rate        | 0.