Training of a simple SAC

In [None]:
from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnMaxEpisodes
from stable_baselines3.common.env_util import make_vec_env
from gymnasium.wrappers import TimeLimit
from gym_unbalanced_disk import UnbalancedDisk
import gymnasium as gym
from gymnasium import spaces

# Training environment
env = AC_UnbalancedDisk()
env = TimeLimit(env, max_episode_steps=500)
env = Monitor(env)

# Separate eval environment for unbiased performance tracking
eval_env = AC_UnbalancedDisk()
eval_env = TimeLimit(eval_env, max_episode_steps=500)
eval_env = Monitor(eval_env)

# Stop after 100 episodes
stop_cb = StopTrainingOnMaxEpisodes(max_episodes=100, verbose=1)

# Save best model based on mean reward
eval_cb = EvalCallback(
    eval_env,
    best_model_save_path="./best_sac_model",
    log_path="./logs",
    eval_freq=5000,
    deterministic=True,
    render=False
)

# Chain both callbacks
from stable_baselines3.common.callbacks import CallbackList
callback = CallbackList([stop_cb, eval_cb])

# Model
model_sac1 = SAC(
    policy='MlpPolicy',
    env=env,
    learning_rate=1e-3,
    verbose=1,
    ent_coef=1e-2,
)

# Train
model_sac1.learn(
    total_timesteps=1_000_000,
    callback=callback,
)


In [None]:
# To use the best-performing model
from stable_baselines3 import SAC
model_sac1 = SAC.load("./best_sac_model/best_model.zip")


env = AC_UnbalancedDisk()
obs, _ = env.reset()
for i in range(5000):
    action, _states = model_sac1.predict(obs)  # policy
    obs, reward, terminated, truncated, info = env.step(action)
    env.render()
    t = obs[0]
    print( f'theta = {t: .4f}, omega: {obs[1]: .4f}')
    if terminated or truncated:
        obs, _ = env.reset()
    
env.close()

Training of multiple SACs with Optuna (A run for over 400 minutes expected)

In [None]:
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
import torch
import numpy as np
import random

from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnMaxEpisodes, CallbackList
from gymnasium.wrappers import TimeLimit

def objective(trial):

    # Sample hyperparameters
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
    # ent_coef = trial.suggest_float("ent_coef", 1e-5, 0.1, log=True)
    gamma = trial.suggest_float("gamma", 0.90, 0.9999)
    tau = trial.suggest_float("tau", 0.005, 0.02)
    batch_size = trial.suggest_categorical("batch_size", [64, 128, 256])
    buffer_size = trial.suggest_categorical("buffer_size", [100_000, 200_000, 500_000])

    # Training env
    env = Monitor(TimeLimit(AC_UnbalancedDisk1(randomize_friction=True), max_episode_steps=500))

    # Evaluation env (no random friction)
    eval_env = Monitor(TimeLimit(AC_UnbalancedDisk1(randomize_friction=False), max_episode_steps=500))

    # Callbacks
    stop_cb = StopTrainingOnMaxEpisodes(max_episodes=100, verbose=1)
    eval_cb = EvalCallback(
        eval_env,
        best_model_save_path=f"./optuna_best_model_trial_{trial.number}",
        log_path=None,
        eval_freq=5000,
        deterministic=True,
        render=False,
    )
    callback = CallbackList([stop_cb, eval_cb])

    # Model
    model = SAC(
        "MlpPolicy",
        env,
        learning_rate=learning_rate,
        # ent_coef=ent_coef,
        gamma=gamma,
        tau=tau,
        batch_size=batch_size,
        buffer_size=buffer_size,
        verbose=1,
        seed=42,
    )

    model.learn(total_timesteps=1_000_000, callback=callback)

    return eval_cb.best_mean_reward

# Optuna study
study = optuna.create_study(direction="maximize", sampler=TPESampler(), pruner=MedianPruner())
study.optimize(objective, n_trials=20)

# Show best result
print("Best trial value:", study.best_trial.value)
print("Best hyperparameters:", study.best_trial.params)


In [None]:
# To use the best-performing model
from stable_baselines3 import SAC
model = SAC.load(f"./optuna_best_model_trial_{17}/best_model.zip")



env = AC_UnbalancedDisk1()
obs, _ = env.reset()
for i in range(1000):
    action, _states = model.predict(obs)  # policy
    obs, reward, terminated, truncated, info = env.step(action)
    env.render()
    t = obs[0]
    print( f'theta = {t: .4f}, omega: {obs[1]: .4f}')
    if terminated or truncated:
        obs, _ = env.reset()
    
env.close()