Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/seba-1511/cherry
Browse files Browse the repository at this point in the history
  • Loading branch information
seba-1511 committed Sep 9, 2019
2 parents cce50a6 + 7ef3868 commit b28d75d
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 12 deletions.
4 changes: 2 additions & 2 deletions README.md
@@ -1,8 +1,8 @@
<p align="center"><img src="https://seba-1511.github.io/cherry/assets/img/cherry_full.png" height="150px" /></p>
<p align="center"><img src="http://cherry-rl.net/assets/img/cherry_full.png" height="150px" /></p>

--------------------------------------------------------------------------------

[![Build Status](https://travis-ci.org/seba-1511/cherry.svg?branch=master)](https://travis-ci.org/seba-1511/cherry)
[![Build Status](https://travis-ci.org/learnables/cherry.svg?branch=master)](https://travis-ci.org/learnables/cherry)

Cherry is a reinforcement learning framework for researchers built on top of PyTorch.

Expand Down
2 changes: 1 addition & 1 deletion cherry/distributions.py
Expand Up @@ -104,7 +104,7 @@ def __init__(self, env, logstd=None, use_probs=False, reparam=False):
super(ActionDistribution, self).__init__()
self.use_probs = use_probs
self.reparam = reparam
self.is_discrete = env.discrete_action
self.is_discrete = ch.envs.is_discrete(env.action_space)
if not self.is_discrete:
if logstd is None:
action_size = ch.envs.get_space_dimension(env.action_space)
Expand Down
4 changes: 3 additions & 1 deletion cherry/envs/__init__.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

from .utils import get_space_dimension
from .utils import *
from .base import Wrapper
from .runner_wrapper import Runner
from .logger_wrapper import Logger
Expand All @@ -11,6 +11,8 @@
from .monitor_wrapper import Monitor
from .recorder_wrapper import Recorder
from .normalizer_wrapper import Normalizer
from .state_normalizer_wrapper import StateNormalizer
from .reward_normalizer_wrapper import RewardNormalizer
from .state_lambda_wrapper import StateLambda
from .action_lambda_wrapper import ActionLambda
from .action_space_scaler_wrapper import ActionSpaceScaler
Expand Down
11 changes: 9 additions & 2 deletions cherry/envs/action_space_scaler_wrapper.py
Expand Up @@ -25,9 +25,16 @@ def __init__(self, env, clip=1.0):
def reset(self, *args, **kwargs):
return self.env.reset(*args, **kwargs)

def step(self, action):
def _normalize(self, action):
lb = self.env.action_space.low
ub = self.env.action_space.high
scaled_action = lb + (action + self.clip) * 0.5 * (ub - lb)
scaled_action = np.clip(scaled_action, lb, ub)
return self.env.step(scaled_action)
return scaled_action

def step(self, action):
if self.is_vectorized:
action = [self._normalize(a) for a in action]
else:
action = self._normalize(action)
return self.env.step(action)
69 changes: 69 additions & 0 deletions cherry/envs/reward_normalizer_wrapper.py
@@ -0,0 +1,69 @@
#!/usr/bin/env python3

import numpy as np
from .base import Wrapper


class RewardNormalizer(Wrapper):

"""
[[Source]](https://github.com/seba-1511/cherry/blob/master/cherry/envs/normalizer_wrapper.py)
**Description**
Normalizes the rewards with a running average.
**Arguments**
* **env** (Environment) - Environment to normalize.
* **statistics** (dict, *optional*, default=None) - Dictionary used to
bootstrap the normalizing statistics.
* **beta** (float, *optional*, default=0.99) - Moving average weigth.
* **eps** (float, *optional*, default=1e-8) - Numerical stability.
**Credit**
Adapted from Tristan Deleu's implementation.
**Example**
~~~python
env = gym.make('CartPole-v0')
env = cherry.envs.RewardNormalizer(env)
env2 = gym.make('CartPole-v0')
env2 = cherry.envs.RewardNormalizer(env2,
statistics=env.statistics)
~~~
"""

def __init__(self, env, statistics=None, beta=0.99, eps=1e-8):
super(RewardNormalizer, self).__init__(env)
self.beta = beta
self.eps = eps
if statistics is not None and 'mean' in statistics:
self._reward_mean = np.copy(statistics['mean'])
else:
self._reward_mean = np.zeros(self.observation_space.shape)

if statistics is not None and 'var' in statistics:
self._reward_var = np.copy(statistics['var'])
else:
self._reward_var = np.ones(self.observation_space.shape)

@property
def statistics(self):
return {
'mean': self._reward_mean,
'var': self._reward_var,
}

def _reward_normalize(self, reward):
self._reward_mean = self.beta * self._reward_mean + (1.0 - self.beta) * reward
self._reward_var = self.beta * self._reward_var + (1.0 - self.beta) * np.square(reward, self._reward_mean)

def reset(self, *args, **kwargs):
reward = self.env.reset(*args, **kwargs)
return self._reward_normalize(reward)

def step(self, *args, **kwargs):
state, reward, done, infos = self.env.step(*args, **kwargs)
return state, self._reward_normalize(reward), done, infos
69 changes: 69 additions & 0 deletions cherry/envs/state_normalizer_wrapper.py
@@ -0,0 +1,69 @@
#!/usr/bin/env python3

import numpy as np
from .base import Wrapper


class StateNormalizer(Wrapper):

"""
[[Source]](https://github.com/seba-1511/cherry/blob/master/cherry/envs/normalizer_wrapper.py)
**Description**
Normalizes the states with a running average.
**Arguments**
* **env** (Environment) - Environment to normalize.
* **statistics** (dict, *optional*, default=None) - Dictionary used to
bootstrap the normalizing statistics.
* **beta** (float, *optional*, default=0.99) - Moving average weigth.
* **eps** (float, *optional*, default=1e-8) - Numerical stability.
**Credit**
Adapted from Tristan Deleu's implementation.
**Example**
~~~python
env = gym.make('CartPole-v0')
env = cherry.envs.StateNormalizer(env)
env2 = gym.make('CartPole-v0')
env2 = cherry.envs.StateNormalizer(env2,
statistics=env.statistics)
~~~
"""

def __init__(self, env, statistics=None, beta=0.99, eps=1e-8):
super(StateNormalizer, self).__init__(env)
self.beta = beta
self.eps = eps
if statistics is not None and 'mean' in statistics:
self._state_mean = np.copy(statistics['mean'])
else:
self._state_mean = np.zeros(self.observation_space.shape)

if statistics is not None and 'var' in statistics:
self._state_var = np.copy(statistics['var'])
else:
self._state_var = np.ones(self.observation_space.shape)

@property
def statistics(self):
return {
'mean': self._state_mean,
'var': self._state_var,
}

def _state_normalize(self, state):
self._state_mean = self.beta * self._state_mean + (1.0 - self.beta) * state
self._state_var = self.beta * self._state_var + (1.0 - self.beta) * np.square(state, self._state_mean)

def reset(self, *args, **kwargs):
state = self.env.reset(*args, **kwargs)
return self._state_normalize(state)

def step(self, *args, **kwargs):
state, reward, done, infos = self.env.step(*args, **kwargs)
return self._state_normalize(state), reward, done, infos
2 changes: 1 addition & 1 deletion cherry/envs/utils.py
Expand Up @@ -16,7 +16,7 @@


def is_vectorized(env):
return hasattr(env, 'num_envs')
return hasattr(env, 'num_envs') and env.num_envs > 1


def is_discrete(space, vectorized=False):
Expand Down
8 changes: 5 additions & 3 deletions docs/pydocmd.yml
Expand Up @@ -28,6 +28,8 @@ generate:
- cherry.envs.visdom_logger_wrapper.VisdomLogger++
- cherry.envs.torch_wrapper.Torch++
- cherry.envs.normalizer_wrapper.Normalizer++
- cherry.envs.state_normalizer_wrapper.StateNormalizer++
- cherry.envs.reward_normalizer_wrapper.RewardNormalizer++
- cherry.envs.reward_clipper_wrapper.RewardClipper++
- cherry.envs.monitor_wrapper.Monitor++
- cherry.envs.openai_atari_wrapper.OpenAIAtari++
Expand Down Expand Up @@ -77,8 +79,8 @@ pages:
- cherry.pg: docs/cherry.pg.md
- cherry.plot: docs/cherry.plot.md
- cherry.td: docs/cherry.td.md
- Examples: https://github.com/seba-1511/cherry/tree/master/examples
- GitHub: https://github.com/seba-1511/cherry/
- Examples: https://github.com/learnables/cherry/tree/master/examples
- GitHub: https://github.com/learnables/cherry/

# These options all show off their default values. You don't have to add
# them to your configuration if you're fine with the default.
Expand All @@ -87,7 +89,7 @@ gens_dir: _build/pydocmd # This will end up as the MkDocs 'docs_dir'
site_dir: _build/site
site_url: http://cherry-rl.net
site_author: Seb Arnold
google_analytics: ['UA-68693545-3', 'seba-1511.github.com']
google_analytics: ['UA-68693545-3', 'learnables.github.com']
theme:
name: mkdocs
custom_dir: 'cherry_theme/'
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -26,8 +26,8 @@
long_description_content_type='text/markdown',
author='Seb Arnold',
author_email='smr.arnold@gmail.com',
url='https://seba-1511.github.com/cherry',
download_url='https://github.com/seba-1511/cherry/archive/' + str(VERSION) + '.zip',
url='https://learnables.github.com/cherry',
download_url='https://github.com/learnables/cherry/archive/' + str(VERSION) + '.zip',
license='License :: OSI Approved :: Apache Software License',
classifiers=[],
scripts=[],
Expand Down

0 comments on commit b28d75d

Please sign in to comment.