/
run_ppo.py
44 lines (39 loc) · 1.42 KB
/
run_ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gym
from tf2rl.algos.ppo import PPO
from tf2rl.policies.categorical_actor import CategoricalActorCritic
from tf2rl.experiments.on_policy_trainer import OnPolicyTrainer
from tf2rl.envs.utils import is_discrete, get_act_dim
if __name__ == '__main__':
parser = OnPolicyTrainer.get_argument()
parser = PPO.get_argument(parser)
parser.add_argument('--env-name', type=str,
default="Pendulum-v0")
parser.set_defaults(test_interval=20480)
parser.set_defaults(max_steps=int(1e7))
parser.set_defaults(horizon=2048)
parser.set_defaults(batch_size=64)
parser.set_defaults(gpu=-1)
args = parser.parse_args()
env = gym.make(args.env_name)
test_env = gym.make(args.env_name)
policy = PPO(
state_shape=env.observation_space.shape,
action_dim=get_act_dim(env.action_space),
is_discrete=is_discrete(env.action_space),
max_action=None if is_discrete(
env.action_space) else env.action_space.high[0],
batch_size=args.batch_size,
actor_units=[64, 64],
critic_units=[64, 64],
n_epoch=10,
n_epoch_critic=10,
lr_actor=3e-4,
lr_critic=3e-4,
discount=0.99,
lam=0.95,
horizon=args.horizon,
normalize_adv=args.normalize_adv,
enable_gae=args.enable_gae,
gpu=args.gpu)
trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)
trainer()