In [1]:
import gym

from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor

In [2]:
train_env = make_vec_env("LunarLanderContinuous-v2", n_envs=16)
eval_env = Monitor(gym.make("LunarLanderContinuous-v2"))
eval_callback = EvalCallback(
    eval_env,
    best_model_save_path="./logs/continuous/sac",
    log_path="./logs/continuous/sac",
    eval_freq=1000,
    n_eval_episodes=10,
    deterministic=True,
    render=False,
)


In [3]:
model = SAC.load("logs/continuous/sac/best_model", env=train_env)

In [4]:
model.learn(total_timesteps=int(1e6), callback=eval_callback)

Logging to ./runs/sac_lunar_tensorboard/SAC_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 93.5     |
|    ep_rew_mean     | -245     |
| time/              |          |
|    episodes        | 4        |
|    fps             | 5720     |
|    time_elapsed    | 0        |
|    total_timesteps | 1616     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 99.8     |
|    ep_rew_mean     | -209     |
| time/              |          |
|    episodes        | 8        |
|    fps             | 5565     |
|    time_elapsed    | 0        |
|    total_timesteps | 1728     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 110      |
|    ep_rew_mean     | -217     |
| time/              |          |
|    episodes        | 12       |
|    fps             | 5416     |
|    time_elapsed    | 0        |
| 

In [None]:
model.save("data/policies/LunarLander-v2#sac#train_best")
del model