Skip to content

Commit

Permalink
Re-purposing td_k in advantages, so that it can serve also as n-step …
Browse files Browse the repository at this point in the history
…returns in DQN.

PiperOrigin-RevId: 332792184
  • Loading branch information
henrykmichalewski authored and Copybara-Service committed Sep 21, 2020
1 parent 84fe848 commit 5acd9dc
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions trax/rl/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def estimator(rewards, returns, values, dones):


@gin.configurable(blacklist=common_args)
def td_k(gamma, margin):
"""Calculate TD-k advantage.
def td_k(gamma, margin, n_step=False):
"""Calculate TD-k advantage or n_step returns.
The k parameter is assumed to be the same as margin.
Expand All @@ -64,6 +64,9 @@ def td_k(gamma, margin):
Args:
gamma: float, gamma parameter for TD from the underlying task
margin: number of extra steps in the sequence
n_step: if set to True, then we return
gamma^n_steps * value(s_{i + n_steps}) + discounted_rewards
Returns:
Function (rewards, returns, values, dones) -> advantages, where advantages
Expand All @@ -83,7 +86,8 @@ def estimator(rewards, returns, values, dones):
dones = dones[:, :-k]
advantages[dones] = rewards[:, :-k][dones]
# Subtract the baseline (value).
advantages -= values[:, :-k]
if not n_step:
advantages -= values[:, :-k]
return advantages
return estimator

Expand Down

0 comments on commit 5acd9dc

Please sign in to comment.