forked from openai/baselines
-
Notifications
You must be signed in to change notification settings - Fork 725
/
a2c.py
413 lines (352 loc) · 20.6 KB
/
a2c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
import time
import gym
import numpy as np
import tensorflow as tf
from stable_baselines import logger
from stable_baselines.common import explained_variance, tf_util, ActorCriticRLModel, SetVerbosity, TensorboardWriter
from stable_baselines.common.policies import ActorCriticPolicy, RecurrentActorCriticPolicy
from stable_baselines.common.runners import AbstractEnvRunner
from stable_baselines.common.schedules import Scheduler
from stable_baselines.common.tf_util import mse, total_episode_reward_logger
from stable_baselines.common.math_util import safe_mean
def discount_with_dones(rewards, dones, gamma):
"""
Apply the discount value to the reward, where the environment is not done
:param rewards: ([float]) The rewards
:param dones: ([bool]) Whether an environment is done or not
:param gamma: (float) The discount value
:return: ([float]) The discounted rewards
"""
discounted = []
ret = 0 # Return: discounted reward
for reward, done in zip(rewards[::-1], dones[::-1]):
ret = reward + gamma * ret * (1. - done) # fixed off by one bug
discounted.append(ret)
return discounted[::-1]
class A2C(ActorCriticRLModel):
"""
The A2C (Advantage Actor Critic) model class, https://arxiv.org/abs/1602.01783
:param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...)
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
:param gamma: (float) Discount factor
:param n_steps: (int) The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param vf_coef: (float) Value function coefficient for the loss calculation
:param ent_coef: (float) Entropy coefficient for the loss calculation
:param max_grad_norm: (float) The maximum value for the gradient clipping
:param learning_rate: (float) The learning rate
:param alpha: (float) RMSProp decay parameter (default: 0.99)
:param momentum: (float) RMSProp momentum parameter (default: 0.0)
:param epsilon: (float) RMSProp epsilon (stabilizes square root computation in denominator of RMSProp update)
(default: 1e-5)
:param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
'double_linear_con', 'middle_drop' or 'double_middle_drop')
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
(used only for loading)
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param full_tensorboard_log: (bool) enable additional logging when using tensorboard
WARNING: this logging can take a lot of space quickly
:param seed: (int) Seed for the pseudo-random generators (python, numpy, tensorflow).
If None (default), use random seed. Note that if you want completely deterministic
results, you must set `n_cpu_tf_sess` to 1.
:param n_cpu_tf_sess: (int) The number of threads for TensorFlow operations
If None, the number of cpu of the current machine will be used.
"""
def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0.01, max_grad_norm=0.5,
learning_rate=7e-4, alpha=0.99, momentum=0.0, epsilon=1e-5, lr_schedule='constant',
verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None,
full_tensorboard_log=False, seed=None, n_cpu_tf_sess=None):
self.n_steps = n_steps
self.gamma = gamma
self.vf_coef = vf_coef
self.ent_coef = ent_coef
self.max_grad_norm = max_grad_norm
self.alpha = alpha
self.momentum = momentum
self.epsilon = epsilon
self.lr_schedule = lr_schedule
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self.full_tensorboard_log = full_tensorboard_log
self.learning_rate_ph = None
self.n_batch = None
self.actions_ph = None
self.advs_ph = None
self.rewards_ph = None
self.pg_loss = None
self.vf_loss = None
self.entropy = None
self.apply_backprop = None
self.train_model = None
self.step_model = None
self.proba_step = None
self.value = None
self.initial_state = None
self.learning_rate_schedule = None
self.summary = None
super(A2C, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True,
_init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs,
seed=seed, n_cpu_tf_sess=n_cpu_tf_sess)
# if we are loading, it is possible the environment is not known, however the obs and action space are known
if _init_setup_model:
self.setup_model()
def _make_runner(self) -> AbstractEnvRunner:
return A2CRunner(self.env, self, n_steps=self.n_steps, gamma=self.gamma)
def _get_pretrain_placeholders(self):
policy = self.train_model
if isinstance(self.action_space, gym.spaces.Discrete):
return policy.obs_ph, self.actions_ph, policy.policy
return policy.obs_ph, self.actions_ph, policy.deterministic_action
def setup_model(self):
with SetVerbosity(self.verbose):
assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the A2C model must be an " \
"instance of common.policies.ActorCriticPolicy."
self.graph = tf.Graph()
with self.graph.as_default():
self.set_random_seed(self.seed)
self.sess = tf_util.make_session(num_cpu=self.n_cpu_tf_sess, graph=self.graph)
self.n_batch = self.n_envs * self.n_steps
n_batch_step = None
n_batch_train = None
if issubclass(self.policy, RecurrentActorCriticPolicy):
n_batch_step = self.n_envs
n_batch_train = self.n_envs * self.n_steps
step_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
n_batch_step, reuse=False, **self.policy_kwargs)
with tf.variable_scope("train_model", reuse=True,
custom_getter=tf_util.outer_scope_getter("train_model")):
train_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs,
self.n_steps, n_batch_train, reuse=True, **self.policy_kwargs)
with tf.variable_scope("loss", reuse=False):
self.actions_ph = train_model.pdtype.sample_placeholder([None], name="action_ph")
self.advs_ph = tf.placeholder(tf.float32, [None], name="advs_ph")
self.rewards_ph = tf.placeholder(tf.float32, [None], name="rewards_ph")
self.learning_rate_ph = tf.placeholder(tf.float32, [], name="learning_rate_ph")
neglogpac = train_model.proba_distribution.neglogp(self.actions_ph)
self.entropy = tf.reduce_mean(train_model.proba_distribution.entropy())
self.pg_loss = tf.reduce_mean(self.advs_ph * neglogpac)
self.vf_loss = mse(tf.squeeze(train_model.value_flat), self.rewards_ph)
# https://arxiv.org/pdf/1708.04782.pdf#page=9, https://arxiv.org/pdf/1602.01783.pdf#page=4
# and https://github.com/dennybritz/reinforcement-learning/issues/34
# suggest to add an entropy component in order to improve exploration.
loss = self.pg_loss - self.entropy * self.ent_coef + self.vf_loss * self.vf_coef
tf.summary.scalar('entropy_loss', self.entropy)
tf.summary.scalar('policy_gradient_loss', self.pg_loss)
tf.summary.scalar('value_function_loss', self.vf_loss)
tf.summary.scalar('loss', loss)
self.params = tf_util.get_trainable_vars("model")
grads = tf.gradients(loss, self.params)
if self.max_grad_norm is not None:
grads, _ = tf.clip_by_global_norm(grads, self.max_grad_norm)
grads = list(zip(grads, self.params))
with tf.variable_scope("input_info", reuse=False):
tf.summary.scalar('discounted_rewards', tf.reduce_mean(self.rewards_ph))
tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate_ph))
tf.summary.scalar('advantage', tf.reduce_mean(self.advs_ph))
if self.full_tensorboard_log:
tf.summary.histogram('discounted_rewards', self.rewards_ph)
tf.summary.histogram('learning_rate', self.learning_rate_ph)
tf.summary.histogram('advantage', self.advs_ph)
if tf_util.is_image(self.observation_space):
tf.summary.image('observation', train_model.obs_ph)
else:
tf.summary.histogram('observation', train_model.obs_ph)
trainer = tf.train.RMSPropOptimizer(learning_rate=self.learning_rate_ph, decay=self.alpha,
epsilon=self.epsilon, momentum=self.momentum)
self.apply_backprop = trainer.apply_gradients(grads)
self.train_model = train_model
self.step_model = step_model
self.step = step_model.step
self.proba_step = step_model.proba_step
self.value = step_model.value
self.initial_state = step_model.initial_state
tf.global_variables_initializer().run(session=self.sess)
self.summary = tf.summary.merge_all()
def _train_step(self, obs, states, rewards, masks, actions, values, update, writer=None):
"""
applies a training step to the model
:param obs: ([float]) The input observations
:param states: ([float]) The states (used for recurrent policies)
:param rewards: ([float]) The rewards from the environment
:param masks: ([bool]) Whether or not the episode is over (used for recurrent policies)
:param actions: ([float]) The actions taken
:param values: ([float]) The logits values
:param update: (int) the current step iteration
:param writer: (TensorFlow Summary.writer) the writer for tensorboard
:return: (float, float, float) policy loss, value loss, policy entropy
"""
advs = rewards - values
cur_lr = None
for _ in range(len(obs)):
cur_lr = self.learning_rate_schedule.value()
assert cur_lr is not None, "Error: the observation input array cannon be empty"
td_map = {self.train_model.obs_ph: obs, self.actions_ph: actions, self.advs_ph: advs,
self.rewards_ph: rewards, self.learning_rate_ph: cur_lr}
if states is not None:
td_map[self.train_model.states_ph] = states
td_map[self.train_model.dones_ph] = masks
if writer is not None:
# run loss backprop with summary, but once every 10 runs save the metadata (memory, compute time, ...)
if self.full_tensorboard_log and (1 + update) % 10 == 0:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, policy_loss, value_loss, policy_entropy, _ = self.sess.run(
[self.summary, self.pg_loss, self.vf_loss, self.entropy, self.apply_backprop],
td_map, options=run_options, run_metadata=run_metadata)
writer.add_run_metadata(run_metadata, 'step%d' % (update * self.n_batch))
else:
summary, policy_loss, value_loss, policy_entropy, _ = self.sess.run(
[self.summary, self.pg_loss, self.vf_loss, self.entropy, self.apply_backprop], td_map)
writer.add_summary(summary, update * self.n_batch)
else:
policy_loss, value_loss, policy_entropy, _ = self.sess.run(
[self.pg_loss, self.vf_loss, self.entropy, self.apply_backprop], td_map)
return policy_loss, value_loss, policy_entropy
def set_env(self,env):
super().set_env(env)
self.n_batch = self.n_envs * self.n_steps
def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="A2C",
reset_num_timesteps=True):
new_tb_log = self._init_num_timesteps(reset_num_timesteps)
callback = self._init_callback(callback)
with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
as writer:
self._setup_learn()
self.learning_rate_schedule = Scheduler(initial_value=self.learning_rate, n_values=total_timesteps,
schedule=self.lr_schedule)
t_start = time.time()
callback.on_training_start(locals(), globals())
for update in range(1, total_timesteps // self.n_batch + 1):
callback.on_rollout_start()
# true_reward is the reward without discount
rollout = self.runner.run(callback)
# unpack
obs, states, rewards, masks, actions, values, ep_infos, true_reward = rollout
callback.update_locals(locals())
callback.on_rollout_end()
# Early stopping due to the callback
if not self.runner.continue_training:
break
self.ep_info_buf.extend(ep_infos)
_, value_loss, policy_entropy = self._train_step(obs, states, rewards, masks, actions, values,
self.num_timesteps // self.n_batch, writer)
n_seconds = time.time() - t_start
fps = int((update * self.n_batch) / n_seconds)
if writer is not None:
total_episode_reward_logger(self.episode_reward,
true_reward.reshape((self.n_envs, self.n_steps)),
masks.reshape((self.n_envs, self.n_steps)),
writer, self.num_timesteps)
if self.verbose >= 1 and (update % log_interval == 0 or update == 1):
explained_var = explained_variance(values, rewards)
logger.record_tabular("nupdates", update)
logger.record_tabular("total_timesteps", self.num_timesteps)
logger.record_tabular("fps", fps)
logger.record_tabular("policy_entropy", float(policy_entropy))
logger.record_tabular("value_loss", float(value_loss))
logger.record_tabular("explained_variance", float(explained_var))
if len(self.ep_info_buf) > 0 and len(self.ep_info_buf[0]) > 0:
logger.logkv('ep_reward_mean', safe_mean([ep_info['r'] for ep_info in self.ep_info_buf]))
logger.logkv('ep_len_mean', safe_mean([ep_info['l'] for ep_info in self.ep_info_buf]))
logger.dump_tabular()
callback.on_training_end()
return self
def save(self, save_path, cloudpickle=False):
data = {
"gamma": self.gamma,
"n_steps": self.n_steps,
"vf_coef": self.vf_coef,
"ent_coef": self.ent_coef,
"max_grad_norm": self.max_grad_norm,
"learning_rate": self.learning_rate,
"alpha": self.alpha,
"epsilon": self.epsilon,
"lr_schedule": self.lr_schedule,
"verbose": self.verbose,
"policy": self.policy,
"observation_space": self.observation_space,
"action_space": self.action_space,
"n_envs": self.n_envs,
"n_cpu_tf_sess": self.n_cpu_tf_sess,
"seed": self.seed,
"_vectorize_action": self._vectorize_action,
"policy_kwargs": self.policy_kwargs
}
params_to_save = self.get_parameters()
self._save_to_file(save_path, data=data, params=params_to_save, cloudpickle=cloudpickle)
class A2CRunner(AbstractEnvRunner):
def __init__(self, env, model, n_steps=5, gamma=0.99):
"""
A runner to learn the policy of an environment for an a2c model
:param env: (Gym environment) The environment to learn from
:param model: (Model) The model to learn
:param n_steps: (int) The number of steps to run for each environment
:param gamma: (float) Discount factor
"""
super(A2CRunner, self).__init__(env=env, model=model, n_steps=n_steps)
self.gamma = gamma
def _run(self):
"""
Run a learning step of the model
:return: ([float], [float], [float], [bool], [float], [float])
observations, states, rewards, masks, actions, values
"""
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [], [], [], [], []
mb_states = self.states
ep_infos = []
for _ in range(self.n_steps):
actions, values, states, _ = self.model.step(self.obs, self.states, self.dones) # pytype: disable=attribute-error
mb_obs.append(np.copy(self.obs))
mb_actions.append(actions)
mb_values.append(values)
mb_dones.append(self.dones)
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.env.action_space, gym.spaces.Box):
clipped_actions = np.clip(actions, self.env.action_space.low, self.env.action_space.high)
obs, rewards, dones, infos = self.env.step(clipped_actions)
self.model.num_timesteps += self.n_envs
if self.callback is not None:
# Abort training early
self.callback.update_locals(locals())
if self.callback.on_step() is False:
self.continue_training = False
# Return dummy values
return [None] * 8
for info in infos:
maybe_ep_info = info.get('episode')
if maybe_ep_info is not None:
ep_infos.append(maybe_ep_info)
self.states = states
self.dones = dones
self.obs = obs
mb_rewards.append(rewards)
mb_dones.append(self.dones)
# batch of steps to batch of rollouts
mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype).swapaxes(1, 0).reshape(self.batch_ob_shape)
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(0, 1)
mb_actions = np.asarray(mb_actions, dtype=self.env.action_space.dtype).swapaxes(0, 1)
mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(0, 1)
mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(0, 1)
mb_masks = mb_dones[:, :-1]
mb_dones = mb_dones[:, 1:]
true_rewards = np.copy(mb_rewards)
last_values = self.model.value(self.obs, self.states, self.dones).tolist() # pytype: disable=attribute-error
# discount/bootstrap off value fn
for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)):
rewards = rewards.tolist()
dones = dones.tolist()
if dones[-1] == 0:
rewards = discount_with_dones(rewards + [value], dones + [0], self.gamma)[:-1]
else:
rewards = discount_with_dones(rewards, dones, self.gamma)
mb_rewards[n] = rewards
# convert from [n_env, n_steps, ...] to [n_steps * n_env, ...]
mb_rewards = mb_rewards.reshape(-1, *mb_rewards.shape[2:])
mb_actions = mb_actions.reshape(-1, *mb_actions.shape[2:])
mb_values = mb_values.reshape(-1, *mb_values.shape[2:])
mb_masks = mb_masks.reshape(-1, *mb_masks.shape[2:])
true_rewards = true_rewards.reshape(-1, *true_rewards.shape[2:])
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, ep_infos, true_rewards