-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Upgrade Gym to Gymnasium Wrapper (#458)
* upgrade wrapper * add another utility function to gym wrapper * format * try super * continue writing tests * stuff * continue working on multiagent env * finish testing * format code
- Loading branch information
Showing
6 changed files
with
498 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,144 @@ | ||
try: | ||
from typing import Any, Dict | ||
import inspect | ||
from typing import Any, Dict, Callable | ||
import gymnasium | ||
import gym | ||
import gym.spaces | ||
except ImportError: | ||
raise ImportError("Please install gym to use this wrapper. Your gym should be >=0.20.0, <=0.26.0") | ||
|
||
def gymnasiumToGym(space: gymnasium.spaces.Space) -> gym.spaces.Space: | ||
if isinstance(space, gymnasium.spaces.Box): | ||
return gym.spaces.Box(low=space.low, high=space.high, shape=space.shape) | ||
elif isinstance(space, gymnasium.spaces.Discrete): | ||
return gym.spaces.Discrete(n=int(space.n), start=int(space.start)) | ||
elif isinstance(space, gymnasium.spaces.MultiDiscrete): | ||
return gym.spaces.MultiDiscrete(nvec=space.nvec) | ||
elif isinstance(space, gymnasium.spaces.Tuple): | ||
return gym.spaces.Tuple([gymnasiumToGym(subspace) for subspace in space.spaces]) | ||
elif isinstance(space, gymnasium.spaces.Dict): | ||
return gym.spaces.Dict({key: gymnasiumToGym(subspace) for key, subspace in space.spaces.items()}) | ||
else: | ||
raise ValueError("unsupported space") | ||
|
||
def gymToGymnasium(space: gym.spaces.Space) -> gymnasium.spaces.Space: | ||
if isinstance(space, gym.spaces.Box): | ||
return gymnasium.spaces.Box(low=space.low, high=space.high, shape=space.shape) | ||
elif isinstance(space, gym.spaces.Discrete): | ||
return gymnasium.spaces.Discrete(n=int(space.n), start=int(space.start)) | ||
elif isinstance(space, gym.spaces.MultiDiscrete): | ||
return gymnasium.spaces.MultiDiscrete(nvec=space.nvec) | ||
elif isinstance(space, gym.spaces.Tuple): | ||
return gymnasium.spaces.Tuple([gymToGymnasium(subspace) for subspace in space.spaces]) | ||
elif isinstance(space, gym.spaces.Dict): | ||
return gymnasium.spaces.Dict({key: gymToGymnasium(subspace) for key, subspace in space.spaces.items()}) | ||
else: | ||
raise ValueError("unsupported space") | ||
|
||
class GymEnvWrapper(gym.Env): | ||
def __init__(self, config: Dict[str, Any]): | ||
def createGymWrapper(inner_class: type): | ||
""" | ||
Note that config must contain two items: | ||
"inner_class": the class of a Metadrive environment (not instantiated) | ||
"inner_config": The config that will be passed to the Metadrive environment | ||
"inner_class": A gymnasium based Metadrive environment class | ||
""" | ||
inner_class = config["inner_class"] | ||
inner_config = config["inner_config"] | ||
assert isinstance(inner_class, type) | ||
assert isinstance(inner_config, dict) | ||
super().__setattr__("_inner", inner_class(config=inner_config)) | ||
|
||
def step(self, actions): | ||
o, r, tm, tc, i = self._inner.step(actions) | ||
if isinstance(tm, dict) and isinstance(tc, dict): | ||
d = {tm[j] or tc[j] for j in set(list(tm.keys()) + list(tc.keys()))} | ||
else: | ||
d = tm or tc | ||
return o, r, d, i | ||
def was_overriden(a): | ||
""" | ||
Returns if function `a` was not defined in this file | ||
""" | ||
# we know that the function is overriden if the function was not defined in this file | ||
# because this is a dynamic class, equality checks will always return false | ||
# TODO: this is a hack, but i'm not sure how to make it more robust | ||
return inspect.getfile(a) != inspect.getfile(createGymWrapper) | ||
|
||
def reset(self, *, seed=None, options=None): | ||
# pass non-none parameters to the reset (which may not support options or seed) | ||
params = {"seed": seed, "options": options} | ||
not_none_params = {k: v for k, v in params.items() if v is not None} | ||
obs, _ = self._inner.reset(**not_none_params) | ||
return obs | ||
def createOverridenDefaultConfigWrapper(base: type, new_default_config: Callable) -> type: | ||
""" | ||
Returns a class derived from the `base` class. It overrides the `default_config` classmethod, which is set to new_default_config | ||
""" | ||
class OverridenDefaultConfigWrapper(base): | ||
@classmethod | ||
def default_config(cls): | ||
return new_default_config() | ||
|
||
def render(self, *args, **kwargs): | ||
# remove mode from kwargs | ||
kwargs.pop("mode", None) | ||
return self._inner.render(*args, **kwargs) | ||
return OverridenDefaultConfigWrapper | ||
|
||
def close(self): | ||
self._inner.close() | ||
class GymEnvWrapper(gym.Env): | ||
@classmethod | ||
def default_config(cls): | ||
""" | ||
This is the default, if you override it, then we will override it within the inner_class to maintain consistency | ||
""" | ||
return inner_class.default_config() | ||
|
||
def seed(self, seed=None): | ||
""" | ||
We cannot seed a Gymnasium environment while running, so do nothing | ||
""" | ||
pass | ||
|
||
@property | ||
def observation_space(self): | ||
obs_space = self._inner.observation_space | ||
assert isinstance(obs_space, gymnasium.spaces.Box) | ||
return gym.spaces.Box(low=obs_space.low, high=obs_space.high, shape=obs_space.shape) | ||
|
||
@property | ||
def action_space(self): | ||
action_space = self._inner.action_space | ||
assert isinstance(action_space, gymnasium.spaces.Box) | ||
return gym.spaces.Box(low=action_space.low, high=action_space.high, shape=action_space.shape) | ||
|
||
def __getattr__(self, name): | ||
return getattr(self._inner, name) | ||
|
||
def __setattr__(self, name, value): | ||
if hasattr(self._inner, name): | ||
setattr(self._inner, name, value) | ||
else: | ||
super().__setattr__(name, value) | ||
|
||
|
||
if __name__ == '__main__': | ||
from metadrive.envs.scenario_env import ScenarioEnv | ||
|
||
env = GymEnvWrapper(config={"inner_class": ScenarioEnv, "inner_config": {"manual_control": True}}) | ||
o, i = env.reset() | ||
assert isinstance(env.observation_space, gymnasium.Space) | ||
assert isinstance(env.action_space, gymnasium.Space) | ||
for s in range(600): | ||
o, r, d, i = env.step([0, -1]) | ||
env.vehicle.set_velocity([0, 0]) | ||
if d: | ||
assert s == env.config["horizon"] and i["max_step"] and d | ||
break | ||
def __init__(self, config: Dict[str, Any]): | ||
# We can only tell if someone has overriden the default config method at init time. | ||
|
||
# if there was an override, we need to provide the overriden method to the inner class. | ||
if was_overriden(type(self).default_config): | ||
# when inner_class's init is called, it now has access to the new default_config | ||
actual_inner_class = createOverridenDefaultConfigWrapper(inner_class, type(self).default_config) | ||
else: | ||
# otherwise, don't make a wrapper | ||
actual_inner_class = inner_class | ||
|
||
# initialize | ||
super().__setattr__("_inner", actual_inner_class(config=config)) | ||
|
||
def step(self, actions): | ||
o, r, tm, tc, i = self._inner.step(actions) | ||
if isinstance(tm, dict) and isinstance(tc, dict): | ||
d = {j: (j in tm and tm[j]) or (j in tc and tc[j]) for j in set(list(tm.keys()) + list(tc.keys()))} | ||
else: | ||
d = tm or tc | ||
return o, r, d, i | ||
|
||
def reset(self, *, seed=None, options=None): | ||
# pass non-none parameters to the reset (which may not support options or seed) | ||
params = {"seed": seed, "options": options} | ||
not_none_params = {k: v for k, v in params.items() if v is not None} | ||
obs, _ = self._inner.reset(**not_none_params) | ||
return obs | ||
|
||
def render(self, *args, **kwargs): | ||
# remove mode from kwargs | ||
kwargs.pop("mode", None) | ||
return self._inner.render(*args, **kwargs) | ||
|
||
def close(self): | ||
self._inner.close() | ||
|
||
def seed(self, seed=None): | ||
""" | ||
We cannot seed a Gymnasium environment while running, so do nothing | ||
""" | ||
pass | ||
|
||
@property | ||
def observation_space(self): | ||
return gymnasiumToGym(self._inner.observation_space) | ||
|
||
@property | ||
def action_space(self): | ||
return gymnasiumToGym(self._inner.action_space) | ||
|
||
def __getattr__(self, __name: str) -> Any: | ||
return getattr(self._inner, __name) | ||
|
||
def __setattr__(self, name, value): | ||
if hasattr(self._inner, name): | ||
setattr(self._inner, name, value) | ||
else: | ||
super().__setattr__(name, value) | ||
|
||
return GymEnvWrapper | ||
|
||
if __name__ == '__main__': | ||
from metadrive.envs.scenario_env import ScenarioEnv | ||
|
||
env = createGymWrapper(ScenarioEnv)(config={"manual_control": True}) | ||
o, i = env.reset() | ||
assert isinstance(env.observation_space, gymnasium.spaces.Space) | ||
assert isinstance(env.action_space, gymnasium.spaces.Space) | ||
for s in range(600): | ||
o, r, d, i = env.step([0, -1]) | ||
env.vehicle.set_velocity([0, 0]) | ||
if d: | ||
assert s == env.config["horizon"] and i["max_step"] and d | ||
break | ||
except: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.