Skip to content

Commit

Permalink
Update to pytorch 0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
ikostrikov2 committed Sep 13, 2018
1 parent eb26e29 commit e200eb8
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
9 changes: 7 additions & 2 deletions main.py
Expand Up @@ -97,7 +97,7 @@ def get_value_loss(flat_params):
for param in value_net.parameters(): for param in value_net.parameters():
value_loss += param.pow(2).sum() * args.l2_reg value_loss += param.pow(2).sum() * args.l2_reg
value_loss.backward() value_loss.backward()
return (value_loss.data.double().numpy()[0], get_flat_grad_from(value_net).data.double().numpy()) return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy())


flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net).double().numpy(), maxiter=25) flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net).double().numpy(), maxiter=25)
set_flat_params_to(value_net, torch.Tensor(flat_params)) set_flat_params_to(value_net, torch.Tensor(flat_params))
Expand All @@ -108,7 +108,12 @@ def get_value_loss(flat_params):
fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone() fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()


def get_loss(volatile=False): def get_loss(volatile=False):
action_means, action_log_stds, action_stds = policy_net(Variable(states, volatile=volatile)) if volatile:
with torch.no_grad():
action_means, action_log_stds, action_stds = policy_net(Variable(states))
else:
action_means, action_log_stds, action_stds = policy_net(Variable(states))

log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds) log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob)) action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
return action_loss.mean() return action_loss.mean()
Expand Down
9 changes: 4 additions & 5 deletions models.py
@@ -1,7 +1,6 @@
import torch import torch
import torch.autograd as autograd import torch.autograd as autograd
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F




class Policy(nn.Module): class Policy(nn.Module):
Expand All @@ -21,8 +20,8 @@ def __init__(self, num_inputs, num_outputs):
self.final_value = 0 self.final_value = 0


def forward(self, x): def forward(self, x):
x = F.tanh(self.affine1(x)) x = torch.tanh(self.affine1(x))
x = F.tanh(self.affine2(x)) x = torch.tanh(self.affine2(x))


action_mean = self.action_mean(x) action_mean = self.action_mean(x)
action_log_std = self.action_log_std.expand_as(action_mean) action_log_std = self.action_log_std.expand_as(action_mean)
Expand All @@ -41,8 +40,8 @@ def __init__(self, num_inputs):
self.value_head.bias.data.mul_(0.0) self.value_head.bias.data.mul_(0.0)


def forward(self, x): def forward(self, x):
x = F.tanh(self.affine1(x)) x = torch.tanh(self.affine1(x))
x = F.tanh(self.affine2(x)) x = torch.tanh(self.affine2(x))


state_values = self.value_head(x) state_values = self.value_head(x)
return state_values return state_values
8 changes: 4 additions & 4 deletions trpo.py
Expand Up @@ -32,18 +32,18 @@ def linesearch(model,
max_backtracks=10, max_backtracks=10,
accept_ratio=.1): accept_ratio=.1):
fval = f(True).data fval = f(True).data
print("fval before", fval[0]) print("fval before", fval.item())
for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)): for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
xnew = x + stepfrac * fullstep xnew = x + stepfrac * fullstep
set_flat_params_to(model, xnew) set_flat_params_to(model, xnew)
newfval = f(True).data newfval = f(True).data
actual_improve = fval - newfval actual_improve = fval - newfval
expected_improve = expected_improve_rate * stepfrac expected_improve = expected_improve_rate * stepfrac
ratio = actual_improve / expected_improve ratio = actual_improve / expected_improve
print("a/e/r", actual_improve[0], expected_improve[0], ratio[0]) print("a/e/r", actual_improve.item(), expected_improve.item(), ratio.item())


if ratio[0] > accept_ratio and actual_improve[0] > 0: if ratio.item() > accept_ratio and actual_improve.item() > 0:
print("fval after", newfval[0]) print("fval after", newfval.item())
return True, xnew return True, xnew
return False, x return False, x


Expand Down

0 comments on commit e200eb8

Please sign in to comment.