Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
PPO
  • Loading branch information
cswinter committed Aug 25, 2019
1 parent ae6fedb commit 573bbe6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
2 changes: 2 additions & 0 deletions hyper_params.py
Expand Up @@ -27,6 +27,8 @@ def __init__(self):
self.lamb = 0.9 # Generalized advantage estimation parameter lambda
self.norm_advs = True # Normalize advantage values
self.rewscale = 20.0 # Scaling of reward values
self.ppo = False # Use PPO-clip instead of vanilla policy gradients objective
self.cliprange = 0.2 # PPO cliprange

# Task
self.objective = envs.Objective.DISTANCE_TO_CRYSTAL
Expand Down
2 changes: 1 addition & 1 deletion main.py
Expand Up @@ -151,7 +151,7 @@ def train(hps: HyperParams) -> None:
advs = torch.tensor(advantages[start:end]).to(device)

optimizer.zero_grad()
policy_loss, value_loss = policy.backprop(o, actions, probs, returns, hps.vf_coef, advs)
policy_loss, value_loss = policy.backprop(hps, o, actions, probs, returns, hps.vf_coef, advs)
episode_loss += policy_loss
batch_value_loss += value_loss
gradnorm += torch.nn.utils.clip_grad_norm_(policy.parameters(), hps.max_grad_norm)
Expand Down
14 changes: 12 additions & 2 deletions policy.py
Expand Up @@ -33,13 +33,23 @@ def evaluate(self, observation):
entropy = action_dist.entropy()
return actions, action_dist.log_prob(actions), entropy, v.detach().view(-1).cpu().numpy()

def backprop(self, obs, actions, old_logprobs, returns, value_loss_scale, advantages):
def backprop(self, hps, obs, actions, old_logprobs, returns, value_loss_scale, advantages):
x = self.latents(obs)
probs = F.softmax(self.policy_head(x), dim=1)

logprobs = distributions.Categorical(probs).log_prob(actions)
# TODO: other code has `logprobs-old_logprobs` and negated loss?
ratios = torch.exp(old_logprobs - logprobs)
vanilla_policy_loss = advantages * ratios
if hps.ppo:
clipped_policy_loss = torch.clamp(ratios, 1 - hps.cliprange, 1 + hps.cliprange) * advantages
policy_loss = torch.min(vanilla_policy_loss, clipped_policy_loss).mean()
else:
policy_loss = vanilla_policy_loss.mean()

baseline = self.value_head(x)
policy_loss = (advantages * torch.exp(old_logprobs - logprobs)).mean()
value_loss = F.mse_loss(returns, baseline.view(-1)).mean()

loss = policy_loss + value_loss_scale * value_loss
loss.backward()
return policy_loss.data.tolist(), value_loss.data.tolist()
Expand Down

0 comments on commit 573bbe6

Please sign in to comment.