Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Fix bug in objective?
  • Loading branch information
cswinter committed Aug 26, 2019
1 parent 573bbe6 commit 7ef78a5
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions policy.py
Expand Up @@ -38,14 +38,13 @@ def backprop(self, hps, obs, actions, old_logprobs, returns, value_loss_scale, a
probs = F.softmax(self.policy_head(x), dim=1)

logprobs = distributions.Categorical(probs).log_prob(actions)
# TODO: other code has `logprobs-old_logprobs` and negated loss?
ratios = torch.exp(old_logprobs - logprobs)
ratios = torch.exp(logprobs - old_logprobs)
vanilla_policy_loss = advantages * ratios
if hps.ppo:
clipped_policy_loss = torch.clamp(ratios, 1 - hps.cliprange, 1 + hps.cliprange) * advantages
policy_loss = torch.min(vanilla_policy_loss, clipped_policy_loss).mean()
policy_loss = -torch.min(vanilla_policy_loss, clipped_policy_loss).mean()
else:
policy_loss = vanilla_policy_loss.mean()
policy_loss = -vanilla_policy_loss.mean()

baseline = self.value_head(x)
value_loss = F.mse_loss(returns, baseline.view(-1)).mean()
Expand Down

0 comments on commit 7ef78a5

Please sign in to comment.