-
Notifications
You must be signed in to change notification settings - Fork 226
/
trpo.py
876 lines (742 loc) · 35.1 KB
/
trpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
import collections
import itertools
from logging import getLogger
import random
import chainer
import chainer.functions as F
import numpy as np
import chainerrl
from chainerrl import agent
from chainerrl.agents.ppo import _compute_explained_variance
from chainerrl.agents.ppo import _make_dataset
from chainerrl.agents.ppo import _make_dataset_recurrent
from chainerrl.agents.ppo import _yield_subset_of_sequences_with_fixed_number_of_items # NOQA
from chainerrl.misc.batch_states import batch_states
def _get_ordered_params(link):
"""Get a list of parameters sorted by parameter names."""
name_param_pairs = list(link.namedparams())
ordered_name_param_pairs = sorted(name_param_pairs, key=lambda x: x[0])
return [x[1] for x in ordered_name_param_pairs]
def _flatten_and_concat_variables(vs):
"""Flatten and concat variables to make a single flat vector variable."""
return F.concat([F.flatten(v) for v in vs], axis=0)
def _as_ndarray(x):
"""chainer.Variable or ndarray -> ndarray."""
if isinstance(x, chainer.Variable):
return x.array
else:
return x
def _flatten_and_concat_ndarrays(vs):
"""Flatten and concat variables to make a single flat vector ndarray."""
xp = chainer.cuda.get_array_module(vs[0])
return xp.concatenate([_as_ndarray(v).ravel() for v in vs], axis=0)
def _split_and_reshape_to_ndarrays(flat_v, sizes, shapes):
"""Split and reshape a single flat vector to make a list of ndarrays."""
xp = chainer.cuda.get_array_module(flat_v)
sections = np.cumsum(sizes)
vs = xp.split(flat_v, sections)
return [v.reshape(shape) for v, shape in zip(vs, shapes)]
def _replace_params_data(params, new_params_data):
"""Replace data of params with new data."""
for param, new_param_data in zip(params, new_params_data):
assert param.shape == new_param_data.shape
param.array[:] = new_param_data
def _hessian_vector_product(flat_grads, params, vec):
"""Compute hessian vector product efficiently by backprop."""
grads = chainer.grad([F.sum(flat_grads * vec)], params)
assert all(grad is not None for grad in grads),\
"The Hessian-vector product contains None."
grads_data = [grad.array for grad in grads]
return _flatten_and_concat_ndarrays(grads_data)
def _mean_or_nan(xs):
"""Return its mean a non-empty sequence, numpy.nan for a empty one."""
return np.mean(xs) if xs else np.nan
def _find_old_style_function(outputs):
"""Find old-style functions in the computational graph."""
found = []
for v in outputs:
assert isinstance(v, (chainer.Variable, chainer.variable.VariableNode))
if v.creator is None:
continue
if isinstance(v.creator, chainer.Function):
found.append(v.creator)
else:
assert isinstance(v.creator, chainer.FunctionNode)
found.extend(_find_old_style_function(v.creator.inputs))
return found
class TRPO(agent.AttributeSavingMixin, agent.Agent):
"""Trust Region Policy Optimization.
A given stochastic policy is optimized by the TRPO algorithm. A given
value function is also trained to predict by the TD(lambda) algorithm and
used for Generalized Advantage Estimation (GAE).
Since the policy is optimized via the conjugate gradient method and line
search while the value function is optimized via SGD, these two models
should be separate.
Since TRPO requires second-order derivatives to compute Hessian-vector
products, Chainer v3.0.0 or newer is required. In addition, your policy
must contain only functions that support second-order derivatives.
See https://arxiv.org/abs/1502.05477 for TRPO.
See https://arxiv.org/abs/1506.02438 for GAE.
Args:
policy (Policy): Stochastic policy. Its forward computation must
contain only functions that support second-order derivatives.
vf (ValueFunction): Value function.
vf_optimizer (chainer.Optimizer): Optimizer for the value function.
obs_normalizer (chainerrl.links.EmpiricalNormalization or None):
If set to chainerrl.links.EmpiricalNormalization, it is used to
normalize observations based on the empirical mean and standard
deviation of observations. These statistics are updated after
computing advantages and target values and before updating the
policy and the value function.
gamma (float): Discount factor [0, 1]
lambd (float): Lambda-return factor [0, 1]
phi (callable): Feature extractor function
entropy_coef (float): Weight coefficient for entropoy bonus [0, inf)
update_interval (int): Interval steps of TRPO iterations. Every after
this amount of steps, this agent updates the policy and the value
function using data from these steps.
vf_epochs (int): Number of epochs for which the value function is
trained on each TRPO iteration.
vf_batch_size (int): Batch size of SGD for the value function.
standardize_advantages (bool): Use standardized advantages on updates
line_search_max_backtrack (int): Maximum number of backtracking in line
search to tune step sizes of policy updates.
conjugate_gradient_max_iter (int): Maximum number of iterations in
the conjugate gradient method.
conjugate_gradient_damping (float): Damping factor used in the
conjugate gradient method.
act_deterministically (bool): If set to True, choose most probable
actions in the act method instead of sampling from distributions.
value_stats_window (int): Window size used to compute statistics
of value predictions.
entropy_stats_window (int): Window size used to compute statistics
of entropy of action distributions.
kl_stats_window (int): Window size used to compute statistics
of KL divergence between old and new policies.
policy_step_size_stats_window (int): Window size used to compute
statistics of step sizes of policy updates.
Statistics:
average_value: Average of value predictions on non-terminal states.
It's updated before the value function is updated.
average_entropy: Average of entropy of action distributions on
non-terminal states. It's updated on act_and_train.
average_kl: Average of KL divergence between old and new policies.
It's updated after the policy is updated.
average_policy_step_size: Average of step sizes of policy updates
It's updated after the policy is updated.
"""
saved_attributes = ['policy', 'vf', 'vf_optimizer', 'obs_normalizer']
def __init__(self,
policy,
vf,
vf_optimizer,
obs_normalizer=None,
gamma=0.99,
lambd=0.95,
phi=lambda x: x,
entropy_coef=0.01,
update_interval=2048,
max_kl=0.01,
vf_epochs=3,
vf_batch_size=64,
standardize_advantages=True,
batch_states=batch_states,
recurrent=False,
max_recurrent_sequence_len=None,
line_search_max_backtrack=10,
conjugate_gradient_max_iter=10,
conjugate_gradient_damping=1e-2,
act_deterministically=False,
value_stats_window=1000,
entropy_stats_window=1000,
kl_stats_window=100,
policy_step_size_stats_window=100,
logger=getLogger(__name__),
):
self.policy = policy
self.vf = vf
assert policy.xp is vf.xp, 'policy and vf must be on the same device'
if recurrent:
self.model = chainerrl.links.StatelessRecurrentBranched(policy, vf)
else:
self.model = chainerrl.links.Branched(policy, vf)
if policy.xp is not np:
if hasattr(policy, 'device'):
# Link.device is available only from chainer v6
self.model.to_device(policy.device)
else:
self.model.to_gpu(device=policy._device_id)
self.vf_optimizer = vf_optimizer
self.obs_normalizer = obs_normalizer
self.gamma = gamma
self.lambd = lambd
self.phi = phi
self.entropy_coef = entropy_coef
self.update_interval = update_interval
self.max_kl = max_kl
self.vf_epochs = vf_epochs
self.vf_batch_size = vf_batch_size
self.standardize_advantages = standardize_advantages
self.batch_states = batch_states
self.recurrent = recurrent
self.max_recurrent_sequence_len = max_recurrent_sequence_len
self.line_search_max_backtrack = line_search_max_backtrack
self.conjugate_gradient_max_iter = conjugate_gradient_max_iter
self.conjugate_gradient_damping = conjugate_gradient_damping
self.act_deterministically = act_deterministically
self.logger = logger
self.value_record = collections.deque(maxlen=value_stats_window)
self.entropy_record = collections.deque(maxlen=entropy_stats_window)
self.kl_record = collections.deque(maxlen=kl_stats_window)
self.policy_step_size_record = collections.deque(
maxlen=policy_step_size_stats_window)
self.explained_variance = np.nan
assert self.policy.xp is self.vf.xp,\
'policy and vf should be in the same device.'
if self.obs_normalizer is not None:
assert self.policy.xp is self.obs_normalizer.xp,\
'policy and obs_normalizer should be in the same device.'
self.xp = self.policy.xp
self.last_state = None
self.last_action = None
# Contains episodes used for next update iteration
self.memory = []
# Contains transitions of the last episode not moved to self.memory yet
self.last_episode = []
# Batch versions of last_episode, last_state, and last_action
self.batch_last_episode = None
self.batch_last_state = None
self.batch_last_action = None
# Recurrent states of the model
self.train_recurrent_states = None
self.train_prev_recurrent_states = None
self.test_recurrent_states = None
def _initialize_batch_variables(self, num_envs):
self.batch_last_episode = [[] for _ in range(num_envs)]
self.batch_last_state = [None] * num_envs
self.batch_last_action = [None] * num_envs
def _update_if_dataset_is_ready(self):
dataset_size = (
sum(len(episode) for episode in self.memory)
+ len(self.last_episode)
+ (0 if self.batch_last_episode is None else sum(
len(episode) for episode in self.batch_last_episode)))
if dataset_size >= self.update_interval:
self._flush_last_episode()
if self.recurrent:
dataset = _make_dataset_recurrent(
episodes=self.memory,
model=self.model,
phi=self.phi,
batch_states=self.batch_states,
obs_normalizer=self.obs_normalizer,
gamma=self.gamma,
lambd=self.lambd,
max_recurrent_sequence_len=self.max_recurrent_sequence_len,
)
self._update_recurrent(dataset)
else:
dataset = _make_dataset(
episodes=self.memory,
model=self.model,
phi=self.phi,
batch_states=self.batch_states,
obs_normalizer=self.obs_normalizer,
gamma=self.gamma,
lambd=self.lambd,
)
assert len(dataset) == dataset_size
self._update(dataset)
self.explained_variance = _compute_explained_variance(
list(itertools.chain.from_iterable(self.memory)))
self.memory = []
def _flush_last_episode(self):
if self.last_episode:
self.memory.append(self.last_episode)
self.last_episode = []
if self.batch_last_episode:
for i, episode in enumerate(self.batch_last_episode):
if episode:
self.memory.append(episode)
self.batch_last_episode[i] = []
def _update(self, dataset):
"""Update both the policy and the value function."""
if self.obs_normalizer:
self._update_obs_normalizer(dataset)
self._update_policy(dataset)
self._update_vf(dataset)
def _update_recurrent(self, dataset):
"""Update both the policy and the value function."""
flat_dataset = list(itertools.chain.from_iterable(dataset))
if self.obs_normalizer:
self._update_obs_normalizer(flat_dataset)
self._update_policy_recurrent(dataset)
self._update_vf_recurrent(dataset)
def _update_vf_recurrent(self, dataset):
for epoch in range(self.vf_epochs):
random.shuffle(dataset)
for minibatch in _yield_subset_of_sequences_with_fixed_number_of_items( # NOQA
dataset, self.vf_batch_size):
self._update_vf_once_recurrent(minibatch)
def _update_vf_once_recurrent(self, episodes):
xp = self.model.xp
flat_transitions = list(itertools.chain.from_iterable(episodes))
# Prepare data for a recurrent model
seqs_states = []
for ep in episodes:
states = self.batch_states(
[transition['state'] for transition in ep], xp, self.phi)
if self.obs_normalizer:
states = self.obs_normalizer(states, update=False)
seqs_states.append(states)
flat_vs_teacher = xp.array(
[[transition['v_teacher']] for transition in flat_transitions],
dtype=np.float32)
with chainer.using_config('train', False),\
chainer.no_backprop_mode():
vf_rs = self.vf.concatenate_recurrent_states(
[ep[0]['recurrent_state'][1] for ep in episodes])
flat_vs_pred, _ = self.vf.n_step_forward(
seqs_states, recurrent_state=vf_rs, output_mode='concat')
vf_loss = F.mean_squared_error(flat_vs_pred, flat_vs_teacher)
self.vf_optimizer.update(lambda: vf_loss)
def _update_obs_normalizer(self, dataset):
assert self.obs_normalizer
states = batch_states(
[b['state'] for b in dataset], self.obs_normalizer.xp, self.phi)
self.obs_normalizer.experience(states)
def _update_vf(self, dataset):
"""Update the value function using a given dataset.
The value function is updated via SGD to minimize TD(lambda) errors.
"""
xp = self.vf.xp
assert 'state' in dataset[0]
assert 'v_teacher' in dataset[0]
dataset_iter = chainer.iterators.SerialIterator(
dataset, self.vf_batch_size)
while dataset_iter.epoch < self.vf_epochs:
batch = dataset_iter.__next__()
states = batch_states([b['state'] for b in batch], xp, self.phi)
if self.obs_normalizer:
states = self.obs_normalizer(states, update=False)
vs_teacher = xp.array(
[b['v_teacher'] for b in batch], dtype=xp.float32)
vs_pred = self.vf(states)
vf_loss = F.mean_squared_error(vs_pred, vs_teacher[..., None])
self.vf_optimizer.update(lambda: vf_loss)
def _compute_gain(self, log_prob, log_prob_old, entropy, advs):
"""Compute a gain to maximize."""
prob_ratio = F.exp(log_prob - log_prob_old)
mean_entropy = F.mean(entropy)
surrogate_gain = F.mean(prob_ratio * advs)
return surrogate_gain + self.entropy_coef * mean_entropy
def _update_policy(self, dataset):
"""Update the policy using a given dataset.
The policy is updated via CG and line search.
"""
assert 'state' in dataset[0]
assert 'action' in dataset[0]
assert 'adv' in dataset[0]
# Use full-batch
xp = self.policy.xp
states = batch_states([b['state'] for b in dataset], xp, self.phi)
if self.obs_normalizer:
states = self.obs_normalizer(states, update=False)
actions = xp.array([b['action'] for b in dataset])
advs = xp.array([b['adv'] for b in dataset], dtype=np.float32)
if self.standardize_advantages:
mean_advs = xp.mean(advs)
std_advs = xp.std(advs)
advs = (advs - mean_advs) / (std_advs + 1e-8)
# Recompute action distributions for batch backprop
action_distrib = self.policy(states)
log_prob_old = xp.array(
[transition['log_prob'] for transition in dataset],
dtype=np.float32)
gain = self._compute_gain(
log_prob=action_distrib.log_prob(actions),
log_prob_old=log_prob_old,
entropy=action_distrib.entropy,
advs=advs)
# Distribution to compute KL div against
action_distrib_old = action_distrib.copy()
full_step = self._compute_kl_constrained_step(
action_distrib=action_distrib,
action_distrib_old=action_distrib_old,
gain=gain)
self._line_search(
full_step=full_step,
dataset=dataset,
advs=advs,
action_distrib_old=action_distrib_old,
gain=gain)
def _update_policy_recurrent(self, dataset):
"""Update the policy using a given dataset.
The policy is updated via CG and line search.
"""
xp = self.model.xp
flat_transitions = list(itertools.chain.from_iterable(dataset))
# Prepare data for a recurrent model
seqs_states = []
for ep in dataset:
states = self.batch_states(
[transition['state'] for transition in ep], xp, self.phi)
if self.obs_normalizer:
states = self.obs_normalizer(states, update=False)
seqs_states.append(states)
flat_actions = xp.array(
[transition['action'] for transition in flat_transitions])
flat_advs = xp.array(
[transition['adv'] for transition in flat_transitions],
dtype=np.float32)
if self.standardize_advantages:
mean_advs = xp.mean(flat_advs)
std_advs = xp.std(flat_advs)
flat_advs = (flat_advs - mean_advs) / (std_advs + 1e-8)
with chainer.using_config('train', False),\
chainer.no_backprop_mode():
policy_rs = self.policy.concatenate_recurrent_states(
[ep[0]['recurrent_state'][0] for ep in dataset])
flat_distribs, _ = self.policy.n_step_forward(
seqs_states, recurrent_state=policy_rs, output_mode='concat')
log_prob_old = xp.array(
[transition['log_prob'] for transition in flat_transitions],
dtype=np.float32)
gain = self._compute_gain(
log_prob=flat_distribs.log_prob(flat_actions),
log_prob_old=log_prob_old,
entropy=flat_distribs.entropy,
advs=flat_advs)
# Distribution to compute KL div against
action_distrib_old = flat_distribs.copy()
full_step = self._compute_kl_constrained_step(
action_distrib=flat_distribs,
action_distrib_old=action_distrib_old,
gain=gain)
self._line_search(
full_step=full_step,
dataset=dataset,
advs=flat_advs,
action_distrib_old=action_distrib_old,
gain=gain)
def _compute_kl_constrained_step(self, action_distrib, action_distrib_old,
gain):
"""Compute a step of policy parameters with a KL constraint."""
policy_params = _get_ordered_params(self.policy)
kl = F.mean(action_distrib_old.kl(action_distrib))
# Check if kl computation fully supports double backprop
old_style_funcs = _find_old_style_function([kl])
if old_style_funcs:
raise RuntimeError("""\
Old-style functions (chainer.Function) are used to compute KL divergence.
Since TRPO requires second-order derivative of KL divergence, its computation
should be done with new-style functions (chainer.FunctionNode) only.
Found old-style functions: {}""".format(old_style_funcs))
kl_grads = chainer.grad([kl], policy_params,
enable_double_backprop=True)
assert all(g is not None for g in kl_grads), "\
The gradient contains None. The policy may have unused parameters."
flat_kl_grads = _flatten_and_concat_variables(kl_grads)
def fisher_vector_product_func(vec):
fvp = _hessian_vector_product(flat_kl_grads, policy_params, vec)
return fvp + self.conjugate_gradient_damping * vec
gain_grads = chainer.grad([gain], policy_params)
assert all(g is not None for g in kl_grads), "\
The gradient contains None. The policy may have unused parameters."
flat_gain_grads = _flatten_and_concat_ndarrays(gain_grads)
step_direction = chainerrl.misc.conjugate_gradient(
fisher_vector_product_func, flat_gain_grads,
max_iter=self.conjugate_gradient_max_iter,
)
# We want a step size that satisfies KL(old|new) < max_kl.
# Let d = alpha * step_direction be the actual parameter updates.
# The second-order approximation of KL divergence is:
# KL(old|new) = 1/2 d^T I d + O(||d||^3),
# where I is a Fisher information matrix.
# Substitute d = alpha * step_direction and solve KL(old|new) = max_kl
# for alpha to get the step size that tightly satisfies the constraint.
dId = float(step_direction.dot(
fisher_vector_product_func(step_direction)))
scale = (2.0 * self.max_kl / (dId + 1e-8)) ** 0.5
return scale * step_direction
def _line_search(self, full_step, dataset, advs, action_distrib_old, gain):
"""Do line search for a safe step size."""
xp = self.policy.xp
policy_params = _get_ordered_params(self.policy)
policy_params_sizes = [param.size for param in policy_params]
policy_params_shapes = [param.shape for param in policy_params]
step_size = 1.0
flat_params = _flatten_and_concat_ndarrays(policy_params)
if self.recurrent:
seqs_states = []
for ep in dataset:
states = self.batch_states(
[transition['state'] for transition in ep], xp, self.phi)
if self.obs_normalizer:
states = self.obs_normalizer(states, update=False)
seqs_states.append(states)
with chainer.using_config('train', False),\
chainer.no_backprop_mode():
policy_rs = self.policy.concatenate_recurrent_states(
[ep[0]['recurrent_state'][0] for ep in dataset])
def evaluate_current_policy():
distrib, _ = self.policy.n_step_forward(
seqs_states, recurrent_state=policy_rs,
output_mode='concat')
return distrib
else:
states = self.batch_states(
[transition['state'] for transition in dataset], xp, self.phi)
if self.obs_normalizer:
states = self.obs_normalizer(states, update=False)
def evaluate_current_policy():
return self.policy(states)
flat_transitions = (list(itertools.chain.from_iterable(dataset))
if self.recurrent else dataset)
actions = xp.array(
[transition['action'] for transition in flat_transitions])
log_prob_old = xp.array(
[transition['log_prob'] for transition in flat_transitions],
dtype=np.float32)
for i in range(self.line_search_max_backtrack + 1):
self.logger.info(
'Line search iteration: %s step size: %s', i, step_size)
new_flat_params = flat_params + step_size * full_step
new_params = _split_and_reshape_to_ndarrays(
new_flat_params,
sizes=policy_params_sizes,
shapes=policy_params_shapes,
)
_replace_params_data(policy_params, new_params)
with chainer.using_config('train', False),\
chainer.no_backprop_mode():
new_action_distrib = evaluate_current_policy()
new_gain = self._compute_gain(
log_prob=new_action_distrib.log_prob(actions),
log_prob_old=log_prob_old,
entropy=new_action_distrib.entropy,
advs=advs)
new_kl = F.mean(action_distrib_old.kl(new_action_distrib))
improve = new_gain.array - gain.array
self.logger.info(
'Surrogate objective improve: %s', float(improve))
self.logger.info('KL divergence: %s', float(new_kl.array))
if not xp.isfinite(new_gain.array):
self.logger.info(
"Surrogate objective is not finite. Bakctracking...")
elif not xp.isfinite(new_kl.array):
self.logger.info(
"KL divergence is not finite. Bakctracking...")
elif improve < 0:
self.logger.info(
"Surrogate objective didn't improve. Bakctracking...")
elif float(new_kl.array) > self.max_kl:
self.logger.info(
"KL divergence exceeds max_kl. Bakctracking...")
else:
self.kl_record.append(float(new_kl.array))
self.policy_step_size_record.append(step_size)
break
step_size *= 0.5
else:
self.logger.info("\
Line search coundn't find a good step size. The policy was not updated.")
self.policy_step_size_record.append(0.)
_replace_params_data(
policy_params,
_split_and_reshape_to_ndarrays(
flat_params,
sizes=policy_params_sizes,
shapes=policy_params_shapes),
)
def act_and_train(self, obs, reward):
if self.last_state is not None:
transition = {
'state': self.last_state,
'action': self.last_action,
'reward': reward,
'next_state': obs,
'nonterminal': 1.0,
}
if self.recurrent:
transition['recurrent_state'] =\
self.model.get_recurrent_state_at(
self.train_prev_recurrent_states,
0, unwrap_variable=True)
self.train_prev_recurrent_states = None
transition['next_recurrent_state'] =\
self.model.get_recurrent_state_at(
self.train_recurrent_states, 0, unwrap_variable=True)
self.last_episode.append(transition)
self._update_if_dataset_is_ready()
xp = self.xp
b_state = self.batch_states([obs], xp, self.phi)
if self.obs_normalizer:
b_state = self.obs_normalizer(b_state, update=False)
# action_distrib will be recomputed when computing gradients
with chainer.using_config('train', False), chainer.no_backprop_mode():
if self.recurrent:
assert self.train_prev_recurrent_states is None
self.train_prev_recurrent_states = self.train_recurrent_states
(action_distrib, value), self.train_recurrent_states =\
self.model(b_state, self.train_prev_recurrent_states)
else:
action_distrib, value = self.model(b_state)
action = chainer.cuda.to_cpu(action_distrib.sample().array)[0]
self.entropy_record.append(float(action_distrib.entropy.array))
self.value_record.append(float(value.array))
self.last_state = obs
self.last_action = action
return action
def act(self, obs):
xp = self.xp
b_state = self.batch_states([obs], xp, self.phi)
if self.obs_normalizer:
b_state = self.obs_normalizer(b_state, update=False)
with chainer.using_config('train', False), chainer.no_backprop_mode():
if self.recurrent:
action_distrib, self.test_recurrent_states =\
self.policy(b_state, self.test_recurrent_states)
else:
action_distrib = self.policy(b_state)
if self.act_deterministically:
action = chainer.cuda.to_cpu(
action_distrib.most_probable.array)[0]
else:
action = chainer.cuda.to_cpu(
action_distrib.sample().array)[0]
return action
def stop_episode_and_train(self, state, reward, done=False):
assert self.last_state is not None
transition = {
'state': self.last_state,
'action': self.last_action,
'reward': reward,
'next_state': state,
'nonterminal': 0.0 if done else 1.0,
}
if self.recurrent:
transition['recurrent_state'] = self.model.get_recurrent_state_at(
self.train_prev_recurrent_states, 0, unwrap_variable=True)
self.train_prev_recurrent_states = None
transition['next_recurrent_state'] =\
self.model.get_recurrent_state_at(
self.train_recurrent_states, 0, unwrap_variable=True)
self.train_recurrent_states = None
self.last_episode.append(transition)
self.last_state = None
self.last_action = None
self._flush_last_episode()
self.stop_episode()
self._update_if_dataset_is_ready()
def stop_episode(self):
self.test_recurrent_states = None
def batch_act(self, batch_obs):
xp = self.xp
b_state = self.batch_states(batch_obs, xp, self.phi)
if self.obs_normalizer:
b_state = self.obs_normalizer(b_state, update=False)
with chainer.using_config('train', False), chainer.no_backprop_mode():
if self.recurrent:
(action_distrib, _), self.test_recurrent_states = self.model(
b_state, self.test_recurrent_states)
else:
action_distrib, _ = self.model(b_state)
if self.act_deterministically:
action = chainer.cuda.to_cpu(
action_distrib.most_probable.array)
else:
action = chainer.cuda.to_cpu(action_distrib.sample().array)
return action
def batch_act_and_train(self, batch_obs):
xp = self.xp
b_state = self.batch_states(batch_obs, xp, self.phi)
if self.obs_normalizer:
b_state = self.obs_normalizer(b_state, update=False)
num_envs = len(batch_obs)
if self.batch_last_episode is None:
self._initialize_batch_variables(num_envs)
assert len(self.batch_last_episode) == num_envs
assert len(self.batch_last_state) == num_envs
assert len(self.batch_last_action) == num_envs
# action_distrib will be recomputed when computing gradients
with chainer.using_config('train', False), chainer.no_backprop_mode():
if self.recurrent:
assert self.train_prev_recurrent_states is None
self.train_prev_recurrent_states = self.train_recurrent_states
(action_distrib, batch_value), self.train_recurrent_states =\
self.model(b_state, self.train_prev_recurrent_states)
else:
action_distrib, batch_value = self.model(b_state)
batch_action = chainer.cuda.to_cpu(action_distrib.sample().array)
self.entropy_record.extend(
chainer.cuda.to_cpu(action_distrib.entropy.array))
self.value_record.extend(chainer.cuda.to_cpu((batch_value.array)))
self.batch_last_state = list(batch_obs)
self.batch_last_action = list(batch_action)
return batch_action
def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
if self.recurrent:
# Reset recurrent states when episodes end
indices_that_ended = [
i for i, (done, reset)
in enumerate(zip(batch_done, batch_reset)) if done or reset]
if indices_that_ended:
self.test_recurrent_states =\
self.model.mask_recurrent_state_at(
self.test_recurrent_states, indices_that_ended)
def batch_observe_and_train(self, batch_obs, batch_reward,
batch_done, batch_reset):
for i, (state, action, reward, next_state, done, reset) in enumerate(zip( # NOQA
self.batch_last_state,
self.batch_last_action,
batch_reward,
batch_obs,
batch_done,
batch_reset,
)):
if state is not None:
assert action is not None
transition = {
'state': state,
'action': action,
'reward': reward,
'next_state': next_state,
'nonterminal': 0.0 if done else 1.0,
}
if self.recurrent:
transition['recurrent_state'] =\
self.model.get_recurrent_state_at(
self.train_prev_recurrent_states,
i, unwrap_variable=True)
transition['next_recurrent_state'] =\
self.model.get_recurrent_state_at(
self.train_recurrent_states,
i, unwrap_variable=True)
self.batch_last_episode[i].append(transition)
if done or reset:
assert self.batch_last_episode[i]
self.memory.append(self.batch_last_episode[i])
self.batch_last_episode[i] = []
self.batch_last_state[i] = None
self.batch_last_action[i] = None
self.train_prev_recurrent_states = None
if self.recurrent:
# Reset recurrent states when episodes end
indices_that_ended = [
i for i, (done, reset)
in enumerate(zip(batch_done, batch_reset)) if done or reset]
if indices_that_ended:
self.train_recurrent_states =\
self.model.mask_recurrent_state_at(
self.train_recurrent_states, indices_that_ended)
self._update_if_dataset_is_ready()
def get_statistics(self):
return [
('average_value', _mean_or_nan(self.value_record)),
('average_entropy', _mean_or_nan(self.entropy_record)),
('average_kl', _mean_or_nan(self.kl_record)),
('average_policy_step_size',
_mean_or_nan(self.policy_step_size_record)),
('explained_variance', self.explained_variance),
]