Skip to content

Commit

Permalink
Add ability to normalize and add DDPG with Mujoco example (#137)
Browse files Browse the repository at this point in the history
* Add ability to normalize and add DDPG with Mujoco example

* Add documentation

* Add backwards compatibility

* Update tests

* Update docs
  • Loading branch information
matthiasplappert committed Sep 23, 2017
1 parent 68e5d86 commit 35f9b50
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 65 deletions.
5 changes: 5 additions & 0 deletions docs/autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import rl
import rl.core
import rl.processors
import rl.agents


Expand All @@ -28,6 +29,10 @@
'page': 'core.md',
'all_module_classes': [rl.core],
},
{
'page': 'processors.md',
'all_module_classes': [rl.processors],
},
{
'page': 'agents/overview.md',
'functions': [
Expand Down
30 changes: 2 additions & 28 deletions docs/sources/core.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,33 +50,7 @@ or write your own.

----

<span style="float:right;">[[source]](https://github.com/matthiasplappert/keras-rl/blob/master/rl/core.py#L529)</span>
### MultiInputProcessor

```python
rl.core.MultiInputProcessor(nb_inputs)
```

Converts observations from an environment with multiple observations for use in a neural network
policy.

In some cases, you have environments that return multiple different observations per timestep
(in a robotics context, for example, a camera may be used to view the scene and a joint encoder may
be used to report the angles for each joint). Usually, this can be handled by a policy that has
multiple inputs, one for each modality. However, observations are returned by the environment
in the form of a tuple `[(modality1_t, modality2_t, ..., modalityn_t) for t in T]` but the neural network
expects them in per-modality batches like so: `[[modality1_1, ..., modality1_T], ..., [[modalityn_1, ..., modalityn_T]]`.
This processor converts observations appropriate for this use case.

__Arguments__

- __nb_inputs__ (integer): The number of inputs, that is different modalities, to be used.
Your neural network that you use for the policy must have a corresponding number of
inputs.

----

<span style="float:right;">[[source]](https://github.com/matthiasplappert/keras-rl/blob/master/rl/core.py#L566)</span>
<span style="float:right;">[[source]](https://github.com/matthiasplappert/keras-rl/blob/master/rl/core.py#L533)</span>
### Env

```python
Expand All @@ -90,7 +64,7 @@ implementation.

----

<span style="float:right;">[[source]](https://github.com/matthiasplappert/keras-rl/blob/master/rl/core.py#L642)</span>
<span style="float:right;">[[source]](https://github.com/matthiasplappert/keras-rl/blob/master/rl/core.py#L609)</span>
### Space

```python
Expand Down
41 changes: 41 additions & 0 deletions docs/sources/processors.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
<span style="float:right;">[[source]](https://github.com/matthiasplappert/keras-rl/blob/master/rl/processors.py#L7)</span>
### MultiInputProcessor

```python
rl.processors.MultiInputProcessor(nb_inputs)
```

Converts observations from an environment with multiple observations for use in a neural network
policy.

In some cases, you have environments that return multiple different observations per timestep
(in a robotics context, for example, a camera may be used to view the scene and a joint encoder may
be used to report the angles for each joint). Usually, this can be handled by a policy that has
multiple inputs, one for each modality. However, observations are returned by the environment
in the form of a tuple `[(modality1_t, modality2_t, ..., modalityn_t) for t in T]` but the neural network
expects them in per-modality batches like so: `[[modality1_1, ..., modality1_T], ..., [[modalityn_1, ..., modalityn_T]]`.
This processor converts observations appropriate for this use case.

__Arguments__

- __nb_inputs__ (integer): The number of inputs, that is different modalities, to be used.
Your neural network that you use for the policy must have a corresponding number of
inputs.

----

<span style="float:right;">[[source]](https://github.com/matthiasplappert/keras-rl/blob/master/rl/processors.py#L40)</span>
### WhiteningNormalizerProcessor

```python
rl.processors.WhiteningNormalizerProcessor()
```

Normalizes the observations to have zero mean and standard deviation of one,
i.e. it applies whitening to the inputs.

This typically helps significantly with learning, especially if different dimensions are
on different scales. However, it complicates training in the sense that you will have to store
these weights alongside the policy if you intend to load it later. It is the responsibility of
the user to do so.

1 change: 1 addition & 0 deletions docs/templates/processors.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{{autogenerated}}
77 changes: 77 additions & 0 deletions examples/ddpg_mujoco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np

import gym
from gym import wrappers

from keras.models import Sequential, Model
from keras.layers import Dense, Activation, Flatten, Input
from keras.optimizers import Adam

from rl.processors import WhiteningNormalizerProcessor
from rl.agents import DDPGAgent
from rl.memory import SequentialMemory
from rl.random import OrnsteinUhlenbeckProcess
from rl.keras_future import concatenate


class MujocoProcessor(WhiteningNormalizerProcessor):
def process_action(self, action):
return np.clip(action, -1., 1.)


ENV_NAME = 'HalfCheetah-v1'
gym.undo_logger_setup()


# Get the environment and extract the number of actions.
env = gym.make(ENV_NAME)
env = wrappers.Monitor(env, '/tmp/{}'.format(ENV_NAME), force=True)
np.random.seed(123)
env.seed(123)
assert len(env.action_space.shape) == 1
nb_actions = env.action_space.shape[0]

# Next, we build a very simple model.
actor = Sequential()
actor.add(Flatten(input_shape=(1,) + env.observation_space.shape))
actor.add(Dense(400))
actor.add(Activation('relu'))
actor.add(Dense(300))
actor.add(Activation('relu'))
actor.add(Dense(nb_actions))
actor.add(Activation('tanh'))
print(actor.summary())

action_input = Input(shape=(nb_actions,), name='action_input')
observation_input = Input(shape=(1,) + env.observation_space.shape, name='observation_input')
flattened_observation = Flatten()(observation_input)
x = Dense(400)(flattened_observation)
x = Activation('relu')(x)
x = concatenate([x, action_input])
x = Dense(300)(x)
x = Activation('relu')(x)
x = Dense(1)(x)
x = Activation('linear')(x)
critic = Model(input=[action_input, observation_input], output=x)
print(critic.summary())

# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = SequentialMemory(limit=100000, window_length=1)
random_process = OrnsteinUhlenbeckProcess(size=nb_actions, theta=.15, mu=0., sigma=.1)
agent = DDPGAgent(nb_actions=nb_actions, actor=actor, critic=critic, critic_action_input=action_input,
memory=memory, nb_steps_warmup_critic=1000, nb_steps_warmup_actor=1000,
random_process=random_process, gamma=.99, target_model_update=1e-3,
processor=MujocoProcessor())
agent.compile([Adam(lr=1e-4), Adam(lr=1e-3)], metrics=['mae'])

# Okay, now it's time to learn something! We visualize the training here for show, but this
# slows down training quite a lot. You can always safely abort the training prematurely using
# Ctrl + C.
agent.fit(env, nb_steps=1000000, visualize=False, verbose=1)

# After training is done, we save the final weights.
agent.save_weights('ddpg_{}_weights.h5f'.format(ENV_NAME), overwrite=True)

# Finally, evaluate our algorithm for 5 episodes.
agent.test(env, nb_episodes=5, visualize=True, nb_max_episode_steps=200)
33 changes: 0 additions & 33 deletions rl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,39 +526,6 @@ def metrics_names(self):
return []


class MultiInputProcessor(Processor):
"""Converts observations from an environment with multiple observations for use in a neural network
policy.
In some cases, you have environments that return multiple different observations per timestep
(in a robotics context, for example, a camera may be used to view the scene and a joint encoder may
be used to report the angles for each joint). Usually, this can be handled by a policy that has
multiple inputs, one for each modality. However, observations are returned by the environment
in the form of a tuple `[(modality1_t, modality2_t, ..., modalityn_t) for t in T]` but the neural network
expects them in per-modality batches like so: `[[modality1_1, ..., modality1_T], ..., [[modalityn_1, ..., modalityn_T]]`.
This processor converts observations appropriate for this use case.
# Arguments
nb_inputs (integer): The number of inputs, that is different modalities, to be used.
Your neural network that you use for the policy must have a corresponding number of
inputs.
"""
def __init__(self, nb_inputs):
self.nb_inputs = nb_inputs

def process_state_batch(self, state_batch):
input_batches = [[] for x in range(self.nb_inputs)]
for state in state_batch:
processed_state = [[] for x in range(self.nb_inputs)]
for observation in state:
assert len(observation) == self.nb_inputs
for o, s in zip(observation, processed_state):
s.append(o)
for idx, s in enumerate(processed_state):
input_batches[idx].append(s)
return [np.array(x) for x in input_batches]


# Note: the API of the `Env` and `Space` classes are taken from the OpenAI Gym implementation.
# https://github.com/openai/gym/blob/master/gym/core.py

Expand Down
57 changes: 57 additions & 0 deletions rl/processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np

from rl.core import Processor
from rl.util import WhiteningNormalizer


class MultiInputProcessor(Processor):
"""Converts observations from an environment with multiple observations for use in a neural network
policy.
In some cases, you have environments that return multiple different observations per timestep
(in a robotics context, for example, a camera may be used to view the scene and a joint encoder may
be used to report the angles for each joint). Usually, this can be handled by a policy that has
multiple inputs, one for each modality. However, observations are returned by the environment
in the form of a tuple `[(modality1_t, modality2_t, ..., modalityn_t) for t in T]` but the neural network
expects them in per-modality batches like so: `[[modality1_1, ..., modality1_T], ..., [[modalityn_1, ..., modalityn_T]]`.
This processor converts observations appropriate for this use case.
# Arguments
nb_inputs (integer): The number of inputs, that is different modalities, to be used.
Your neural network that you use for the policy must have a corresponding number of
inputs.
"""
def __init__(self, nb_inputs):
self.nb_inputs = nb_inputs

def process_state_batch(self, state_batch):
input_batches = [[] for x in range(self.nb_inputs)]
for state in state_batch:
processed_state = [[] for x in range(self.nb_inputs)]
for observation in state:
assert len(observation) == self.nb_inputs
for o, s in zip(observation, processed_state):
s.append(o)
for idx, s in enumerate(processed_state):
input_batches[idx].append(s)
return [np.array(x) for x in input_batches]


class WhiteningNormalizerProcessor(Processor):
"""Normalizes the observations to have zero mean and standard deviation of one,
i.e. it applies whitening to the inputs.
This typically helps significantly with learning, especially if different dimensions are
on different scales. However, it complicates training in the sense that you will have to store
these weights alongside the policy if you intend to load it later. It is the responsibility of
the user to do so.
"""
def __init__(self):
self.normalizer = None

def process_state_batch(self, batch):
if self.normalizer is None:
self.normalizer = WhiteningNormalizer(shape=batch.shape[1:], dtype=batch.dtype)
self.normalizer.update(batch)
return self.normalizer.normalize(batch)

33 changes: 33 additions & 0 deletions rl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,36 @@ def get_updates(self, params, constraints, loss):

def get_config(self):
return self.optimizer.get_config()


# Based on https://github.com/openai/baselines/blob/master/baselines/common/mpi_running_mean_std.py
class WhiteningNormalizer(object):
def __init__(self, shape, eps=1e-2, dtype=np.float64):
self.eps = eps
self.shape = shape
self.dtype = dtype

self._sum = np.zeros(shape, dtype=dtype)
self._sumsq = np.zeros(shape, dtype=dtype)
self._count = 0

self.mean = np.zeros(shape, dtype=dtype)
self.std = np.ones(shape, dtype=dtype)

def normalize(self, x):
return (x - self.mean) / self.std

def denormalize(self, x):
return self.std * x + self.mean

def update(self, x):
if x.ndim == len(self.shape):
x = x.reshape(-1, *self.shape)
assert x.shape[1:] == self.shape

self._count += x.shape[0]
self._sum += np.sum(x, axis=0)
self._sumsq += np.sum(np.square(x), axis=0)

self.mean = self._sum / float(self._count)
self.std = np.sqrt(np.maximum(np.square(self.eps), self._sumsq / float(self._count) - np.square(self.mean)))
2 changes: 1 addition & 1 deletion tests/rl/agents/test_cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from rl.agents.cem import CEMAgent
from rl.memory import EpisodeParameterMemory
from rl.core import MultiInputProcessor
from rl.processors import MultiInputProcessor

from ..util import MultiInputTestEnv

Expand Down
2 changes: 1 addition & 1 deletion tests/rl/agents/test_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from rl.agents.ddpg import DDPGAgent
from rl.memory import SequentialMemory
from rl.core import MultiInputProcessor
from rl.processors import MultiInputProcessor

from ..util import MultiInputTestEnv

Expand Down
2 changes: 1 addition & 1 deletion tests/rl/agents/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from rl.agents.dqn import NAFLayer, DQNAgent, NAFAgent
from rl.memory import SequentialMemory
from rl.core import MultiInputProcessor
from rl.processors import MultiInputProcessor
from rl.keras_future import concatenate, Model

from ..util import MultiInputTestEnv
Expand Down
19 changes: 18 additions & 1 deletion tests/rl/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from keras.optimizers import SGD
import keras.backend as K

from rl.util import clone_optimizer, clone_model, huber_loss
from rl.util import clone_optimizer, clone_model, huber_loss, WhiteningNormalizer


def test_clone_sequential_model():
Expand Down Expand Up @@ -68,5 +68,22 @@ def test_huber_loss():
assert_allclose(K.eval(huber_loss(a, b, np.inf)), np.array([.125, .125, 2., 2.]))


def test_whitening_normalizer():
x = np.random.normal(loc=.2, scale=2., size=(1000, 5))
normalizer = WhiteningNormalizer(shape=(5,))
normalizer.update(x[:500])
normalizer.update(x[500:])

assert_allclose(normalizer.mean, np.mean(x, axis=0))
assert_allclose(normalizer.std, np.std(x, axis=0))

x_norm = normalizer.normalize(x)
assert_allclose(np.mean(x_norm, axis=0), np.zeros(5, dtype=normalizer.dtype), atol=1e-5)
assert_allclose(np.std(x_norm, axis=0), np.ones(5, dtype=normalizer.dtype), atol=1e-5)

x_denorm = normalizer.denormalize(x_norm)
assert_allclose(x_denorm, x)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 35f9b50

Please sign in to comment.