Skip to content

Commit

Permalink
Merge pull request #3 from pranz24/master
Browse files Browse the repository at this point in the history
pytorch-rl
  • Loading branch information
ikostrikov2 committed Dec 10, 2017
2 parents 3398d0b + 5093533 commit 7848013
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 27 deletions.
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
### Description
Reimplementation of [Continuous Deep Q-Learning with Model-based Acceleration](https://arxiv.org/pdf/1603.00748v1.pdf).
Reimplementation of [Continuous Deep Q-Learning with Model-based Acceleration](https://arxiv.org/pdf/1603.00748v1.pdf) and [Continuous control with deep reinforcement learning](https://arxiv.org/pdf/1509.02971.pdf).

Contributions are welcome. If you know how to make it more stable, don't hesitate to send a pull request.

### Run
Use the default hyperparameters.

```python
python main.py
#### For NAF:

```
python main.py --algo NAF
```
#### For DDPG

```
python main.py --algo DDPG
```
159 changes: 159 additions & 0 deletions ddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import sys

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.autograd import Variable
import torch.nn.functional as F

MSELoss = nn.MSELoss()

def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

def hard_update(target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)


class Actor(nn.Module):

def __init__(self, hidden_size, num_inputs, action_space):
super(Actor, self).__init__()
self.action_space = action_space
num_outputs = action_space.shape[0]

self.bn0 = nn.BatchNorm1d(num_inputs)
self.bn0.weight.data.fill_(1)
self.bn0.bias.data.fill_(0)

self.linear1 = nn.Linear(num_inputs, hidden_size)
self.bn1 = nn.BatchNorm1d(hidden_size)
self.bn1.weight.data.fill_(1)
self.bn1.bias.data.fill_(0)

self.linear2 = nn.Linear(hidden_size, hidden_size)
self.bn2 = nn.BatchNorm1d(hidden_size)
self.bn2.weight.data.fill_(1)
self.bn2.bias.data.fill_(0)

self.mu = nn.Linear(hidden_size, num_outputs)
self.mu.weight.data.mul_(0.1)
self.mu.bias.data.mul_(0.1)


def forward(self, inputs):
x = inputs
x = self.bn0(x)
x = F.tanh(self.linear1(x))
x = F.tanh(self.linear2(x))

mu = F.tanh(self.mu(x))
return mu


class Critic(nn.Module):

def __init__(self, hidden_size, num_inputs, action_space):
super(Critic, self).__init__()
self.action_space = action_space
num_outputs = action_space.shape[0]
self.bn0 = nn.BatchNorm1d(num_inputs)
self.bn0.weight.data.fill_(1)
self.bn0.bias.data.fill_(0)

self.linear1 = nn.Linear(num_inputs, hidden_size)
self.bn1 = nn.BatchNorm1d(hidden_size)
self.bn1.weight.data.fill_(1)
self.bn1.bias.data.fill_(0)

self.linear_action = nn.Linear(num_outputs, hidden_size)
self.bn_a = nn.BatchNorm1d(hidden_size)
self.bn_a.weight.data.fill_(1)
self.bn_a.bias.data.fill_(0)

self.linear2 = nn.Linear(hidden_size + hidden_size, hidden_size)
self.bn2 = nn.BatchNorm1d(hidden_size)
self.bn2.weight.data.fill_(1)
self.bn2.bias.data.fill_(0)

self.V = nn.Linear(hidden_size, 1)
self.V.weight.data.mul_(0.1)
self.V.bias.data.mul_(0.1)

def forward(self, inputs, actions):
x = inputs
x = self.bn0(x)
x = F.tanh(self.linear1(x))
a = F.tanh(self.linear_action(actions))
x = torch.cat((x, a), 1)
x = F.tanh(self.linear2(x))

V = self.V(x)
return V


class DDPG(object):
def __init__(self, gamma, tau, hidden_size, num_inputs, action_space):

self.num_inputs = num_inputs
self.action_space = action_space

self.actor = Actor(hidden_size, self.num_inputs, self.action_space)
self.actor_target = Actor(hidden_size, self.num_inputs, self.action_space)
self.actor_optim = Adam(self.actor.parameters(), lr=1e-4)

self.critic = Critic(hidden_size, self.num_inputs, self.action_space)
self.critic_target = Critic(hidden_size, self.num_inputs, self.action_space)
self.critic_optim = Adam(self.critic.parameters(), lr=1e-3)

self.gamma = gamma
self.tau = tau

hard_update(self.actor_target, self.actor) # Make sure target is with the same weight
hard_update(self.critic_target, self.critic)


def select_action(self, state, exploration=None):
self.actor.eval()
mu = self.actor((Variable(state, volatile=True)))
self.actor.train()
mu = mu.data
if exploration is not None:
mu += torch.Tensor(exploration.noise())

return mu.clamp(-1, 1)


def update_parameters(self, batch):
state_batch = Variable(torch.cat(batch.state))
next_state_batch = Variable(torch.cat(batch.next_state), volatile=True)
action_batch = Variable(torch.cat(batch.action))
reward_batch = Variable(torch.cat(batch.reward))
mask_batch = Variable(torch.cat(batch.mask))

next_action_batch = self.actor_target(next_state_batch)
next_state_action_values = self.critic_target(next_state_batch, next_action_batch)

reward_batch = torch.unsqueeze(reward_batch, 1)
expected_state_action_batch = reward_batch + (self.gamma * next_state_action_values)

self.critic_optim.zero_grad()

state_action_batch = self.critic((state_batch), (action_batch))

value_loss = MSELoss(state_action_batch, expected_state_action_batch)
value_loss.backward()
self.critic_optim.step()

self.actor_optim.zero_grad()

policy_loss = -self.critic((state_batch),self.actor((state_batch)))

policy_loss = policy_loss.mean()
policy_loss.backward()
self.actor_optim.step()

soft_update(self.actor_target, self.actor, self.tau)
soft_update(self.critic_target, self.critic, self.tau)
22 changes: 15 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from gym import wrappers

import torch
from ddpg import DDPG
from naf import NAF
from normalized_actions import NormalizedActions
from ounoise import OUNoise
from replay_memory import ReplayMemory, Transition

parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--algo', default='NAF',
help='algorithm to use: DDPG | NAF')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.001, metavar='G',
Expand All @@ -24,10 +27,10 @@
help='final noise scale (default: 0.3)')
parser.add_argument('--exploration_end', type=int, default=100, metavar='N',
help='number of episodes with noise (default: 100)')
parser.add_argument('--seed', type=int, default=42, metavar='N',
help='random seed (default: 42)')
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='batch size (default: 64)')
parser.add_argument('--seed', type=int, default=4, metavar='N',
help='random seed (default: 4)')
parser.add_argument('--batch_size', type=int, default=128, metavar='N',
help='batch size (default: 128)')
parser.add_argument('--num_steps', type=int, default=1000, metavar='N',
help='max episode length (default: 1000)')
parser.add_argument('--num_episodes', type=int, default=1000, metavar='N',
Expand All @@ -50,9 +53,13 @@
env.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if args.algo == "NAF":
agent = NAF(args.gamma, args.tau, args.hidden_size,
env.observation_space.shape[0], env.action_space)
else:
agent = DDPG(args.gamma, args.tau, args.hidden_size,
env.observation_space.shape[0], env.action_space)

agent = NAF(args.gamma, args.tau, args.hidden_size,
env.observation_space.shape[0], env.action_space)
memory = ReplayMemory(args.replay_size)
ounoise = OUNoise(env.action_space.shape[0])

Expand Down Expand Up @@ -89,6 +96,7 @@
agent.update_parameters(batch)

if done:

break
rewards.append(episode_reward)
else:
Expand All @@ -112,5 +120,5 @@
rewards.append(episode_reward)
print("Episode: {}, noise: {}, reward: {}, average reward: {}".format(i_episode, ounoise.scale,
rewards[-1], np.mean(rewards[-100:])))

env.close()
40 changes: 23 additions & 17 deletions naf.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import sys

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
from torch.optim import Adam
from torch.autograd import Variable
import torch.nn.functional as F

MSELoss = nn.MSELoss()

def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

def hard_update(target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)

class Policy(nn.Module):

Expand Down Expand Up @@ -43,7 +50,7 @@ def __init__(self, hidden_size, num_inputs, action_space):
self.L.bias.data.mul_(0.1)

self.tril_mask = Variable(torch.tril(torch.ones(
num_outputs, num_outputs), k=-1).unsqueeze(0))
num_outputs, num_outputs), diagonal=-1).unsqueeze(0))
self.diag_mask = Variable(torch.diag(torch.diag(
torch.ones(num_outputs, num_outputs))).unsqueeze(0))

Expand Down Expand Up @@ -78,15 +85,16 @@ class NAF:

def __init__(self, gamma, tau, hidden_size, num_inputs, action_space):
self.action_space = action_space
self.num_inputs = num_inputs

self.model = Policy(hidden_size, num_inputs, action_space)
self.target_model = Policy(hidden_size, num_inputs, action_space)
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
self.optimizer = Adam(self.model.parameters(), lr=1e-3)

self.gamma = gamma
self.tau = tau

for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
target_param.data.copy_(param.data)
hard_update(self.target_model, self.model)

def select_action(self, state, exploration=None):
self.model.eval()
Expand All @@ -106,19 +114,17 @@ def update_parameters(self, batch):
mask_batch = Variable(torch.cat(batch.mask))

_, _, next_state_values = self.target_model((next_state_batch, None))
next_state_values.volatile = False
expected_state_action_values = (
next_state_values * self.gamma) + reward_batch

reward_batch = (torch.unsqueeze(reward_batch, 1))
expected_state_action_values = reward_batch + (next_state_values * self.gamma)

_, state_action_values, _ = self.model((state_batch, action_batch))

loss = (state_action_values - expected_state_action_values).pow(2).mean()
loss = MSELoss(state_action_values, expected_state_action_values)

self.optimizer.zero_grad()
loss.backward()
for param in self.model.parameters():
param.grad.data.clamp(-1, 1)
torch.nn.utils.clip_grad_norm(self.model.parameters(), 1)
self.optimizer.step()

for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - self.tau) + param.data * self.tau)
soft_update(self.target_model, self.model, self.tau)

0 comments on commit 7848013

Please sign in to comment.