From e331dc3140e5ef53e62ac753f56539db4c4430eb Mon Sep 17 00:00:00 2001 From: Heiner Date: Mon, 28 Jan 2019 13:27:17 +0000 Subject: [PATCH] Use tf.scan's reverse=True for vtrace's vs_minus_v_xs computation. --- vtrace.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vtrace.py b/vtrace.py index d185051..f2a9f77 100644 --- a/vtrace.py +++ b/vtrace.py @@ -244,12 +244,7 @@ def from_importance_weights( [values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0) deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) - # Note that all sequences are reversed, computation starts from the back. - sequences = ( - tf.reverse(discounts, axis=[0]), - tf.reverse(cs, axis=[0]), - tf.reverse(deltas, axis=[0]), - ) + sequences = (discounts, cs, deltas) # V-trace vs are calculated through a scan from the back to the beginning # of the given trajectory. def scanfunc(acc, sequence_item): @@ -263,9 +258,8 @@ def scanfunc(acc, sequence_item): initializer=initial_values, parallel_iterations=1, back_prop=False, + reverse=True, # Computation starts from the back. name='scan') - # Reverse the results back to original order. - vs_minus_v_xs = tf.reverse(vs_minus_v_xs, [0], name='vs_minus_v_xs') # Add V(x_s) to get v_s. vs = tf.add(vs_minus_v_xs, values, name='vs')