In [1]:
import gymnasium as gym
import tianshou as ts

from tianshou.algorithm.modelfree.sac import SAC, SACPolicy, SACTrainingStats
from tianshou.algorithm.optim import AdamOptimizerFactory
from tianshou.data import CollectStats, Collector
from tianshou.trainer import OffPolicyTrainerParams
from tianshou.utils.net.continuous import (
    # ContinuousActorDeterministic,
    ContinuousCritic,
    ContinuousActorProbabilistic,
)
from tianshou.utils.net.common import Net
from tianshou.utils.space_info import SpaceInfo
from tianshou.data import VectorReplayBuffer, ReplayBuffer
from tianshou.env import DummyVectorEnv

from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger

import matplotlib.pyplot as plt

In [2]:
gym.envs.registry.keys()

dict_keys(['CartPole-v0', 'CartPole-v1', 'MountainCar-v0', 'MountainCarContinuous-v0', 'Pendulum-v1', 'Acrobot-v1', 'phys2d/CartPole-v0', 'phys2d/CartPole-v1', 'phys2d/Pendulum-v0', 'LunarLander-v3', 'LunarLanderContinuous-v3', 'BipedalWalker-v3', 'BipedalWalkerHardcore-v3', 'CarRacing-v3', 'Blackjack-v1', 'FrozenLake-v1', 'FrozenLake8x8-v1', 'CliffWalking-v1', 'CliffWalkingSlippery-v1', 'Taxi-v3', 'tabular/Blackjack-v0', 'tabular/CliffWalking-v0', 'Reacher-v2', 'Reacher-v4', 'Reacher-v5', 'Pusher-v2', 'Pusher-v4', 'Pusher-v5', 'InvertedPendulum-v2', 'InvertedPendulum-v4', 'InvertedPendulum-v5', 'InvertedDoublePendulum-v2', 'InvertedDoublePendulum-v4', 'InvertedDoublePendulum-v5', 'HalfCheetah-v2', 'HalfCheetah-v3', 'HalfCheetah-v4', 'HalfCheetah-v5', 'Hopper-v2', 'Hopper-v3', 'Hopper-v4', 'Hopper-v5', 'Swimmer-v2', 'Swimmer-v3', 'Swimmer-v4', 'Swimmer-v5', 'Walker2d-v2', 'Walker2d-v3', 'Walker2d-v4', 'Walker2d-v5', 'Ant-v2', 'Ant-v3', 'Ant-v4', 'Ant-v5', 'Humanoid-v2', 'Humanoid-v3', 

In [3]:
env = gym.make("Pendulum-v1")

In [4]:
task = "Pendulum-v1"

num_train_envs = 10
num_test_envs = 100

# SAC Parameters
hidden_sizes = [256, 256]
lr_actor = 3e-4
lr_critic = 3e-4
gamma = 0.99
tau = 0.005
n_step_warmup = 10_000
buffer_size = 100_000

# Training Parameters
max_epochs = 5
epoch_num_steps = 5000
batch_size = 256

In [5]:
env = gym.make(task)
space_info = SpaceInfo.from_env(env)
state_shape = space_info.observation_info.obs_shape
action_shape = space_info.action_info.action_shape

logger = TensorboardLogger(SummaryWriter("log/sac"))

train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])

In [6]:
# Actor
actor = ContinuousActorProbabilistic(
    preprocess_net=Net(
        state_shape=state_shape, action_shape=action_shape, hidden_sizes=hidden_sizes
    ),
    action_shape=action_shape,
)
actor_optim = AdamOptimizerFactory(lr=lr_actor)

critic = ContinuousCritic(
    preprocess_net=Net(
        state_shape=state_shape,
        action_shape=action_shape,
        hidden_sizes=hidden_sizes,
        concat=True,
    )
)
critic_optim = AdamOptimizerFactory(lr=lr_critic)


