Skip to content

Commit

Permalink
Add reset params (with np support) (#2926)
Browse files Browse the repository at this point in the history
  • Loading branch information
psc-g committed Jul 6, 2022
1 parent 071f0bf commit ca39816
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 6 deletions.
8 changes: 7 additions & 1 deletion gym/envs/classic_control/acrobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# SOURCE:
# https://github.com/rlpy/rlpy/blob/master/rlpy/Domains/Acrobot.py
from gym.envs.classic_control import utils
from gym.utils.renderer import Renderer


Expand Down Expand Up @@ -187,7 +188,12 @@ def reset(
options: Optional[dict] = None
):
super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
# Note that if you use custom reset bounds, it may lead to out-of-bound
# state/observations.
low, high = utils.maybe_parse_reset_bounds(
options, -0.1, 0.1 # default low
) # default high
self.state = self.np_random.uniform(low=low, high=high, size=(4,)).astype(
np.float32
)

Expand Down
8 changes: 7 additions & 1 deletion gym/envs/classic_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import gym
from gym import logger, spaces
from gym.envs.classic_control import utils
from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer

Expand Down Expand Up @@ -194,7 +195,12 @@ def reset(
options: Optional[dict] = None,
):
super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
# Note that if you use custom reset bounds, it may lead to out-of-bound
# state/observations.
low, high = utils.maybe_parse_reset_bounds(
options, -0.05, 0.05 # default low
) # default high
self.state = self.np_random.uniform(low=low, high=high, size=(4,))
self.steps_beyond_done = None
self.renderer.reset()
self.renderer.render_step()
Expand Down
8 changes: 7 additions & 1 deletion gym/envs/classic_control/continuous_mountain_car.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import gym
from gym import spaces
from gym.envs.classic_control import utils
from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer

Expand Down Expand Up @@ -180,7 +181,12 @@ def reset(
options: Optional[dict] = None
):
super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
# Note that if you use custom reset bounds, it may lead to out-of-bound
# state/observations.
low, high = utils.maybe_parse_reset_bounds(
options, -0.6, 0.4 # default low
) # default high
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
self.renderer.reset()
self.renderer.render_step()
if not return_info:
Expand Down
8 changes: 7 additions & 1 deletion gym/envs/classic_control/mountain_car.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import gym
from gym import spaces
from gym.envs.classic_control import utils
from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer

Expand Down Expand Up @@ -154,7 +155,12 @@ def reset(
options: Optional[dict] = None,
):
super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
# Note that if you use custom reset bounds, it may lead to out-of-bound
# state/observations.
low, high = utils.maybe_parse_reset_bounds(
options, -0.6, 0.4 # default low
) # default high
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
self.renderer.reset()
self.renderer.render_step()
if not return_info:
Expand Down
18 changes: 16 additions & 2 deletions gym/envs/classic_control/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@

import gym
from gym import spaces
from gym.envs.classic_control import utils
from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer

DEFAULT_X = np.pi
DEFAULT_Y = 1.0


class PendulumEnv(gym.Env):
"""
Expand Down Expand Up @@ -142,8 +146,18 @@ def reset(
options: Optional[dict] = None
):
super().reset(seed=seed)
high = np.array([np.pi, 1])
self.state = self.np_random.uniform(low=-high, high=high)
if options is None:
high = np.array([DEFAULT_X, DEFAULT_Y])
else:
# Note that if you use custom reset bounds, it may lead to out-of-bound
# state/observations.
x = options.get("x_init") if "x_init" in options else DEFAULT_X
y = options.get("y_init") if "y_init" in options else DEFAULT_Y
x = utils.verify_number_and_cast(x)
y = utils.verify_number_and_cast(y)
high = np.array([x, y])
low = -high # We enforce symmetric limits.
self.state = self.np_random.uniform(low=low, high=high)
self.last_u = None

self.renderer.reset()
Expand Down
44 changes: 44 additions & 0 deletions gym/envs/classic_control/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Utility functions used for classic control environments.
"""

from typing import Optional, SupportsFloat, Union


def verify_number_and_cast(x: SupportsFloat) -> float:
"""Verify parameter is a single number and cast to a float."""
try:
x = float(x)
except (ValueError, TypeError):
raise ValueError(f"Your input must support being cast to a float: {x}")
return x


def maybe_parse_reset_bounds(
options: Optional[dict], default_low: float, default_high: float
) -> Union[float, float]:
"""
This function can be called during a reset() to customize the sampling
ranges for setting the initial state distributions.
Args:
options: (Optional) options passed in to reset().
default_low: Default lower limit to use, if none specified in options.
default_high: Default upper limit to use, if none specified in options.
limit_low: Lowest allowable value for user-specified lower limit.
limit_high: Highest allowable value for user-specified higher limit.
Returns:
Lower and higher limits.
"""
if options is None:
return default_low, default_high

low = options.get("low") if "low" in options else default_low
high = options.get("high") if "high" in options else default_high
# We expect only numerical inputs.
low = verify_number_and_cast(low)
high = verify_number_and_cast(high)
if low > high:
raise ValueError("Lower bound must be lower than higher bound.")
return low, high
80 changes: 80 additions & 0 deletions tests/envs/test_env_implementation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Optional

import numpy as np
import pytest

import gym
Expand Down Expand Up @@ -142,3 +145,80 @@ def test_taxi_encode_decode():
env.encode(*env.decode(state)) == state
), f"state={state}, encode(decode(state))={env.encode(*env.decode(state))}"
state, _, _, _ = env.step(env.action_space.sample())


