Skip to content

Commit

Permalink
Seeding update (#2422)
Browse files Browse the repository at this point in the history
* Ditch most of the seeding.py and replace np_random with the numpy default_rng. Let's see if tests pass

* Updated a bunch of RNG calls from the RandomState API to Generator API

* black; didn't expect that, did ya?

* Undo a typo

* blaaack

* More typo fixes

* Fixed setting/getting state in multidiscrete spaces

* Fix typo, fix a test to work with the new sampling

* Correctly (?) pass the randomly generated seed if np_random is called with None as seed

* Convert the Discrete sample to a python int (as opposed to np.int64)

* Remove some redundant imports

* First version of the compatibility layer for old-style RNG. Mainly to trigger tests.

* Removed redundant f-strings

* Style fixes, removing unused imports

* Try to make tests pass by removing atari from the dockerfile

* Try to make tests pass by removing atari from the setup

* Try to make tests pass by removing atari from the setup

* Try to make tests pass by removing atari from the setup

* First attempt at deprecating `env.seed` and supporting `env.reset(seed=seed)` instead. Tests should hopefully pass but throw up a million warnings.

* black; didn't expect that, didya?

* Rename the reset parameter in VecEnvs back to `seed`

* Updated tests to use the new seeding method

* Removed a bunch of old `seed` calls.

Fixed a bug in AsyncVectorEnv

* Stop Discrete envs from doing part of the setup (and using the randomness) in init (as opposed to reset)

* Add explicit seed to wrappers reset

* Remove an accidental return

* Re-add some legacy functions with a warning.

* Use deprecation instead of regular warnings for the newly deprecated methods/functions
  • Loading branch information
RedTachyon committed Dec 8, 2021
1 parent b84b69c commit c364506
Show file tree
Hide file tree
Showing 59 changed files with 386 additions and 294 deletions.
4 changes: 3 additions & 1 deletion docs/creating_environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

* `gym-foo/gym_foo/envs/foo_env.py` should look something like:
```python
from typing import Optional
import gym
from gym import error, spaces, utils
from gym.utils import seeding
Expand All @@ -61,7 +62,8 @@
...
def step(self, action):
...
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
...
def render(self, mode='human'):
...
Expand Down
39 changes: 26 additions & 13 deletions gym/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import abstractmethod
from typing import Optional

import gym
from gym import error
from gym.utils import closer
from gym.utils import closer, seeding
from gym.logger import deprecation


class Env:
Expand Down Expand Up @@ -38,6 +40,9 @@ class Env:
action_space = None
observation_space = None

# Created
np_random = None

@abstractmethod
def step(self, action):
"""Run one timestep of the environment's dynamics. When end of
Expand All @@ -58,7 +63,7 @@ def step(self, action):
raise NotImplementedError

@abstractmethod
def reset(self):
def reset(self, seed: Optional[int] = None):
"""Resets the environment to an initial state and returns an initial
observation.
Expand All @@ -71,7 +76,9 @@ def reset(self):
Returns:
observation (object): the initial observation.
"""
raise NotImplementedError
# Initialize the RNG if it's the first reset, or if the seed is manually passed
if seed is not None or self.np_random is None:
self.np_random, seed = seeding.np_random(seed)

@abstractmethod
def render(self, mode="human"):
Expand Down Expand Up @@ -136,7 +143,12 @@ def seed(self, seed=None):
'seed'. Often, the main seed equals the provided 'seed', but
this won't be true if seed=None, for example.
"""
return
deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
"Please use `env.reset(seed=seed) instead."
)
self.np_random, seed = seeding.np_random(seed)
return [seed]

