-
Notifications
You must be signed in to change notification settings - Fork 63
/
run_lunarlander_continuous_v2.py
79 lines (66 loc) · 2.24 KB
/
run_lunarlander_continuous_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# -*- coding: utf-8 -*-
"""Train or test algorithms on LunarLanderContinuous-v2.
- Author: Curt Park
- Contact: curt.park@medipixel.io
"""
import argparse
import importlib
import gym
import algorithms.common.env.utils as env_utils
import algorithms.common.helper_functions as common_utils
# configurations
parser = argparse.ArgumentParser(description="Pytorch RL algorithms")
parser.add_argument(
"--seed", type=int, default=777, help="random seed for reproducibility"
)
parser.add_argument("--algo", type=str, default="ddpg", help="choose an algorithm")
parser.add_argument(
"--test", dest="test", action="store_true", help="test mode (no training)"
)
parser.add_argument(
"--load-from", type=str, help="load the saved model and optimizer at the beginning"
)
parser.add_argument(
"--off-render", dest="render", action="store_false", help="turn off rendering"
)
parser.add_argument(
"--render-after",
type=int,
default=0,
help="start rendering after the input number of episode",
)
parser.add_argument("--log", dest="log", action="store_true", help="turn on logging")
parser.add_argument("--save-period", type=int, default=100, help="save model period")
parser.add_argument("--episode-num", type=int, default=1500, help="total episode num")
parser.add_argument(
"--max-episode-steps", type=int, default=300, help="max episode step"
)
parser.add_argument(
"--interim-test-num", type=int, default=10, help="interim test number"
)
parser.add_argument(
"--demo-path",
type=str,
default="data/lunarlander_continuous_demo.pkl",
help="demonstration path",
)
parser.set_defaults(test=False)
parser.set_defaults(load_from=None)
parser.set_defaults(render=True)
parser.set_defaults(log=False)
args = parser.parse_args()
def main():
"""Main."""
# env initialization
env = gym.make("LunarLanderContinuous-v2")
env_utils.set_env(env, args)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
# set a random seed
common_utils.set_random_seed(args.seed, env)
# run
module_path = "examples.lunarlander_continuous_v2." + args.algo
example = importlib.import_module(module_path)
example.run(env, args, state_dim, action_dim)
if __name__ == "__main__":
main()