@pytest.mark.parametrize(
"env_name",
["Acrobot-v1", "CartPole-v1", "MountainCar-v0", "MountainCarContinuous-v0"],
)
@pytest.mark.parametrize(
"low_high", [None, (-0.4, 0.4), (np.array(-0.4), np.array(0.4))]
)
def test_customizable_resets(env_name: str, low_high: Optional[list]):
env = gym.make(env_name)
env.action_space.seed(0)
# First ensure we can do a reset.
if low_high is None:
env.reset()
else:
low, high = low_high
env.reset(options={"low": low, "high": high})
assert np.all((env.state >= low) & (env.state <= high))
# Make sure we can take a step.
env.step(env.action_space.sample())


@pytest.mark.parametrize(
"env_name", ["CartPole-v1", "MountainCar-v0", "MountainCarContinuous-v0"]
)
@pytest.mark.parametrize("low_high", [(-10.0, -9.0), (np.array(-10.0), np.array(-9.0))])
def test_customizable_out_of_bounds_resets(env_name: str, low_high: Optional[list]):
env = gym.make(env_name)
low, high = low_high
with pytest.raises(AssertionError):
env.reset(options={"low": low, "high": high})


# We test Pendulum separately, as the parameters are handled differently.
@pytest.mark.parametrize(
"low_high",
[
None,
(1.2, 1.0),
(np.array(1.2), np.array(1.0)),
],
)
def test_customizable_pendulum_resets(low_high: Optional[list]):
env = gym.make("Pendulum-v1")
env.action_space.seed(0)
# First ensure we can do a reset and the values are within expected ranges.
if low_high is None:
env.reset()
else:
low, high = low_high
# Pendulum is initialized a little differently than the other
# environments, where we specify the x and y values for the upper
# limit (and lower limit is just the negative of it).
env.reset(options={"x_init": low, "y_init": high})
# Make sure we can take a step.
env.step(env.action_space.sample())


@pytest.mark.parametrize(
"env_name",
["Acrobot-v1", "CartPole-v1", "MountainCar-v0", "MountainCarContinuous-v0"],
)
@pytest.mark.parametrize(
"low_high",
[
("x", "y"),
(10.0, 8.0),
([-1.0, -1.0], [1.0, 1.0]),
(np.array([-1.0, -1.0]), np.array([1.0, 1.0])),
],
)
def test_invalid_customizable_resets(env_name: str, low_high: list):
env = gym.make(env_name)
low, high = low_high
with pytest.raises(ValueError):
env.reset(options={"low": low, "high": high})

0 comments on commit ca39816

Please sign in to comment.