Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Same bc maths
  • Loading branch information
cswinter committed Aug 26, 2019
1 parent 9bbb2d9 commit 228357d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 10 deletions.
2 changes: 0 additions & 2 deletions hyper_params.py
Expand Up @@ -30,8 +30,6 @@ def __init__(self):
self.ppo = True # Use PPO-clip instead of vanilla policy gradients objective
self.cliprange = 0.2 # PPO cliprange

self.inverted = False # Invert probability ratio on objective (probably wrong? but empiric evidence that it works better under current hyperparameters ¯\_(ツ)_/¯)

# Task
self.objective = envs.Objective.DISTANCE_TO_CRYSTAL
self.game_length = 3 * 60 * 60
Expand Down
11 changes: 3 additions & 8 deletions policy.py
Expand Up @@ -38,18 +38,13 @@ def backprop(self, hps, obs, actions, old_logprobs, returns, value_loss_scale, a
probs = F.softmax(self.policy_head(x), dim=1)

logprobs = distributions.Categorical(probs).log_prob(actions)
if hps.inverted:
ratios = torch.exp(old_logprobs - logprobs)
policy_loss_sign = 1
else:
ratios = torch.exp(logprobs - old_logprobs)
policy_loss_sign = -1
ratios = torch.exp(old_logprobs - logprobs)
vanilla_policy_loss = advantages * ratios
if hps.ppo:
clipped_policy_loss = torch.clamp(ratios, 1 - hps.cliprange, 1 + hps.cliprange) * advantages
policy_loss = policy_loss_sign * torch.min(vanilla_policy_loss, clipped_policy_loss).mean()
policy_loss = torch.min(vanilla_policy_loss, clipped_policy_loss).mean()
else:
policy_loss = policy_loss_sign * vanilla_policy_loss.mean()
policy_loss = vanilla_policy_loss.mean()

approxkl = 0.5 * (old_logprobs - logprobs).pow(2).mean()
clipfrac = ((ratios - 1.0).abs() > hps.cliprange).sum().type(torch.float32) / ratios.numel()
Expand Down

0 comments on commit 228357d

Please sign in to comment.