In [29]:
import argparse
import os
import pprint

import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net
from tianshou.utils.net.discrete import Actor, Critic



In [30]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="LunarLander-v2")
    parser.add_argument("--reward-threshold", type=float, default=None)
    parser.add_argument("--seed", type=int, default=1626)
    parser.add_argument("--buffer-size", type=int, default=20000)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--epoch", type=int, default=10)
    parser.add_argument("--step-per-epoch", type=int, default=50000)
    parser.add_argument("--step-per-collect", type=int, default=2000)
    parser.add_argument("--repeat-per-collect", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
    parser.add_argument("--training-num", type=int, default=20)
    parser.add_argument("--test-num", type=int, default=100)
    parser.add_argument("--logdir", type=str, default="log")
    parser.add_argument("--render", type=float, default=0.0)
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
    )
    # ppo special
    parser.add_argument("--vf-coef", type=float, default=0.5)
    parser.add_argument("--ent-coef", type=float, default=0.0)
    parser.add_argument("--eps-clip", type=float, default=0.2)
    parser.add_argument("--max-grad-norm", type=float, default=0.5)
    parser.add_argument("--gae-lambda", type=float, default=0.95)
    parser.add_argument("--rew-norm", type=int, default=0)
    parser.add_argument("--norm-adv", type=int, default=0)
    parser.add_argument("--recompute-adv", type=int, default=0)
    parser.add_argument("--dual-clip", type=float, default=None)
    parser.add_argument("--value-clip", type=int, default=0)
    return parser.parse_known_args()[0]



In [31]:
args=get_args()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if args.reward_threshold is None:
    default_reward_threshold = {"CartPole-v0": 195}
    args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
if torch.cuda.is_available():
    actor = DataParallelNet(Actor(net, args.action_shape, device=None).to(args.device))
    critic = DataParallelNet(Critic(net, device=None).to(args.device))
else:
    actor = Actor(net, args.action_shape, device=args.device).to(args.device)
    critic = Critic(net, device=args.device).to(args.device)
actor_critic = ActorCritic(actor, critic)
# orthogonal initialization
for m in actor_critic.modules():
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.orthogonal_(m.weight)
        torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
dist = torch.distributions.Categorical
policy = PPOPolicy(
    actor=actor,
    critic=critic,
    optim=optim,
    dist_fn=dist,
    action_scaling=isinstance(env.action_space, Box),
    discount_factor=args.gamma,
    max_grad_norm=args.max_grad_norm,
    eps_clip=args.eps_clip,
    vf_coef=args.vf_coef,
    ent_coef=args.ent_coef,
    gae_lambda=args.gae_lambda,
    reward_normalization=args.rew_norm,
    dual_clip=args.dual_clip,
    value_clip=args.value_clip,
    action_space=env.action_space,
    deterministic_eval=True,
    advantage_normalization=args.norm_adv,
    recompute_advantage=args.recompute_adv,
)
# collector
train_collector = Collector(
    policy,
    train_envs,
    VectorReplayBuffer(args.buffer_size, len(train_envs)),
)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, "ppo")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

def save_best_fn(policy):
    torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
    return mean_rewards >= args.reward_threshold

# trainer
result = OnpolicyTrainer(
    policy=policy,
    train_collector=train_collector,
    test_collector=test_collector,
    max_epoch=args.epoch,
    step_per_epoch=args.step_per_epoch,
    repeat_per_collect=args.repeat_per_collect,
    episode_per_test=args.test_num,
    batch_size=args.batch_size,
    step_per_collect=args.step_per_collect,
    stop_fn=stop_fn,
    save_best_fn=save_best_fn,
    logger=logger,
).run()


Epoch #1: 50001it [00:40, 1227.01it/s, env_step=50000, gradient_step=800, len=180, n/ep=1, n/st=2000, rew=-60.82]                             


Epoch #1: test_reward: -252.066959 ± 26.117025, best_reward: -162.349102 ± 135.970389 in #0


Epoch #2: 50001it [00:41, 1213.69it/s, env_step=100000, gradient_step=1600, len=1000, n/ep=2, n/st=2000, rew=57.09]                           


Epoch #2: test_reward: -3.640651 ± 40.743170, best_reward: -3.640651 ± 40.743170 in #2


Epoch #3: 50001it [00:48, 1031.11it/s, env_step=150000, gradient_step=2400, len=629, n/ep=2, n/st=2000, rew=12.36]                             


Epoch #3: test_reward: 75.231847 ± 84.574756, best_reward: 75.231847 ± 84.574756 in #3


Epoch #4: 50001it [00:49, 1015.98it/s, env_step=200000, gradient_step=3200, len=873, n/ep=4, n/st=2000, rew=97.47]                             


Epoch #4: test_reward: -53.330257 ± 23.724646, best_reward: 75.231847 ± 84.574756 in #3


Epoch #5: 50001it [00:40, 1236.36it/s, env_step=250000, gradient_step=4000, len=973, n/ep=2, n/st=2000, rew=132.50]                            


Epoch #5: test_reward: -62.996039 ± 84.763657, best_reward: 75.231847 ± 84.574756 in #3


Epoch #6: 50001it [00:39, 1251.45it/s, env_step=300000, gradient_step=4800, len=651, n/ep=0, n/st=2000, rew=16.79]                            


Epoch #6: test_reward: 11.072115 ± 89.324507, best_reward: 75.231847 ± 84.574756 in #3


Epoch #7: 50001it [00:47, 1054.88it/s, env_step=350000, gradient_step=5600, len=1000, n/ep=3, n/st=2000, rew=97.85]                            


Epoch #7: test_reward: 69.825404 ± 102.039539, best_reward: 75.231847 ± 84.574756 in #3


Epoch #8: 50001it [00:40, 1238.86it/s, env_step=400000, gradient_step=6400, len=690, n/ep=4, n/st=2000, rew=91.85]                             


Epoch #8: test_reward: -4.331552 ± 111.778525, best_reward: 75.231847 ± 84.574756 in #3


Epoch #9: 50001it [00:40, 1236.03it/s, env_step=450000, gradient_step=7200, len=654, n/ep=2, n/st=2000, rew=141.86]                            


Epoch #9: test_reward: 51.896893 ± 119.259292, best_reward: 75.231847 ± 84.574756 in #3


Epoch #10: 50001it [00:45, 1096.20it/s, env_step=500000, gradient_step=8000, len=419, n/ep=5, n/st=2000, rew=97.42]                             


Epoch #10: test_reward: -70.558810 ± 226.807568, best_reward: 75.231847 ± 84.574756 in #3


In [32]:

# Let's watch its performance!
env = gym.make(args.task,render_mode="human")
num=3
policy.eval()
for _ in range(num):
    
    collector = Collector(policy, env,exploration_noise=True)
    result = collector.collect(n_episode=1, render=args.render)
    print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}")
env.close()

Final reward: -115.63431414646574, length: 866.0
Final reward: -39.589728014548726, length: 475.0
Final reward: -201.59155501379962, length: 478.0
