Skip to content

Commit

Permalink
Upgrade Gym to Gymnasium Wrapper (#458)
Browse files Browse the repository at this point in the history
* upgrade wrapper

* add another utility function to gym wrapper

* format

* try super

* continue writing tests

* stuff

* continue working on multiagent env

* finish testing

* format code
  • Loading branch information
pimpale committed Aug 1, 2023
1 parent cc8ca5f commit 86c1050
Show file tree
Hide file tree
Showing 6 changed files with 498 additions and 77 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jobs:
run: |
pip install cython
pip install numpy
pip install gym
pip install -e .
pip install pytest
pip install pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion metadrive/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def render(self,
ret = self._render_topdown(text=text, *args, **kwargs)
return ret
assert self.config["use_render"] or self.engine.mode != RENDER_MODE_NONE, \
("Panda Renderring is off now, can not render. Please set config['use_render'] = True!")
("Panda Rendering is off now, can not render. Please set config['use_render'] = True!")

self.engine.render_frame(text)

Expand Down
206 changes: 132 additions & 74 deletions metadrive/envs/gym_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,144 @@
try:
from typing import Any, Dict
import inspect
from typing import Any, Dict, Callable
import gymnasium
import gym
import gym.spaces
except ImportError:
raise ImportError("Please install gym to use this wrapper. Your gym should be >=0.20.0, <=0.26.0")

def gymnasiumToGym(space: gymnasium.spaces.Space) -> gym.spaces.Space:
if isinstance(space, gymnasium.spaces.Box):
return gym.spaces.Box(low=space.low, high=space.high, shape=space.shape)
elif isinstance(space, gymnasium.spaces.Discrete):
return gym.spaces.Discrete(n=int(space.n), start=int(space.start))
elif isinstance(space, gymnasium.spaces.MultiDiscrete):
return gym.spaces.MultiDiscrete(nvec=space.nvec)
elif isinstance(space, gymnasium.spaces.Tuple):
return gym.spaces.Tuple([gymnasiumToGym(subspace) for subspace in space.spaces])
elif isinstance(space, gymnasium.spaces.Dict):
return gym.spaces.Dict({key: gymnasiumToGym(subspace) for key, subspace in space.spaces.items()})
else:
raise ValueError("unsupported space")

def gymToGymnasium(space: gym.spaces.Space) -> gymnasium.spaces.Space:
if isinstance(space, gym.spaces.Box):
return gymnasium.spaces.Box(low=space.low, high=space.high, shape=space.shape)
elif isinstance(space, gym.spaces.Discrete):
return gymnasium.spaces.Discrete(n=int(space.n), start=int(space.start))
elif isinstance(space, gym.spaces.MultiDiscrete):
return gymnasium.spaces.MultiDiscrete(nvec=space.nvec)
elif isinstance(space, gym.spaces.Tuple):
return gymnasium.spaces.Tuple([gymToGymnasium(subspace) for subspace in space.spaces])
elif isinstance(space, gym.spaces.Dict):
return gymnasium.spaces.Dict({key: gymToGymnasium(subspace) for key, subspace in space.spaces.items()})
else:
raise ValueError("unsupported space")

class GymEnvWrapper(gym.Env):
def __init__(self, config: Dict[str, Any]):
def createGymWrapper(inner_class: type):
"""
Note that config must contain two items:
"inner_class": the class of a Metadrive environment (not instantiated)
"inner_config": The config that will be passed to the Metadrive environment
"inner_class": A gymnasium based Metadrive environment class
"""
inner_class = config["inner_class"]
inner_config = config["inner_config"]
assert isinstance(inner_class, type)
assert isinstance(inner_config, dict)
super().__setattr__("_inner", inner_class(config=inner_config))

def step(self, actions):
o, r, tm, tc, i = self._inner.step(actions)
if isinstance(tm, dict) and isinstance(tc, dict):
d = {tm[j] or tc[j] for j in set(list(tm.keys()) + list(tc.keys()))}
else:
d = tm or tc
return o, r, d, i
def was_overriden(a):
"""
Returns if function `a` was not defined in this file
"""
# we know that the function is overriden if the function was not defined in this file
# because this is a dynamic class, equality checks will always return false
# TODO: this is a hack, but i'm not sure how to make it more robust
return inspect.getfile(a) != inspect.getfile(createGymWrapper)

def reset(self, *, seed=None, options=None):
# pass non-none parameters to the reset (which may not support options or seed)
params = {"seed": seed, "options": options}
not_none_params = {k: v for k, v in params.items() if v is not None}
obs, _ = self._inner.reset(**not_none_params)
return obs
def createOverridenDefaultConfigWrapper(base: type, new_default_config: Callable) -> type:
"""
Returns a class derived from the `base` class. It overrides the `default_config` classmethod, which is set to new_default_config
"""
class OverridenDefaultConfigWrapper(base):
@classmethod
def default_config(cls):
return new_default_config()

def render(self, *args, **kwargs):
# remove mode from kwargs
kwargs.pop("mode", None)
return self._inner.render(*args, **kwargs)
return OverridenDefaultConfigWrapper

def close(self):
self._inner.close()
class GymEnvWrapper(gym.Env):
@classmethod
def default_config(cls):
"""
This is the default, if you override it, then we will override it within the inner_class to maintain consistency
"""
return inner_class.default_config()

def seed(self, seed=None):
"""
We cannot seed a Gymnasium environment while running, so do nothing
"""
pass

@property
def observation_space(self):
obs_space = self._inner.observation_space
assert isinstance(obs_space, gymnasium.spaces.Box)
return gym.spaces.Box(low=obs_space.low, high=obs_space.high, shape=obs_space.shape)

@property
def action_space(self):
action_space = self._inner.action_space
assert isinstance(action_space, gymnasium.spaces.Box)
return gym.spaces.Box(low=action_space.low, high=action_space.high, shape=action_space.shape)

def __getattr__(self, name):
return getattr(self._inner, name)

def __setattr__(self, name, value):
if hasattr(self._inner, name):
setattr(self._inner, name, value)
else:
super().__setattr__(name, value)


if __name__ == '__main__':
from metadrive.envs.scenario_env import ScenarioEnv

env = GymEnvWrapper(config={"inner_class": ScenarioEnv, "inner_config": {"manual_control": True}})
o, i = env.reset()
assert isinstance(env.observation_space, gymnasium.Space)
assert isinstance(env.action_space, gymnasium.Space)
for s in range(600):
o, r, d, i = env.step([0, -1])
env.vehicle.set_velocity([0, 0])
if d:
assert s == env.config["horizon"] and i["max_step"] and d
break
def __init__(self, config: Dict[str, Any]):
# We can only tell if someone has overriden the default config method at init time.

# if there was an override, we need to provide the overriden method to the inner class.
if was_overriden(type(self).default_config):
# when inner_class's init is called, it now has access to the new default_config
actual_inner_class = createOverridenDefaultConfigWrapper(inner_class, type(self).default_config)
else:
# otherwise, don't make a wrapper
actual_inner_class = inner_class

# initialize
super().__setattr__("_inner", actual_inner_class(config=config))

def step(self, actions):
o, r, tm, tc, i = self._inner.step(actions)
if isinstance(tm, dict) and isinstance(tc, dict):
d = {j: (j in tm and tm[j]) or (j in tc and tc[j]) for j in set(list(tm.keys()) + list(tc.keys()))}
else:
d = tm or tc
return o, r, d, i

def reset(self, *, seed=None, options=None):
# pass non-none parameters to the reset (which may not support options or seed)
params = {"seed": seed, "options": options}
not_none_params = {k: v for k, v in params.items() if v is not None}
obs, _ = self._inner.reset(**not_none_params)
return obs

def render(self, *args, **kwargs):
# remove mode from kwargs
kwargs.pop("mode", None)
return self._inner.render(*args, **kwargs)

def close(self):
self._inner.close()

def seed(self, seed=None):
"""
We cannot seed a Gymnasium environment while running, so do nothing
"""
pass

@property
def observation_space(self):
return gymnasiumToGym(self._inner.observation_space)

@property
def action_space(self):
return gymnasiumToGym(self._inner.action_space)

def __getattr__(self, __name: str) -> Any:
return getattr(self._inner, __name)

def __setattr__(self, name, value):
if hasattr(self._inner, name):
setattr(self._inner, name, value)
else:
super().__setattr__(name, value)

return GymEnvWrapper

if __name__ == '__main__':
from metadrive.envs.scenario_env import ScenarioEnv

env = createGymWrapper(ScenarioEnv)(config={"manual_control": True})
o, i = env.reset()
assert isinstance(env.observation_space, gymnasium.spaces.Space)
assert isinstance(env.action_space, gymnasium.spaces.Space)
for s in range(600):
o, r, d, i = env.step([0, -1])
env.vehicle.set_velocity([0, 0])
if d:
assert s == env.config["horizon"] and i["max_step"] and d
break
except:
pass
3 changes: 2 additions & 1 deletion metadrive/examples/train_generalization_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np

from metadrive import MetaDriveEnv
from metadrive.envs.gym_wrapper import createGymWrapper

try:
import ray
Expand Down Expand Up @@ -173,7 +174,7 @@ def get_train_parser():

# ===== Training Environment =====
# Train the policies in scenario sets with different number of scenarios.
env=MetaDriveEnv,
env=createGymWrapper(MetaDriveEnv),
env_config=dict(
num_scenarios=tune.grid_search([1, 3, 5, 1000]),
start_seed=tune.grid_search([5000]),
Expand Down

0 comments on commit 86c1050

Please sign in to comment.