Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Implement GAE?
  • Loading branch information
cswinter committed Aug 25, 2019
1 parent 5b17812 commit b75f9de
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion baselines/baselines/ppo2/ppo2.py
Expand Up @@ -155,7 +155,7 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
mbinds = inds[start:end]
slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
mblossvals.append(model.train(lrnow, cliprangenow, *slices))
else: # recurrent version
else: # recurrent version
assert nenvs % nminibatches == 0
envsperbatch = nenvs // nminibatches
envinds = np.arange(nenvs)
Expand Down
3 changes: 3 additions & 0 deletions hyper_params.py
Expand Up @@ -19,11 +19,14 @@ def __init__(self):
self.width = 1024 # Number of activations on each hidden layer
self.conv = True # Use convolution to share weights on objects

self.fp16 = True

# RL
self.steps = 2e7 # Total number of timesteps
self.seq_rosteps = 64 # Number of sequential steps per rollout
self.rosteps = 64 * 32 # Number of total rollout steps
self.gamma = 0.9 # Discount factor
self.lamb = 0.9 # Generalized advantage estimation parameter lambda
self.norm_advs = True # Normalize advantage values
self.rewscale = 20.0 # Scaling of reward values

Expand Down
20 changes: 13 additions & 7 deletions main.py
Expand Up @@ -103,19 +103,25 @@ def train(hps: HyperParams) -> None:
obs_tensor = torch.tensor(obs).to(device)
_, _, _, final_values = policy.evaluate(obs_tensor)

all_rewards = np.array(all_rewards) * hps.rewscale
all_returns = np.zeros(len(all_rewards), dtype=np.float32)
ret = np.array(final_values)
retscale = (1.0 - hps.gamma) * hps.rewscale
all_values = np.array(all_values)
last_gae = np.zeros(num_envs)
for t in reversed(range(hps.seq_rosteps)):
# TODO: correct for action delay?
# TODO: vectorize
for i in range(num_envs):
ti = t * num_envs + i
ret[i] = hps.gamma * ret[i] + all_rewards[ti]
all_returns[ti] = ret[i] * retscale
if all_dones[ti] == 1:
ret[i] = 0
tnext_i = (t + 1) * num_envs + i
nextnonterminal = 1.0 - all_dones[ti]
if t == hps.seq_rosteps - 1:
next_value = final_values[i]
else:
next_value = all_values[tnext_i]
td_error = all_rewards[ti] + hps.gamma * next_value * nextnonterminal - all_values[ti]
last_gae[i] = td_error + hps.gamma * hps.lamb * last_gae[i] * nextnonterminal
all_returns[ti] = last_gae[i] + all_values[ti]

all_values = np.array(all_values)
advantages = all_returns - all_values
if hps.norm_advs:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
Expand Down

0 comments on commit b75f9de

Please sign in to comment.