Skip to content

Commit

Permalink
Fix bugs introduces with keepdim
Browse files Browse the repository at this point in the history
  • Loading branch information
ikostrikov2 committed Oct 7, 2017
1 parent 8238c58 commit 7650592
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 17 deletions.
6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions README.md
Expand Up @@ -14,6 +14,16 @@ Contributions are very welcome. If you know how to make this code better, don't
python main.py --env-name "Reacher-v1"
```

## Recommended hyper parameters

InvertedPendulum-v1: 5000

Reacher-v1, InvertedDoublePendulum-v1: 15000

HalfCheetah-v1, Hopper-v1, Swimmer-v1, Walker2d-v1: 25000

Ant-v1, Humanoid-v1: 50000

## Results

More or less similar to the original code. Coming soon.
Expand Down
17 changes: 5 additions & 12 deletions main.py
@@ -1,26 +1,20 @@
import argparse
import sys
from collections import namedtuple
from itertools import count

import gym
import numpy as np
import scipy.optimize
from gym import wrappers

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
from models import *
from replay_memory import Memory
from running_state import ZFilter
from torch.autograd import Variable
from trpo import trpo_step
from utils import *

torch.utils.backcompat.broadcast_warning.enabled = True
torch.utils.backcompat.keepdim_warning.enabled = True

torch.set_default_tensor_type('torch.DoubleTensor')

parser = argparse.ArgumentParser(description='PyTorch actor-critic example')
Expand Down Expand Up @@ -86,7 +80,6 @@ def update_params(batch):
prev_value = values.data[i, 0]
prev_advantage = advantages[i, 0]

values_ = value_net(Variable(states))
targets = Variable(returns)

# Original code uses the same LBFGS to optimize the value loss
Expand All @@ -112,7 +105,7 @@ def get_value_loss(flat_params):
advantages = (advantages - advantages.mean()) / advantages.std()

action_means, action_log_stds, action_stds = policy_net(Variable(states))
fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data
fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()

def get_loss(volatile=False):
action_means, action_log_stds, action_stds = policy_net(Variable(states, volatile=volatile))
Expand All @@ -128,7 +121,7 @@ def get_kl():
log_std0 = Variable(log_std1.data)
std0 = Variable(std1.data)
kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
return kl.sum(1)
return kl.sum(1, keepdim=True)

trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping)

Expand Down
6 changes: 3 additions & 3 deletions trpo.py
Expand Up @@ -50,7 +50,7 @@ def linesearch(model,

def trpo_step(model, get_loss, get_kl, max_kl, damping):
loss = get_loss()
grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
grads = torch.autograd.grad(loss, model.parameters())
loss_grad = torch.cat([grad.view(-1) for grad in grads]).data

def Fvp(v):
Expand All @@ -68,12 +68,12 @@ def Fvp(v):

stepdir = conjugate_gradients(Fvp, -loss_grad, 10)

shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0)
shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)

lm = torch.sqrt(shs / max_kl)
fullstep = stepdir / lm[0]

neggdotstepdir = (-loss_grad * stepdir).sum(0)
neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)
print(("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm()))

prev_params = get_flat_params_from(model)
Expand Down
4 changes: 2 additions & 2 deletions utils.py
Expand Up @@ -8,14 +8,14 @@
def normal_entropy(std):
var = std.pow(2)
entropy = 0.5 + 0.5 * torch.log(2 * var * math.pi)
return entropy.sum(1)
return entropy.sum(1, keepdim=True)


def normal_log_density(x, mean, log_std, std):
var = std.pow(2)
log_density = -(x - mean).pow(2) / (
2 * var) - 0.5 * math.log(2 * math.pi) - log_std
return log_density.sum(1)
return log_density.sum(1, keepdim=True)


def get_flat_params_from(model):
Expand Down

0 comments on commit 7650592

Please sign in to comment.