Permalink
Browse files

Update to pytorch 0.4

  • Loading branch information...
ikostrikov committed Sep 13, 2018
1 parent eb26e29 commit e200eb8a23b3c7941a0091efb9750dafa4b23cbb
Showing with 15 additions and 11 deletions.
  1. +7 −2 main.py
  2. +4 −5 models.py
  3. +4 −4 trpo.py
@@ -97,7 +97,7 @@ def get_value_loss(flat_params):
for param in value_net.parameters():
value_loss += param.pow(2).sum() * args.l2_reg
value_loss.backward()
return (value_loss.data.double().numpy()[0], get_flat_grad_from(value_net).data.double().numpy())
return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy())

flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net).double().numpy(), maxiter=25)
set_flat_params_to(value_net, torch.Tensor(flat_params))
@@ -108,7 +108,12 @@ def get_value_loss(flat_params):
fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()

def get_loss(volatile=False):
action_means, action_log_stds, action_stds = policy_net(Variable(states, volatile=volatile))
if volatile:
with torch.no_grad():
action_means, action_log_stds, action_stds = policy_net(Variable(states))
else:
action_means, action_log_stds, action_stds = policy_net(Variable(states))

log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
return action_loss.mean()
@@ -1,7 +1,6 @@
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F


class Policy(nn.Module):
@@ -21,8 +20,8 @@ def __init__(self, num_inputs, num_outputs):
self.final_value = 0

def forward(self, x):
x = F.tanh(self.affine1(x))
x = F.tanh(self.affine2(x))
x = torch.tanh(self.affine1(x))
x = torch.tanh(self.affine2(x))

action_mean = self.action_mean(x)
action_log_std = self.action_log_std.expand_as(action_mean)
@@ -41,8 +40,8 @@ def __init__(self, num_inputs):
self.value_head.bias.data.mul_(0.0)

def forward(self, x):
x = F.tanh(self.affine1(x))
x = F.tanh(self.affine2(x))
x = torch.tanh(self.affine1(x))
x = torch.tanh(self.affine2(x))

state_values = self.value_head(x)
return state_values
@@ -32,18 +32,18 @@ def linesearch(model,
max_backtracks=10,
accept_ratio=.1):
fval = f(True).data
print("fval before", fval[0])
print("fval before", fval.item())
for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
xnew = x + stepfrac * fullstep
set_flat_params_to(model, xnew)
newfval = f(True).data
actual_improve = fval - newfval
expected_improve = expected_improve_rate * stepfrac
ratio = actual_improve / expected_improve
print("a/e/r", actual_improve[0], expected_improve[0], ratio[0])
print("a/e/r", actual_improve.item(), expected_improve.item(), ratio.item())

if ratio[0] > accept_ratio and actual_improve[0] > 0:
print("fval after", newfval[0])
if ratio.item() > accept_ratio and actual_improve.item() > 0:
print("fval after", newfval.item())
return True, xnew
return False, x

0 comments on commit e200eb8

Please sign in to comment.