# Policy
policy = SACPolicy(
    actor=actor,
    exploration_noise="default",
    action_space=env.action_space,
    action_scaling=False,
)

# Algorithm
algorithm = SAC(
    policy=policy,
    policy_optim=actor_optim,
    critic=critic,
    critic_optim=critic_optim,
    tau=tau,
    gamma=gamma,
)

train_collector = Collector[CollectStats](
    policy=algorithm,
    env=env,
    buffer=VectorReplayBuffer(
        total_size=buffer_size,
        buffer_num=num_train_envs,
    ),
    exploration_noise=True,
)

test_collector = Collector[CollectStats](
    policy=algorithm,
    env=env,
    exploration_noise=False,
)

# Warm up
train_collector.reset()
train_collector.collect(n_step=n_step_warmup, random=True)


# Define stop condition
def stop_fn(mean_rewards: float) -> bool:
    if env.spec and env.spec.reward_threshold:
        return mean_rewards >= env.spec.reward_threshold
    return False



In [9]:
# Run the training
result = algorithm.run_training(
    OffPolicyTrainerParams(
        max_epochs=max_epochs,
        epoch_num_steps=epoch_num_steps,
        test_collector=test_collector,
        logger=logger,
        training_collector=train_collector,
        batch_size=batch_size,
        collection_step_num_env_steps=2500,
    )
)

Initial test step: test_reward: -1283.831694 ± 0.000000, best_reward: -1283.831694 ± 0.000000 in #0


Epoch #1: 100%|##########| 5000/5000 [00:26<00:00, 186.88it/s, env_episode=25, env_step=5000, len=200, n_ep=13, n_st=2500, rew=-272.16, update_step=2]


Epoch #1: test_reward: -242.020377 ± 0.000000, best_reward: -242.020377 ± 0.000000 in #1


Epoch #2: 100%|##########| 5000/5000 [00:26<00:00, 187.79it/s, env_episode=50, env_step=10000, len=200, n_ep=13, n_st=2500, rew=-658.17, update_step=4]

Epoch #2: test_reward: -123.440496 ± 0.000000, best_reward: -123.440496 ± 0.000000 in #2



Epoch #3: 100%|##########| 5000/5000 [00:26<00:00, 187.34it/s, env_episode=75, env_step=15000, len=200, n_ep=13, n_st=2500, rew=-303.18, update_step=6]

Epoch #3: test_reward: -367.777299 ± 0.000000, best_reward: -123.440496 ± 0.000000 in #2



Epoch #4: 100%|##########| 5000/5000 [00:26<00:00, 187.06it/s, env_episode=100, env_step=20000, len=200, n_ep=13, n_st=2500, rew=-402.84, update_step=8]

Epoch #4: test_reward: -126.188644 ± 0.000000, best_reward: -123.440496 ± 0.000000 in #2



Epoch #5: 100%|##########| 5000/5000 [00:26<00:00, 187.25it/s, env_episode=125, env_step=25000, len=200, n_ep=13, n_st=2500, rew=-407.40, update_step=10]

Epoch #5: test_reward: -479.226360 ± 0.000000, best_reward: -123.440496 ± 0.000000 in #2





In [12]:
# Visualize the results
render_env = gym.make(task, render_mode="human")
render_collector = Collector[CollectStats](algorithm, render_env)
render_collector.reset()
render_collector.collect(n_episode=3, render=1 / 60)
# render_collector.close()

CollectStats(n_collected_episodes=3, n_collected_steps=600, collect_time=41.67225098609924, collect_speed=14.398070317827182, returns=array([-126.10705059, -242.65335812, -241.96819234]), returns_stat=SequenceSummaryStats(mean=-203.5762003509542, std=54.779675283007585, max=-126.10705059331679, min=-242.65335812367138), lens=array([200, 200, 200]), lens_stat=SequenceSummaryStats(mean=200.0, std=0.0, max=200.0, min=200.0), pred_dist_std_array=array([[0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466521],
       [0.3466