@property
def unwrapped(self):
Expand Down Expand Up @@ -173,7 +185,8 @@ class GoalEnv(Env):
actual observations of the environment as per usual.
"""

def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
# Enforce that each GoalEnv uses a Goal-compatible observation space.
if not isinstance(self.observation_space, gym.spaces.Dict):
raise error.Error(
Expand Down Expand Up @@ -286,8 +299,8 @@ def metadata(self, value):
def step(self, action):
return self.env.step(action)

def reset(self, **kwargs):
return self.env.reset(**kwargs)
def reset(self, seed: Optional[int] = None, **kwargs):
return self.env.reset(seed=seed, **kwargs)

def render(self, mode="human", **kwargs):
return self.env.render(mode, **kwargs)
Expand All @@ -313,8 +326,8 @@ def unwrapped(self):


class ObservationWrapper(Wrapper):
def reset(self, **kwargs):
observation = self.env.reset(**kwargs)
def reset(self, seed: Optional[int] = None, **kwargs):
observation = self.env.reset(seed=seed, **kwargs)
return self.observation(observation)

def step(self, action):
Expand All @@ -327,8 +340,8 @@ def observation(self, observation):


class RewardWrapper(Wrapper):
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def reset(self, seed: Optional[int] = None, **kwargs):
return self.env.reset(seed=seed, **kwargs)

def step(self, action):
observation, reward, done, info = self.env.step(action)
Expand All @@ -340,8 +353,8 @@ def reward(self, reward):


class ActionWrapper(Wrapper):
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def reset(self, seed: Optional[int] = None, **kwargs):
return self.env.reset(seed=seed, **kwargs)

def step(self, action):
return self.env.step(self.action(action))
Expand Down
23 changes: 10 additions & 13 deletions gym/envs/box2d/bipedal_walker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import math
from typing import Optional

import numpy as np
import Box2D
Expand Down Expand Up @@ -122,7 +123,6 @@ class BipedalWalker(gym.Env, EzPickle):

def __init__(self):
EzPickle.__init__(self)
self.seed()
self.viewer = None

self.world = Box2D.b2World()
Expand All @@ -149,10 +149,6 @@ def __init__(self):
)
self.observation_space = spaces.Box(-high, high)

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]

def _destroy(self):
if not self.terrain:
return
Expand Down Expand Up @@ -188,7 +184,7 @@ def _generate_terrain(self, hardcore):
y += velocity

elif state == PIT and oneshot:
counter = self.np_random.randint(3, 5)
counter = self.np_random.integers(3, 5)
poly = [
(x, y),
(x + TERRAIN_STEP, y),
Expand All @@ -215,7 +211,7 @@ def _generate_terrain(self, hardcore):
y -= 4 * TERRAIN_STEP

elif state == STUMP and oneshot:
counter = self.np_random.randint(1, 3)
counter = self.np_random.integers(1, 3)
poly = [
(x, y),
(x + counter * TERRAIN_STEP, y),
Expand All @@ -228,9 +224,9 @@ def _generate_terrain(self, hardcore):
self.terrain.append(t)

elif state == STAIRS and oneshot:
stair_height = +1 if self.np_random.rand() > 0.5 else -1
stair_width = self.np_random.randint(4, 5)
stair_steps = self.np_random.randint(3, 5)
stair_height = +1 if self.np_random.random() > 0.5 else -1
stair_width = self.np_random.integers(4, 5)
stair_steps = self.np_random.integers(3, 5)
original_y = y
for s in range(stair_steps):
poly = [
Expand Down Expand Up @@ -266,9 +262,9 @@ def _generate_terrain(self, hardcore):
self.terrain_y.append(y)
counter -= 1
if counter == 0:
counter = self.np_random.randint(TERRAIN_GRASS / 2, TERRAIN_GRASS)
counter = self.np_random.integers(TERRAIN_GRASS / 2, TERRAIN_GRASS)
if state == GRASS and hardcore:
state = self.np_random.randint(1, _STATES_)
state = self.np_random.integers(1, _STATES_)
oneshot = True
else:
state = GRASS
Expand Down Expand Up @@ -312,7 +308,8 @@ def _generate_clouds(self):
x2 = max(p[0] for p in poly)
self.cloud_poly.append((poly, x1, x2))

def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self._destroy()
self.world.contactListener_bug_workaround = ContactDetector(self)
self.world.contactListener = self.world.contactListener_bug_workaround
Expand Down
10 changes: 4 additions & 6 deletions gym/envs/box2d/car_racing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
"""
import sys
import math
from typing import Optional

