Skip to content

Commit

Permalink
Use tf.scan's reverse=True for vtrace's vs_minus_v_xs computation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Heiner committed Mar 11, 2019
1 parent fcec1c2 commit e331dc3
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions vtrace.py
Expand Up @@ -244,12 +244,7 @@ def from_importance_weights(
[values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)
deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)

# Note that all sequences are reversed, computation starts from the back.
sequences = (
tf.reverse(discounts, axis=[0]),
tf.reverse(cs, axis=[0]),
tf.reverse(deltas, axis=[0]),
)
sequences = (discounts, cs, deltas)
# V-trace vs are calculated through a scan from the back to the beginning
# of the given trajectory.
def scanfunc(acc, sequence_item):
Expand All @@ -263,9 +258,8 @@ def scanfunc(acc, sequence_item):
initializer=initial_values,
parallel_iterations=1,
back_prop=False,
reverse=True, # Computation starts from the back.
name='scan')
# Reverse the results back to original order.
vs_minus_v_xs = tf.reverse(vs_minus_v_xs, [0], name='vs_minus_v_xs')

# Add V(x_s) to get v_s.
vs = tf.add(vs_minus_v_xs, values, name='vs')
Expand Down

0 comments on commit e331dc3

Please sign in to comment.