Skip to content

Commit

Permalink
Update to pytorch 0.4.1
Browse files Browse the repository at this point in the history
  • Loading branch information
ikostrikov2 committed Oct 7, 2018
1 parent e898f75 commit 8826e21
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 26 deletions.
17 changes: 8 additions & 9 deletions test.py
Expand Up @@ -3,7 +3,6 @@

import torch
import torch.nn.functional as F
from torch.autograd import Variable

from envs import create_atari_env
from model import ActorCritic
Expand Down Expand Up @@ -34,16 +33,16 @@ def test(rank, args, shared_model, counter):
# Sync with the shared model
if done:
model.load_state_dict(shared_model.state_dict())
cx = Variable(torch.zeros(1, 256), volatile=True)
hx = Variable(torch.zeros(1, 256), volatile=True)
cx = torch.zeros(1, 256)
hx = torch.zeros(1, 256)
else:
cx = Variable(cx.data, volatile=True)
hx = Variable(hx.data, volatile=True)
cx = cx.detach()
hx = hx.detach()

value, logit, (hx, cx) = model((Variable(
state.unsqueeze(0), volatile=True), (hx, cx)))
prob = F.softmax(logit)
action = prob.max(1, keepdim=True)[1].data.numpy()
with torch.no_grad():
value, logit, (hx, cx) = model((state.unsqueeze(0), (hx, cx)))
prob = F.softmax(logit, dim=-1)
action = prob.max(1, keepdim=True)[1].numpy()

state, reward, done, _ = env.step(action[0, 0])
done = done or episode_length >= args.max_episode_length
Expand Down
32 changes: 15 additions & 17 deletions train.py
@@ -1,7 +1,6 @@
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from envs import create_atari_env
from model import ActorCritic
Expand Down Expand Up @@ -37,11 +36,11 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):
# Sync with the shared model
model.load_state_dict(shared_model.state_dict())
if done:
cx = Variable(torch.zeros(1, 256))
hx = Variable(torch.zeros(1, 256))
cx = torch.zeros(1, 256)
hx = torch.zeros(1, 256)
else:
cx = Variable(cx.data)
hx = Variable(hx.data)
cx = cx.detach()
hx = hx.detach()

values = []
log_probs = []
Expand All @@ -50,15 +49,15 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):

for step in range(args.num_steps):
episode_length += 1
value, logit, (hx, cx) = model((Variable(state.unsqueeze(0)),
value, logit, (hx, cx) = model((state.unsqueeze(0),
(hx, cx)))
prob = F.softmax(logit)
log_prob = F.log_softmax(logit)
prob = F.softmax(logit, dim=-1)
log_prob = F.log_softmax(logit, dim=-1)
entropy = -(log_prob * prob).sum(1, keepdim=True)
entropies.append(entropy)

action = prob.multinomial(num_samples=1).data
log_prob = log_prob.gather(1, Variable(action))
action = prob.multinomial(num_samples=1).detach()
log_prob = log_prob.gather(1, action)

state, reward, done, _ = env.step(action.numpy())
done = done or episode_length >= args.max_episode_length
Expand All @@ -81,13 +80,12 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):

R = torch.zeros(1, 1)
if not done:
value, _, _ = model((Variable(state.unsqueeze(0)), (hx, cx)))
R = value.data
value, _, _ = model((state.unsqueeze(0), (hx, cx)))
R = value.detach()

values.append(Variable(R))
values.append(R)
policy_loss = 0
value_loss = 0
R = Variable(R)
gae = torch.zeros(1, 1)
for i in reversed(range(len(rewards))):
R = args.gamma * R + rewards[i]
Expand All @@ -96,16 +94,16 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):

# Generalized Advantage Estimataion
delta_t = rewards[i] + args.gamma * \
values[i + 1].data - values[i].data
values[i + 1] - values[i]
gae = gae * args.gamma * args.tau + delta_t

policy_loss = policy_loss - \
log_probs[i] * Variable(gae) - args.entropy_coef * entropies[i]
log_probs[i] * gae.detach() - args.entropy_coef * entropies[i]

optimizer.zero_grad()

(policy_loss + args.value_loss_coef * value_loss).backward()
torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

ensure_shared_grads(model, shared_model)
optimizer.step()

0 comments on commit 8826e21

Please sign in to comment.