Permalink
Browse files

Migrate to pytorch 0.4

  • Loading branch information...
ikostrikov committed Apr 25, 2018
1 parent 9d2ad2e commit e7f739f3ad1a377bfbcd0b92c55c2b9d947764fc
Showing with 45 additions and 51 deletions.
  1. +2 −2 README.md
  2. +11 −14 algo/a2c_acktr.py
  3. +1 −1 algo/kfac.py
  4. +6 −9 algo/ppo.py
  5. +2 −3 distributions.py
  6. +6 −7 enjoy.py
  7. +3 −2 envs.py
  8. +11 −10 main.py
  9. +3 −3 model.py
@@ -99,13 +99,13 @@ python main.py --env-name "PongNoFrameskip-v4" --algo acktr --num-processes 32 -
#### A2C

```bash
python main.py --env-name "Reacher-v1" --num-stack 1 --num-frames 1000000
python main.py --env-name "Reacher-v2" --num-stack 1 --num-frames 1000000
```

#### PPO

```bash
python main.py --env-name "Reacher-v1" --algo ppo --use-gae --vis-interval 1 --log-interval 1 --num-stack 1 --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 1 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-frames 1000000
python main.py --env-name "Reacher-v2" --algo ppo --use-gae --vis-interval 1 --log-interval 1 --num-stack 1 --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 1 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-frames 1000000
```

#### ACKTR
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

from .kfac import KFACOptimizer

@@ -37,32 +36,30 @@ def update(self, rollouts):
num_steps, num_processes, _ = rollouts.rewards.size()

values, action_log_probs, dist_entropy, states = self.actor_critic.evaluate_actions(
Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
Variable(rollouts.states[0].view(-1,
self.actor_critic.state_size)),
Variable(rollouts.masks[:-1].view(-1, 1)),
Variable(rollouts.actions.view(-1, action_shape)))
rollouts.observations[:-1].view(-1, *obs_shape),
rollouts.states[0].view(-1, self.actor_critic.state_size),
rollouts.masks[:-1].view(-1, 1),
rollouts.actions.view(-1, action_shape))

values = values.view(num_steps, num_processes, 1)
action_log_probs = action_log_probs.view(num_steps, num_processes, 1)

advantages = Variable(rollouts.returns[:-1]) - values
advantages = rollouts.returns[:-1] - values
value_loss = advantages.pow(2).mean()

action_loss = -(Variable(advantages.data) * action_log_probs).mean()
action_loss = -(advantages.detach() * action_log_probs).mean()

if self.acktr and self.optimizer.steps % self.optimizer.Ts == 0:
# Sampled fisher, see Martens 2014
self.actor_critic.zero_grad()
pg_fisher_loss = -action_log_probs.mean()

value_noise = Variable(torch.randn(values.size()))
value_noise = torch.randn(values.size())
if values.is_cuda:
value_noise = value_noise.cuda()

sample_values = values + value_noise
vf_fisher_loss = -(
values - Variable(sample_values.data)).pow(2).mean()
vf_fisher_loss = -(values - sample_values.detach()).pow(2).mean()

fisher_loss = pg_fisher_loss + vf_fisher_loss
self.optimizer.acc_stats = True
@@ -74,9 +71,9 @@ def update(self, rollouts):
dist_entropy * self.entropy_coef).backward()

if self.acktr == False:
nn.utils.clip_grad_norm(self.actor_critic.parameters(),
self.max_grad_norm)
nn.utils.clip_grad_norm_(self.actor_critic.parameters(),
self.max_grad_norm)

self.optimizer.step()

return value_loss, action_loss, dist_entropy
return value_loss.item(), action_loss.item(), dist_entropy.item()
@@ -142,7 +142,7 @@ def split_bias(module):
momentum=self.momentum)

def _save_input(self, module, input):
if input[0].volatile == False and self.steps % self.Ts == 0:
if torch.is_grad_enabled() and self.steps % self.Ts == 0:
classname = module.__class__.__name__
layer_info = None
if classname == 'Conv2d':
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

from .kfac import KFACOptimizer

@@ -51,24 +50,22 @@ def update(self, rollouts):

# Reshape to do in a single forward pass for all steps
values, action_log_probs, dist_entropy, states = self.actor_critic.evaluate_actions(
Variable(observations_batch), Variable(states_batch),
Variable(masks_batch), Variable(actions_batch))
observations_batch, states_batch,
masks_batch, actions_batch)

adv_targ = Variable(adv_targ)
ratio = torch.exp(
action_log_probs - Variable(old_action_log_probs_batch))
ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
surr1 = ratio * adv_targ
surr2 = torch.clamp(ratio, 1.0 - self.clip_param,
1.0 + self.clip_param) * adv_targ
action_loss = -torch.min(surr1, surr2).mean()

