Permalink
Browse files

Merge pull request #6 from pranz24/master

Updates for pytorch 0.4 and gym 10.2
  • Loading branch information...
ikostrikov committed May 1, 2018
2 parents 7848013 + 9dbbfe0 commit 6215d4c6472075a437c432d1c1d894aa19797c17
Showing with 19 additions and 12 deletions.
  1. +2 −2 README.md
  2. +7 −4 ddpg.py
  3. +4 −3 main.py
  4. +6 −3 naf.py
@@ -9,10 +9,10 @@ Use the default hyperparameters.
#### For NAF:

```
python main.py --algo NAF
python main.py --algo NAF --env-name HalfCheetah-v2
```
#### For DDPG

```
python main.py --algo DDPG
python main.py --algo DDPG --env-name HalfCheetah-v2
```
11 ddpg.py
@@ -6,7 +6,8 @@
from torch.autograd import Variable
import torch.nn.functional as F

MSELoss = nn.MSELoss()
def MSELoss(input, target):
return torch.sum((input - target)**2) / input.data.nelement()

def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
@@ -117,7 +118,8 @@ def __init__(self, gamma, tau, hidden_size, num_inputs, action_space):

def select_action(self, state, exploration=None):
self.actor.eval()
mu = self.actor((Variable(state, volatile=True)))
with torch.no_grad():
mu = self.actor((Variable(state)))
self.actor.train()
mu = mu.data
if exploration is not None:
@@ -128,11 +130,12 @@ def select_action(self, state, exploration=None):

def update_parameters(self, batch):
state_batch = Variable(torch.cat(batch.state))
next_state_batch = Variable(torch.cat(batch.next_state), volatile=True)
action_batch = Variable(torch.cat(batch.action))
reward_batch = Variable(torch.cat(batch.reward))
mask_batch = Variable(torch.cat(batch.mask))

with torch.no_grad():
next_state_batch = Variable(torch.cat(batch.next_state))

next_action_batch = self.actor_target(next_state_batch)
next_state_action_values = self.critic_target(next_state_batch, next_action_batch)

@@ -17,6 +17,8 @@
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--algo', default='NAF',
help='algorithm to use: DDPG | NAF')
parser.add_argument('--env-name', default="HalfCheetah-v2",
help='name of the environment to run')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.001, metavar='G',
@@ -45,10 +47,9 @@
help='render the environment')
args = parser.parse_args()

env_name = 'Pendulum-v0'
env = NormalizedActions(gym.make(env_name))
env = NormalizedActions(gym.make(args.env_name))

env = wrappers.Monitor(env, '/tmp/{}-experiment'.format(env_name), force=True)
env = wrappers.Monitor(env, '/tmp/{}-experiment'.format(args.env_name), force=True)

env.seed(args.seed)
torch.manual_seed(args.seed)
9 naf.py
@@ -6,7 +6,8 @@
from torch.autograd import Variable
import torch.nn.functional as F

MSELoss = nn.MSELoss()
def MSELoss(input, target):
return torch.sum((input - target)**2) / input.data.nelement()

def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
@@ -98,7 +99,8 @@ def __init__(self, gamma, tau, hidden_size, num_inputs, action_space):

def select_action(self, state, exploration=None):
self.model.eval()
mu, _, _ = self.model((Variable(state, volatile=True), None))
with torch.no_grad():
mu, _, _ = self.model((Variable(state), None))
self.model.train()
mu = mu.data
if exploration is not None:
@@ -108,10 +110,11 @@ def select_action(self, state, exploration=None):

def update_parameters(self, batch):
state_batch = Variable(torch.cat(batch.state))
next_state_batch = Variable(torch.cat(batch.next_state), volatile=True)
action_batch = Variable(torch.cat(batch.action))
reward_batch = Variable(torch.cat(batch.reward))
mask_batch = Variable(torch.cat(batch.mask))
with torch.no_grad():
next_state_batch = Variable(torch.cat(batch.next_state))

_, _, next_state_values = self.target_model((next_state_batch, None))

0 comments on commit 6215d4c

Please sign in to comment.