Skip to content

Commit

Permalink
Merge pull request #34 from heiner/master
Browse files Browse the repository at this point in the history
Small improvements in the vtrace implementation
  • Loading branch information
lespeholt committed Mar 13, 2019
2 parents 2bbbf45 + e331dc3 commit 6c0c8a7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
10 changes: 2 additions & 8 deletions vtrace.py
Expand Up @@ -244,12 +244,7 @@ def from_importance_weights(
[values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)
deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)

# Note that all sequences are reversed, computation starts from the back.
sequences = (
tf.reverse(discounts, axis=[0]),
tf.reverse(cs, axis=[0]),
tf.reverse(deltas, axis=[0]),
)
sequences = (discounts, cs, deltas)
# V-trace vs are calculated through a scan from the back to the beginning
# of the given trajectory.
def scanfunc(acc, sequence_item):
Expand All @@ -263,9 +258,8 @@ def scanfunc(acc, sequence_item):
initializer=initial_values,
parallel_iterations=1,
back_prop=False,
reverse=True, # Computation starts from the back.
name='scan')
# Reverse the results back to original order.
vs_minus_v_xs = tf.reverse(vs_minus_v_xs, [0], name='vs_minus_v_xs')

# Add V(x_s) to get v_s.
vs = tf.add(vs_minus_v_xs, values, name='vs')
Expand Down
2 changes: 1 addition & 1 deletion vtrace_test.py
Expand Up @@ -94,7 +94,7 @@ def test_log_probs_from_logits_and_actions(self, batch_size):

policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10
actions = np.random.randint(
0, num_actions - 1, size=(seq_len, batch_size), dtype=np.int32)
0, num_actions, size=(seq_len, batch_size), dtype=np.int32)

action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions(
policy_logits, actions)
Expand Down

0 comments on commit 6c0c8a7

Please sign in to comment.