forked from floringogianu/categorical-dqn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
84 lines (62 loc) · 2.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gc, time # noqa
import gym, gym_fast_envs # noqa
import torch, numpy # noqa
import utils
from utils import Preprocessor
from utils import EvaluationMonitor
from agents import get_agent
def train_agent(cmdl):
step_cnt = 0
ep_cnt = 0
start_time = time.time()
env = utils.get_new_env(cmdl.env_name, cmdl)
eval_env = EvaluationMonitor(gym.make(cmdl.env_name), cmdl)
name = cmdl.agent.name
agent = get_agent(name)(env.action_space, cmdl.agent)
eval_agent = get_agent(name)(eval_env.action_space, cmdl.agent, False)
preprocess = Preprocessor(cmdl.env_class).transform
agent.display_setup(env, cmdl)
while step_cnt < cmdl.training.step_no:
ep_cnt += 1
o, r, done = env.reset(), 0, False
s = preprocess(o)
while not done:
a = agent.evaluate_policy(s)
o, r, done, _ = env.step(a)
_s, _a = s, a
s = preprocess(o)
agent.improve_policy(_s, _a, r, s, done)
step_cnt += 1
agent.gather_stats(r, done)
if step_cnt % cmdl.report_freq == 0:
agent.display_stats(start_time)
agent.display_model_stats()
gc.collect()
if step_cnt % cmdl.eval_freq == 0:
evaluate_agent(step_cnt, eval_env, eval_agent, agent.policy, cmdl)
end_time = time.time()
agent.display_final_report(ep_cnt, step_cnt, end_time - start_time)
def evaluate_agent(crt_training_step, eval_env, eval_agent, policy, cmdl):
print("[Evaluator] Initializing at %d training steps:" %
crt_training_step)
agent = eval_agent
eval_env.get_crt_step(crt_training_step)
agent.policy_evaluation.policy.load_state_dict(policy.state_dict())
preprocess = Preprocessor(cmdl.env_class).transform
step_cnt = 0
o, r, done = eval_env.reset(), 0, False
while step_cnt < cmdl.evaluator.eval_steps:
s = preprocess(o)
a = agent.evaluate_policy(s)
o, r, done, _ = eval_env.step(a)
step_cnt += 1
if done:
o, r, done = eval_env.reset(), 0, False
if __name__ == "__main__":
# Parse cmdl args for the config file and return config as Namespace
config = utils.parse_config_file(utils.parse_cmd_args())
# Assuming everything in the config is deterministic already.
torch.manual_seed(config.seed)
numpy.random.seed(config.seed)
# Let's do this!
train_agent(config)