import numpy as np

import Box2D
Expand Down Expand Up @@ -121,7 +123,6 @@ class CarRacing(gym.Env, EzPickle):

def __init__(self, verbose=1):
EzPickle.__init__(self)
self.seed()
self.contactListener_keepref = FrictionDetector(self)
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
self.viewer = None
Expand All @@ -145,10 +146,6 @@ def __init__(self, verbose=1):
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
)

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]

def _destroy(self):
if not self.road:
return
Expand Down Expand Up @@ -343,7 +340,8 @@ def _create_track(self):
self.track = track
return True

def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self._destroy()
self.reward = 0.0
self.prev_reward = 0.0
Expand Down
13 changes: 5 additions & 8 deletions gym/envs/box2d/lunar_lander.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

import math
import sys
from typing import Optional

import numpy as np

import Box2D
Expand Down Expand Up @@ -93,7 +95,6 @@ class LunarLander(gym.Env, EzPickle):

def __init__(self):
EzPickle.__init__(self)
self.seed()
self.viewer = None

self.world = Box2D.b2World()
Expand All @@ -117,10 +118,6 @@ def __init__(self):
# Nop, fire left engine, main engine, right engine
self.action_space = spaces.Discrete(4)

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]

def _destroy(self):
if not self.moon:
return
Expand All @@ -133,7 +130,8 @@ def _destroy(self):
self.world.DestroyBody(self.legs[0])
self.world.DestroyBody(self.legs[1])

def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self._destroy()
self.world.contactListener_keepref = ContactDetector(self)
self.world.contactListener = self.world.contactListener_keepref
Expand Down Expand Up @@ -504,10 +502,9 @@ def heuristic(env, s):


def demo_heuristic_lander(env, seed=None, render=False):
env.seed(seed)
total_reward = 0
steps = 0
s = env.reset()
s = env.reset(seed=seed)
while True:
a = heuristic(env, s)
s, r, done, info = env.step(a)
Expand Down
10 changes: 4 additions & 6 deletions gym/envs/classic_control/acrobot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""classic Acrobot task"""
from typing import Optional

import numpy as np
from numpy import sin, cos, pi

Expand Down Expand Up @@ -94,13 +96,9 @@ def __init__(self):
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
self.action_space = spaces.Discrete(3)
self.state = None
self.seed()

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]

def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
np.float32
)
Expand Down
10 changes: 4 additions & 6 deletions gym/envs/classic_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"""

import math
from typing import Optional

import gym
from gym import spaces, logger
from gym.utils import seeding
Expand Down Expand Up @@ -90,16 +92,11 @@ def __init__(self):
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Box(-high, high, dtype=np.float32)

self.seed()
self.viewer = None
self.state = None

self.steps_beyond_done = None

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]

def step(self, action):
err_msg = f"{action!r} ({type(action)}) invalid"
assert self.action_space.contains(action), err_msg
Expand Down Expand Up @@ -158,7 +155,8 @@ def step(self, action):

return np.array(self.state, dtype=np.float32), reward, done, {}

def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
self.steps_beyond_done = None
return np.array(self.state, dtype=np.float32)
Expand Down
10 changes: 3 additions & 7 deletions gym/envs/classic_control/continuous_mountain_car.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""

import math
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -83,12 +84,6 @@ def __init__(self, goal_velocity=0):
low=self.low_state, high=self.high_state, dtype=np.float32
)

self.seed()

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]

def step(self, action):

position = self.state[0]
Expand Down Expand Up @@ -119,7 +114,8 @@ def step(self, action):
self.state = np.array([position, velocity], dtype=np.float32)
return self.state, reward, done, {}

def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
return np.array(self.state, dtype=np.float32)

Expand Down

0 comments on commit c364506

Please sign in to comment.