Skip to content

Commit

Permalink
update gym wrapper (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanyiLi committed Jun 23, 2023
1 parent a4aa036 commit 16d1d4c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import deque

import numpy as np
from panda3d.core import NodePath
from panda3d.core import NodePath, Material

from metadrive.component.vehicle_navigation_module.base_navigation import BaseNavigation
from metadrive.engine.asset_loader import AssetLoader
Expand Down Expand Up @@ -58,6 +58,13 @@ def __init__(
for model in self._ckpt_vis_models:
if self._navi_point_model is None:
self._navi_point_model = AssetLoader.loader.loadModel(AssetLoader.file_path("models", "box.bam"))
self._navi_point_model.setScale(0.5)
if self.engine.use_render_pipeline:
material = Material()
material.setBaseColor((1, 1, 1, 1))
material.setShininess(128)
material.setEmission((1, 1, 1, 1))
self._navi_point_model.setMaterial(material, True)
self._navi_point_model.instanceTo(model)
model.reparentTo(self.origin)

Expand Down Expand Up @@ -86,7 +93,7 @@ def set_route(self):
self.next_ref_lanes = None
if self._dest_node_path is not None:
check_point = self.reference_trajectory.end
self._dest_node_path.setPos(panda_vector(check_point[0], check_point[1], 1.8))
self._dest_node_path.setPos(panda_vector(check_point[0], check_point[1], 1))

def discretize_reference_trajectory(self):
ret = []
Expand Down
3 changes: 2 additions & 1 deletion metadrive/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def done_function(self, vehicle_id: str) -> Tuple[bool, Dict]:
def render(self,
text: Optional[Union[dict, str]] = None,
return_bytes=False,
mode=None,
*args,
**kwargs) -> Optional[np.ndarray]:
"""
Expand All @@ -353,7 +354,7 @@ def render(self,
:return: when mode is 'rgb', image array is returned
"""

mode = self.config["render_mode"]
mode = mode or self.config["render_mode"] # for compatibility

if mode in ["top_down", "topdown", "bev", "birdview"]:
ret = self._render_topdown(text=text, *args, **kwargs)
Expand Down
132 changes: 70 additions & 62 deletions metadrive/envs/gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,76 +3,84 @@
import gymnasium
import gym
import gym.spaces
except ImportError:
raise ImportError("Please install gym to use this wrapper. Your gym should be >=0.20.0, <=0.26.0")

class GymEnvWrapper(gym.Env):
def __init__(self, config: Dict[str, Any]):
"""
Note that config must contain two items:
"inner_class": the class of a Metadrive environment (not instantiated)
"inner_config": The config that will be passed to the Metadrive environment
"""
inner_class = config["inner_class"]
inner_config = config["inner_config"]
assert isinstance(inner_class, type)
assert isinstance(inner_config, dict)
self._inner = inner_class(config=inner_config)

def step(self, actions):
o, r, tm, tc, i = self._inner.step(actions)
if isinstance(tm, dict) and isinstance(tc, dict):
d = {tm[j] or tc[j] for j in set(list(tm.keys()) + list(tc.keys()))}
else:
d = tm or tc
return o, r, d, i
class GymEnvWrapper(gym.Env):
def __init__(self, config: Dict[str, Any]):
"""
Note that config must contain two items:
"inner_class": the class of a Metadrive environment (not instantiated)
"inner_config": The config that will be passed to the Metadrive environment
"""
inner_class = config["inner_class"]
inner_config = config["inner_config"]
assert isinstance(inner_class, type)
assert isinstance(inner_config, dict)
super().__setattr__("_inner", inner_class(config=inner_config))

def reset(self, *, seed=None, options=None):
# pass non-none parameters to the reset (which may not support options or seed)
params = {"seed": seed, "options": options}
not_none_params = {k: v for k, v in params.items() if v is not None}
obs, _ = self._inner.reset(**not_none_params)
return obs
def step(self, actions):
o, r, tm, tc, i = self._inner.step(actions)
if isinstance(tm, dict) and isinstance(tc, dict):
d = {tm[j] or tc[j] for j in set(list(tm.keys()) + list(tc.keys()))}
else:
d = tm or tc
return o, r, d, i

def render(self, *args, **kwargs):
# remove mode from kwargs
kwargs.pop("mode", None)
return self._inner.render(*args, **kwargs)
def reset(self, *, seed=None, options=None):
# pass non-none parameters to the reset (which may not support options or seed)
params = {"seed": seed, "options": options}
not_none_params = {k: v for k, v in params.items() if v is not None}
obs, _ = self._inner.reset(**not_none_params)
return obs

def close(self):
self._inner.close()
def render(self, *args, **kwargs):
# remove mode from kwargs
kwargs.pop("mode", None)
return self._inner.render(*args, **kwargs)

def seed(self, seed=None):
"""
We cannot seed a Gymnasium environment while running, so do nothing
"""
pass
def close(self):
self._inner.close()

@property
def observation_space(self):
obs_space = self._inner.observation_space
assert isinstance(obs_space, gymnasium.spaces.Box)
return gym.spaces.Box(low=obs_space.low, high=obs_space.high, shape=obs_space.shape)
def seed(self, seed=None):
"""
We cannot seed a Gymnasium environment while running, so do nothing
"""
pass

@property
def action_space(self):
action_space = self._inner.action_space
assert isinstance(action_space, gymnasium.spaces.Box)
return gym.spaces.Box(low=action_space.low, high=action_space.high, shape=action_space.shape)
@property
def observation_space(self):
obs_space = self._inner.observation_space
assert isinstance(obs_space, gymnasium.spaces.Box)
return gym.spaces.Box(low=obs_space.low, high=obs_space.high, shape=obs_space.shape)

def __getattr__(self, __name: str) -> Any:
return self._inner[__name]
@property
def action_space(self):
action_space = self._inner.action_space
assert isinstance(action_space, gymnasium.spaces.Box)
return gym.spaces.Box(low=action_space.low, high=action_space.high, shape=action_space.shape)

if __name__ == '__main__':
from metadrive.envs.scenario_env import ScenarioEnv
def __getattr__(self, name):
return getattr(self._inner, name)

env = GymEnvWrapper(config={"inner_class": ScenarioEnv, "inner_config": {"manual_control": True}})
o, i = env.reset()
assert isinstance(env.observation_space, gymnasium.Space)
assert isinstance(env.action_space, gymnasium.Space)
for s in range(600):
o, r, d, i = env.step([0, -1])
env.vehicle.set_velocity([0, 0])
if d:
assert s == env.config["horizon"] and i["max_step"] and d
break
except ImportError:
pass
def __setattr__(self, name, value):
if hasattr(self._inner, name):
setattr(self._inner, name, value)
else:
super().__setattr__(name, value)


if __name__ == '__main__':
from metadrive.envs.scenario_env import ScenarioEnv

env = GymEnvWrapper(config={"inner_class": ScenarioEnv, "inner_config": {"manual_control": True}})
o, i = env.reset()
assert isinstance(env.observation_space, gymnasium.Space)
assert isinstance(env.action_space, gymnasium.Space)
for s in range(600):
o, r, d, i = env.step([0, -1])
env.vehicle.set_velocity([0, 0])
if d:
assert s == env.config["horizon"] and i["max_step"] and d
break
2 changes: 1 addition & 1 deletion metadrive/render_pipeline/config/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pipeline:
# it will also disable the hotkeys, and give a small performance boost.
# Most likely you also don't want to show it in your own game, so set
# it to false in that case.
display_debugger: true
display_debugger: false

# Affects which debugging information is displayed. If this is set to false,
# only frame time is displayed, otherwise much more information is visible.
Expand Down

0 comments on commit 16d1d4c

Please sign in to comment.