Skip to content

Commit

Permalink
Merge pull request #462 from muupan/stop-using-underscore-methods
Browse files Browse the repository at this point in the history
Support `gym>=0.12.2` by stopping to use underscore methods in gym wrappers
  • Loading branch information
toslunar committed Jun 12, 2019
2 parents 98c3c8a + 33a4d39 commit 4e4d16c
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 24 deletions.
30 changes: 15 additions & 15 deletions chainerrl/wrappers/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, env, noop_max=30):
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

def _reset(self, **kwargs):
def reset(self, **kwargs):
"""Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset(**kwargs)
if self.override_num_noops is not None:
Expand All @@ -48,7 +48,7 @@ def _reset(self, **kwargs):
obs = self.env.reset(**kwargs)
return obs

def _step(self, ac):
def step(self, ac):
return self.env.step(ac)


Expand All @@ -59,7 +59,7 @@ def __init__(self, env):
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3

def _reset(self, **kwargs):
def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, info = self.env.step(1)
if done or info.get('needs_reset', False):
Expand All @@ -69,7 +69,7 @@ def _reset(self, **kwargs):
self.env.reset(**kwargs)
return obs

def _step(self, ac):
def step(self, ac):
return self.env.step(ac)


Expand All @@ -83,7 +83,7 @@ def __init__(self, env):
self.lives = 0
self.needs_real_reset = True

def _step(self, action):
def step(self, action):
obs, reward, done, info = self.env.step(action)
self.needs_real_reset = done or info.get('needs_reset', False)
# check current lives, make loss of life terminal,
Expand All @@ -98,7 +98,7 @@ def _step(self, action):
self.lives = lives
return obs, reward, done, info

def _reset(self, **kwargs):
def reset(self, **kwargs):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
Expand All @@ -122,7 +122,7 @@ def __init__(self, env, skip=4):
(2,) + env.observation_space.shape, dtype=np.uint8)
self._skip = skip

def _step(self, action):
def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
Expand All @@ -141,15 +141,15 @@ def _step(self, action):

return max_frame, total_reward, done, info

def _reset(self, **kwargs):
def reset(self, **kwargs):
return self.env.reset(**kwargs)


class ClipRewardEnv(gym.RewardWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env)

def _reward(self, reward):
def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward)

Expand All @@ -173,7 +173,7 @@ def __init__(self, env, channel_order='hwc'):
low=0, high=255,
shape=shape[channel_order], dtype=np.uint8)

def _observation(self, frame):
def observation(self, frame):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (self.width, self.height),
interpolation=cv2.INTER_AREA)
Expand All @@ -200,13 +200,13 @@ def __init__(self, env, k, channel_order='hwc'):
self.observation_space = spaces.Box(
low=low, high=high, dtype=orig_obs_space.dtype)

def _reset(self):
def reset(self):
ob = self.env.reset()
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()

def _step(self, action):
def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info
Expand All @@ -231,11 +231,11 @@ def __init__(self, env):

orig_obs_space = env.observation_space
self.observation_space = spaces.Box(
low=self._observation(orig_obs_space.low),
high=self._observation(orig_obs_space.high),
low=self.observation(orig_obs_space.low),
high=self.observation(orig_obs_space.high),
dtype=np.float32)

def _observation(self, observation):
def observation(self, observation):
# careful! This undoes the memory optimization, use
# with smaller replay buffers only.
return np.array(observation).astype(np.float32) / self.scale
Expand Down
2 changes: 1 addition & 1 deletion chainerrl/wrappers/cast_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, env, dtype):
super().__init__(env)
self.dtype = dtype

def _observation(self, observation):
def observation(self, observation):
self.original_observation = observation
return observation.astype(self.dtype, copy=False)

Expand Down
2 changes: 1 addition & 1 deletion chainerrl/wrappers/randomize_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, env, random_fraction):
self._random_fraction = random_fraction
self._np_random = np.random.RandomState()

def _action(self, action):
def action(self, action):
if self._np_random.rand() < self._random_fraction:
return self._np_random.randint(self.env.action_space.n)
else:
Expand Down
2 changes: 1 addition & 1 deletion chainerrl/wrappers/scale_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ def __init__(self, env, scale):
self.scale = scale
self.original_reward = None

def _reward(self, reward):
def reward(self, reward):
self.original_reward = reward
return self.scale * reward
2 changes: 1 addition & 1 deletion examples/grasping/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ This directory contains example scripts that learn to grasp objects in an enviro

## Requirements

- pybullet>=2.1.2
- pybullet>=2.4.9

## How to run

Expand Down
6 changes: 3 additions & 3 deletions examples/grasping/train_dqn_batch_grasping.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, env, type_):
super().__init__(env)
self.type_ = type_

def _action(self, action):
def action(self, action):
return self.type_(action)


Expand All @@ -49,7 +49,7 @@ def __init__(self, env, axes):
dtype=env.observation_space.dtype,
)

def _observation(self, observation):
def observation(self, observation):
return observation.transpose(*self._axes)


Expand All @@ -73,7 +73,7 @@ def reset(self):
self._elapsed_steps = 0
return self.env.reset(), self._elapsed_steps

def _step(self, action):
def step(self, action):
observation, reward, done, info = self.env.step(action)
self._elapsed_steps += 1
assert self._elapsed_steps <= self._max_steps
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ chainer>=4.0.0
fastcache; python_version<'3.2'
funcsigs; python_version<'3.5'
future
gym>=0.9.7,<=0.12.1
gym>=0.9.7
numpy>=1.10.4
pillow
scipy
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
'cached-property',
'chainer>=2.0.0',
'future',
'gym>=0.9.7,<=0.12.1',
'gym>=0.9.7',
'numpy>=1.10.4',
'pillow',
'scipy',
Expand Down

0 comments on commit 4e4d16c

Please sign in to comment.