From fcec1c299ddc08fc9d6dd21677fb965bc6643682 Mon Sep 17 00:00:00 2001 From: Heiner Date: Mon, 28 Jan 2019 13:22:53 +0000 Subject: [PATCH 1/2] Sample all actions in vtrace_test. Numpy's randint(a, b) samples from [a,b), including a but excluding b. --- vtrace_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vtrace_test.py b/vtrace_test.py index 2de05d1..f264e64 100644 --- a/vtrace_test.py +++ b/vtrace_test.py @@ -94,7 +94,7 @@ def test_log_probs_from_logits_and_actions(self, batch_size): policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10 actions = np.random.randint( - 0, num_actions - 1, size=(seq_len, batch_size), dtype=np.int32) + 0, num_actions, size=(seq_len, batch_size), dtype=np.int32) action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions( policy_logits, actions) From e331dc3140e5ef53e62ac753f56539db4c4430eb Mon Sep 17 00:00:00 2001 From: Heiner Date: Mon, 28 Jan 2019 13:27:17 +0000 Subject: [PATCH 2/2] Use tf.scan's reverse=True for vtrace's vs_minus_v_xs computation. --- vtrace.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vtrace.py b/vtrace.py index d185051..f2a9f77 100644 --- a/vtrace.py +++ b/vtrace.py @@ -244,12 +244,7 @@ def from_importance_weights( [values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0) deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) - # Note that all sequences are reversed, computation starts from the back. - sequences = ( - tf.reverse(discounts, axis=[0]), - tf.reverse(cs, axis=[0]), - tf.reverse(deltas, axis=[0]), - ) + sequences = (discounts, cs, deltas) # V-trace vs are calculated through a scan from the back to the beginning # of the given trajectory. def scanfunc(acc, sequence_item): @@ -263,9 +258,8 @@ def scanfunc(acc, sequence_item): initializer=initial_values, parallel_iterations=1, back_prop=False, + reverse=True, # Computation starts from the back. name='scan') - # Reverse the results back to original order. - vs_minus_v_xs = tf.reverse(vs_minus_v_xs, [0], name='vs_minus_v_xs') # Add V(x_s) to get v_s. vs = tf.add(vs_minus_v_xs, values, name='vs')