-
Notifications
You must be signed in to change notification settings - Fork 48
/
maml_run_mujoco.py
131 lines (102 loc) · 4.39 KB
/
maml_run_mujoco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from meta_policy_search.baselines.linear_baseline import LinearFeatureBaseline
from meta_policy_search.envs.mujoco_envs.half_cheetah_rand_direc import HalfCheetahRandDirecEnv
from meta_policy_search.envs.normalized_env import normalize
from meta_policy_search.meta_algos.trpo_maml import TRPOMAML
from meta_policy_search.meta_trainer import Trainer
from meta_policy_search.samplers.meta_sampler import MetaSampler
from meta_policy_search.samplers.meta_sample_processor import MetaSampleProcessor
from meta_policy_search.policies.meta_gaussian_mlp_policy import MetaGaussianMLPPolicy
from meta_policy_search.utils import logger
from meta_policy_search.utils.utils import set_seed, ClassEncoder
import numpy as np
import os
import json
import argparse
import time
meta_policy_search_path = '/'.join(os.path.realpath(os.path.dirname(__file__)).split('/')[:-1])
def main(config):
set_seed(config['seed'])
baseline = globals()[config['baseline']]() #instantiate baseline
env = globals()[config['env']]() # instantiate env
env = normalize(env) # apply normalize wrapper to env
policy = MetaGaussianMLPPolicy(
name="meta-policy",
obs_dim=np.prod(env.observation_space.shape),
action_dim=np.prod(env.action_space.shape),
meta_batch_size=config['meta_batch_size'],
hidden_sizes=config['hidden_sizes'],
)
sampler = MetaSampler(
env=env,
policy=policy,
rollouts_per_meta_task=config['rollouts_per_meta_task'], # This batch_size is confusing
meta_batch_size=config['meta_batch_size'],
max_path_length=config['max_path_length'],
parallel=config['parallel'],
)
sample_processor = MetaSampleProcessor(
baseline=baseline,
discount=config['discount'],
gae_lambda=config['gae_lambda'],
normalize_adv=config['normalize_adv'],
)
algo = TRPOMAML(
policy=policy,
step_size=config['step_size'],
inner_type=config['inner_type'],
inner_lr=config['inner_lr'],
meta_batch_size=config['meta_batch_size'],
num_inner_grad_steps=config['num_inner_grad_steps'],
exploration=False,
)
trainer = Trainer(
algo=algo,
policy=policy,
env=env,
sampler=sampler,
sample_processor=sample_processor,
n_itr=config['n_itr'],
num_inner_grad_steps=config['num_inner_grad_steps'],
)
trainer.train()
if __name__=="__main__":
idx = int(time.time())
parser = argparse.ArgumentParser(description='ProMP: Proximal Meta-Policy Search')
parser.add_argument('--config_file', type=str, default='', help='json file with run specifications')
parser.add_argument('--dump_path', type=str, default=meta_policy_search_path + '/data/pro-mp/run_%d' % idx)
args = parser.parse_args()
if args.config_file: # load configuration from json file
with open(args.config_file, 'r') as f:
config = json.load(f)
else: # use default config
config = {
'seed': 1,
'baseline': 'LinearFeatureBaseline',
'env': 'HalfCheetahRandDirecEnv',
# sampler config
'rollouts_per_meta_task': 20,
'max_path_length': 100,
'parallel': True,
# sample processor config
'discount': 0.99,
'gae_lambda': 1,
'normalize_adv': True,
# policy config
'hidden_sizes': (64, 64),
'learn_std': True, # whether to learn the standard deviation of the gaussian policy
# E-MAML config
'inner_lr': 0.1, # adaptation step size
'learning_rate': 1e-3, # meta-policy gradient step size
'step_size': 0.01, # size of the TRPO trust-region
'n_itr': 1001, # number of overall training iterations
'meta_batch_size': 40, # number of sampled meta-tasks per iterations
'num_inner_grad_steps': 1, # number of inner / adaptation gradient steps
'inner_type' : 'log_likelihood', # type of inner loss function used
}
# configure logger
logger.configure(dir=args.dump_path, format_strs=['stdout', 'log', 'csv'],
snapshot_mode='last_gap')
# dump run configuration before starting training
json.dump(config, open(args.dump_path + '/params.json', 'w'), cls=ClassEncoder)
# start the actual algorithm
main(config)