Skip to content

Commit

Permalink
Merge pull request #6 from pranz24/master
Browse files Browse the repository at this point in the history
Updates for pytorch 0.4 and gym 10.2
  • Loading branch information
ikostrikov2 committed May 1, 2018
2 parents 7848013 + 9dbbfe0 commit 6215d4c
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 12 deletions.
4 changes: 2 additions & 2 deletions README.md
Expand Up @@ -9,10 +9,10 @@ Use the default hyperparameters.
#### For NAF: #### For NAF:


``` ```
python main.py --algo NAF python main.py --algo NAF --env-name HalfCheetah-v2
``` ```
#### For DDPG #### For DDPG


``` ```
python main.py --algo DDPG python main.py --algo DDPG --env-name HalfCheetah-v2
``` ```
11 changes: 7 additions & 4 deletions ddpg.py
Expand Up @@ -6,7 +6,8 @@
from torch.autograd import Variable from torch.autograd import Variable
import torch.nn.functional as F import torch.nn.functional as F


MSELoss = nn.MSELoss() def MSELoss(input, target):
return torch.sum((input - target)**2) / input.data.nelement()


def soft_update(target, source, tau): def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()): for target_param, param in zip(target.parameters(), source.parameters()):
Expand Down Expand Up @@ -117,7 +118,8 @@ def __init__(self, gamma, tau, hidden_size, num_inputs, action_space):


def select_action(self, state, exploration=None): def select_action(self, state, exploration=None):
self.actor.eval() self.actor.eval()
mu = self.actor((Variable(state, volatile=True))) with torch.no_grad():
mu = self.actor((Variable(state)))
self.actor.train() self.actor.train()
mu = mu.data mu = mu.data
if exploration is not None: if exploration is not None:
Expand All @@ -128,11 +130,12 @@ def select_action(self, state, exploration=None):


def update_parameters(self, batch): def update_parameters(self, batch):
state_batch = Variable(torch.cat(batch.state)) state_batch = Variable(torch.cat(batch.state))
next_state_batch = Variable(torch.cat(batch.next_state), volatile=True)
action_batch = Variable(torch.cat(batch.action)) action_batch = Variable(torch.cat(batch.action))
reward_batch = Variable(torch.cat(batch.reward)) reward_batch = Variable(torch.cat(batch.reward))
mask_batch = Variable(torch.cat(batch.mask)) mask_batch = Variable(torch.cat(batch.mask))

with torch.no_grad():
next_state_batch = Variable(torch.cat(batch.next_state))

next_action_batch = self.actor_target(next_state_batch) next_action_batch = self.actor_target(next_state_batch)
next_state_action_values = self.critic_target(next_state_batch, next_action_batch) next_state_action_values = self.critic_target(next_state_batch, next_action_batch)


Expand Down
7 changes: 4 additions & 3 deletions main.py
Expand Up @@ -17,6 +17,8 @@
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example') parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--algo', default='NAF', parser.add_argument('--algo', default='NAF',
help='algorithm to use: DDPG | NAF') help='algorithm to use: DDPG | NAF')
parser.add_argument('--env-name', default="HalfCheetah-v2",
help='name of the environment to run')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G', parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='discount factor for reward (default: 0.99)') help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.001, metavar='G', parser.add_argument('--tau', type=float, default=0.001, metavar='G',
Expand Down Expand Up @@ -45,10 +47,9 @@
help='render the environment') help='render the environment')
args = parser.parse_args() args = parser.parse_args()


env_name = 'Pendulum-v0' env = NormalizedActions(gym.make(args.env_name))
env = NormalizedActions(gym.make(env_name))


env = wrappers.Monitor(env, '/tmp/{}-experiment'.format(env_name), force=True) env = wrappers.Monitor(env, '/tmp/{}-experiment'.format(args.env_name), force=True)


env.seed(args.seed) env.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
Expand Down
9 changes: 6 additions & 3 deletions naf.py
Expand Up @@ -6,7 +6,8 @@
from torch.autograd import Variable from torch.autograd import Variable
import torch.nn.functional as F import torch.nn.functional as F


MSELoss = nn.MSELoss() def MSELoss(input, target):
return torch.sum((input - target)**2) / input.data.nelement()


def soft_update(target, source, tau): def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()): for target_param, param in zip(target.parameters(), source.parameters()):
Expand Down Expand Up @@ -98,7 +99,8 @@ def __init__(self, gamma, tau, hidden_size, num_inputs, action_space):


def select_action(self, state, exploration=None): def select_action(self, state, exploration=None):
self.model.eval() self.model.eval()
mu, _, _ = self.model((Variable(state, volatile=True), None)) with torch.no_grad():
mu, _, _ = self.model((Variable(state), None))
self.model.train() self.model.train()
mu = mu.data mu = mu.data
if exploration is not None: if exploration is not None:
Expand All @@ -108,10 +110,11 @@ def select_action(self, state, exploration=None):


def update_parameters(self, batch): def update_parameters(self, batch):
state_batch = Variable(torch.cat(batch.state)) state_batch = Variable(torch.cat(batch.state))
next_state_batch = Variable(torch.cat(batch.next_state), volatile=True)
action_batch = Variable(torch.cat(batch.action)) action_batch = Variable(torch.cat(batch.action))
reward_batch = Variable(torch.cat(batch.reward)) reward_batch = Variable(torch.cat(batch.reward))
mask_batch = Variable(torch.cat(batch.mask)) mask_batch = Variable(torch.cat(batch.mask))
with torch.no_grad():
next_state_batch = Variable(torch.cat(batch.next_state))


_, _, next_state_values = self.target_model((next_state_batch, None)) _, _, next_state_values = self.target_model((next_state_batch, None))


Expand Down

0 comments on commit 6215d4c

Please sign in to comment.