In [6]:
import gym

from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor

In [7]:
train_env = make_vec_env("LunarLanderContinuous-v2", n_envs=16)
eval_env = Monitor(gym.make("LunarLanderContinuous-v2"))
eval_callback = EvalCallback(
    eval_env,
    best_model_save_path="./logs/continuous/sac",
    log_path="./logs/continuous/sac",
    eval_freq=1000,
    n_eval_episodes=10,
    deterministic=True,
    render=False,
)


In [8]:
model = SAC.load("logs/continuous/sac/best_model", env=train_env)

In [9]:
model.learn(total_timesteps=int(1e6), callback=eval_callback)

Logging to ./runs/sac_lunar_tensorboard/SAC_3
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 81.8     |
|    ep_rew_mean     | -155     |
| time/              |          |
|    episodes        | 4        |
|    fps             | 4045     |
|    time_elapsed    | 0        |
|    total_timesteps | 1456     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 88.2     |
|    ep_rew_mean     | -191     |
| time/              |          |
|    episodes        | 8        |
|    fps             | 3973     |
|    time_elapsed    | 0        |
|    total_timesteps | 1536     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 94.2     |
|    ep_rew_mean     | -248     |
| time/              |          |
|    episodes        | 12       |
|    fps             | 3998     |
|    time_elapsed    | 0        |
| 

<stable_baselines3.sac.sac.SAC at 0x239a162b3a0>

In [10]:
model.save("data/policies/LunarLander-v2#sac#train_best")
del model