In [1]:
import gym

from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback

In [3]:
train_env = make_vec_env("LunarLanderContinuous-v2", n_envs=16)
eval_env = gym.make("LunarLanderContinuous-v2")

In [7]:
model = SAC(
    "MlpPolicy",
    train_env,
    learning_rate=7.3e-4,
    buffer_size=1000000,
    batch_size=256,
    ent_coef='auto',
    gamma=0.99,
    tau=0.01,
    train_freq=1,
    gradient_steps=1,
    learning_starts=10000,
    policy_kwargs=dict(net_arch=[400, 300]),
    verbose=1,
    tensorboard_log="./runs/sac_lunarcontinuous_tensorboard/"
)

Using cuda device


In [8]:
eval_callback = EvalCallback(
    eval_env,
    best_model_save_path="./logs/continuous/sac",
    log_path="./logs/continuous/sac",
    eval_freq=1000,
    n_eval_episodes=10,
    deterministic=True,
    render=False,
)


In [9]:
model.learn(total_timesteps=int(5e5), callback=eval_callback)

Logging to ./runs/sac_lunar_tensorboard/SAC_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 90.2     |
|    ep_rew_mean     | -226     |
| time/              |          |
|    episodes        | 4        |
|    fps             | 3684     |
|    time_elapsed    | 0        |
|    total_timesteps | 1520     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 96.1     |
|    ep_rew_mean     | -272     |
| time/              |          |
|    episodes        | 8        |
|    fps             | 3692     |
|    time_elapsed    | 0        |
|    total_timesteps | 1680     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 102      |
|    ep_rew_mean     | -263     |
| time/              |          |
|    episodes        | 12       |
|    fps             | 3771     |
|    time_elapsed    | 0        |
| 

<stable_baselines3.sac.sac.SAC at 0x25223967ca0>

In [10]:
model.save("data/policies/LunarLanderContinuous-v2#sac#training")
del model