value_loss = (Variable(return_batch) - values).pow(2).mean()
value_loss = (return_batch - values).pow(2).mean()

self.optimizer.zero_grad()
(value_loss * self.value_loss_coef + action_loss -
dist_entropy * self.entropy_coef).backward()
nn.utils.clip_grad_norm(self.actor_critic.parameters(),
self.max_grad_norm)
nn.utils.clip_grad_norm_(self.actor_critic.parameters(),
self.max_grad_norm)
self.optimizer.step()

return value_loss, action_loss, dist_entropy
@@ -3,7 +3,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from utils import AddBias


@@ -48,7 +47,7 @@ def forward(self, x):
action_mean = self.fc_mean(x)

# An ugly hack for my KFAC implementation.
zeros = Variable(torch.zeros(action_mean.size()), volatile=x.volatile)
zeros = torch.zeros(action_mean.size())
if x.is_cuda:
zeros = zeros.cuda()

@@ -61,7 +60,7 @@ def sample(self, x, deterministic):
action_std = action_logstd.exp()

if deterministic is False:
noise = Variable(torch.randn(action_std.size()))
noise = torch.randn(action_std.size())
if action_std.is_cuda:
noise = noise.cuda()
action = action_mean + action_std * noise
@@ -4,7 +4,6 @@

import numpy as np
import torch
from torch.autograd import Variable
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.vec_normalize import VecNormalize

@@ -76,12 +75,12 @@ def update_current_obs(obs):
torsoId = i

while True:
value, action, _, states = actor_critic.act(Variable(current_obs, volatile=True),
Variable(states, volatile=True),
Variable(masks, volatile=True),
deterministic=True)
states = states.data
cpu_actions = action.data.squeeze(1).cpu().numpy()
with torch.no_grad():
value, action, _, states = actor_critic.act(current_obs,
states,
masks,
deterministic=True)
cpu_actions = action.squeeze(1).cpu().numpy()
# Obser reward and next obs
obs, reward, done, _ = env.step(cpu_actions)

@@ -49,8 +49,9 @@ def __init__(self, env=None):
self.observation_space = Box(
self.observation_space.low[0,0,0],
self.observation_space.high[0,0,0],
[obs_shape[2], obs_shape[1], obs_shape[0]]
[obs_shape[2], obs_shape[1], obs_shape[0]],
dtype=self.observation_space.dtype
)

def _observation(self, observation):
def observation(self, observation):
return observation.transpose(2, 0, 1)
21 main.py
@@ -9,7 +9,6 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from arguments import get_args
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
@@ -125,10 +124,11 @@ def update_current_obs(obs):
for j in range(num_updates):
for step in range(args.num_steps):
# Sample actions
value, action, action_log_prob, states = actor_critic.act(
Variable(rollouts.observations[step], volatile=True),
Variable(rollouts.states[step], volatile=True),
Variable(rollouts.masks[step], volatile=True))
with torch.no_grad():
value, action, action_log_prob, states = actor_critic.act(
rollouts.observations[step],
rollouts.states[step],
rollouts.masks[step])
cpu_actions = action.data.squeeze(1).cpu().numpy()

# Obser reward and next obs
@@ -153,9 +153,10 @@ def update_current_obs(obs):
update_current_obs(obs)
rollouts.insert(current_obs, states.data, action.data, action_log_prob.data, value.data, reward, masks)

next_value = actor_critic.get_value(Variable(rollouts.observations[-1], volatile=True),
Variable(rollouts.states[-1], volatile=True),
Variable(rollouts.masks[-1], volatile=True)).data
with torch.no_grad():
next_value = actor_critic.get_value(rollouts.observations[-1],
rollouts.states[-1],
rollouts.masks[-1]).detach()

rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)

@@ -189,8 +190,8 @@ def update_current_obs(obs):
final_rewards.mean(),
final_rewards.median(),
final_rewards.min(),
final_rewards.max(), dist_entropy.data[0],
value_loss.data[0], action_loss.data[0]))
final_rewards.max(), dist_entropy,
value_loss, action_loss))
if args.vis and j % args.vis_interval == 0:
try:
# Sometimes monitor doesn't properly flush the outputs
@@ -7,7 +7,7 @@
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1 or classname.find('Linear') != -1:
nn.init.orthogonal(m.weight.data)
nn.init.orthogonal_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0)

@@ -100,8 +100,8 @@ def mult_gain(m):
self.main.apply(mult_gain)

if hasattr(self, 'gru'):
nn.init.orthogonal(self.gru.weight_ih.data)
nn.init.orthogonal(self.gru.weight_hh.data)
nn.init.orthogonal_(self.gru.weight_ih.data)
nn.init.orthogonal_(self.gru.weight_hh.data)
self.gru.bias_ih.data.fill_(0)
self.gru.bias_hh.data.fill_(0)

0 comments on commit e7f739f

Please sign in to comment.