forked from PKU-Alignment/ReDMan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
88 lines (79 loc) · 3.82 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from ast import arg
import numpy as np
import random
from utils.config import set_np_formatting, set_seed, get_args, parse_sim_params, load_cfg
from utils.parse_task import parse_task
from utils.process_sarl import *
from algorithms import REGISTRY
from algorithms.module import Actor, Critic
from utils.Logger import EpochLogger
def train(logdir):
print("Algorithm: ", args.algo)
logger = EpochLogger(args.algo, args.task, args.seed)
agent_index = get_AgentIndex(cfg)
if args.algo in ["ppol", "focops", "p3o" ,"pcpo","cpo",'trpol','ppo','cppo_pid']:
task, env = parse_task(args, cfg, cfg_train, sim_params, agent_index)
learn_cfg = cfg_train["learn"]
is_testing = learn_cfg["test"]
# is_testing = True
# Override resume and testing flags if they are passed as parameters.
if args.model_dir != "":
is_testing = True
chkpt_path = args.model_dir
logdir = logdir + "_seed{}".format(env.task.cfg["seed"])
iterations = cfg_train["learn"]["max_iterations"]
if args.max_iterations > 0:
iterations = args.max_iterations
"""Set up the agent system for training or inferencing."""
agent = REGISTRY[args.algo](vec_env=env,
logger=logger,
actor_class=Actor,
critic_class=Critic,
cost_critic_class=Critic,
cost_lim=args.cost_lim,
num_transitions_per_env=learn_cfg["nsteps"],
num_learning_epochs=learn_cfg["noptepochs"],
num_mini_batches=learn_cfg["nminibatches"],
clip_param=learn_cfg["cliprange"],
gamma=learn_cfg["gamma"],
lam=learn_cfg["lam"],
init_noise_std=learn_cfg.get("init_noise_std", 0.3),
value_loss_coef=learn_cfg.get("value_loss_coef", 2.0),
entropy_coef=learn_cfg["ent_coef"],
learning_rate=learn_cfg["optim_stepsize"],
max_grad_norm=learn_cfg.get("max_grad_norm", 2.0),
use_clipped_value_loss=learn_cfg.get("use_clipped_value_loss", False),
schedule=learn_cfg.get("schedule", "fixed"),
desired_kl=learn_cfg.get("desired_kl", None),
model_cfg=cfg_train["policy"],
device=env.rl_device,
sampler=learn_cfg.get("sampler", 'sequential'),
log_dir=logdir,
is_testing=is_testing,
print_log=learn_cfg["print_log"],
apply_reset=False,
asymmetric=(env.num_states > 0),
debug=args.debug
)
if is_testing and args.model_dir != "":
print("Loading model from {}".format(chkpt_path))
agent.test(chkpt_path)
elif args.model_dir != "":
print("Loading model from {}".format(chkpt_path))
agent.load(chkpt_path)
agent.run(num_learning_iterations=iterations, log_interval=cfg_train["learn"]["save_interval"])
else:
print("Unrecognized algorithm!\nAlgorithm should be one of: [happo, hatrpo, mappo,ippo,maddpg,sac,td3,trpo,ppo,ddpg]")
if __name__ == '__main__':
set_np_formatting()
args = get_args()
cfg, cfg_train, logdir = load_cfg(args)
sim_params = parse_sim_params(args, cfg, cfg_train)
set_seed(cfg_train.get("seed", -1), cfg_train.get("torch_deterministic", False))
train(logdir)