Skip to content

Commit

Permalink
Fix minor issue in gym_wrapper (#628)
Browse files Browse the repository at this point in the history
* Fix minor issue in gym_wrapper

* fix
  • Loading branch information
pengzhenghao committed Jan 30, 2024
1 parent 1180543 commit eea6a95
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions metadrive/envs/gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import gym.spaces

def gymnasiumToGym(space: gymnasium.spaces.Space) -> gym.spaces.Space:
return gymnasium_to_gym(space)

def gymnasium_to_gym(space: gymnasium.spaces.Space) -> gym.spaces.Space:
if isinstance(space, gym.spaces.Space):
return space
if isinstance(space, gymnasium.spaces.Box):
return gym.spaces.Box(low=space.low, high=space.high, shape=space.shape)
elif isinstance(space, gymnasium.spaces.Discrete):
Expand All @@ -17,9 +22,14 @@ def gymnasiumToGym(space: gymnasium.spaces.Space) -> gym.spaces.Space:
elif isinstance(space, gymnasium.spaces.Dict):
return gym.spaces.Dict({key: gymnasiumToGym(subspace) for key, subspace in space.spaces.items()})
else:
raise ValueError("unsupported space")
raise ValueError(f"unsupported space: {type(space)}!")

def gymToGymnasium(space: gym.spaces.Space) -> gymnasium.spaces.Space:
return gym_to_gymnasium(space)

def gym_to_gymnasium(space: gym.spaces.Space) -> gymnasium.spaces.Space:
if isinstance(space, gymnasium.spaces.Space):
return space
if isinstance(space, gym.spaces.Box):
return gymnasium.spaces.Box(low=space.low, high=space.high, shape=space.shape)
elif isinstance(space, gym.spaces.Discrete):
Expand All @@ -31,9 +41,12 @@ def gymToGymnasium(space: gym.spaces.Space) -> gymnasium.spaces.Space:
elif isinstance(space, gym.spaces.Dict):
return gymnasium.spaces.Dict({key: gymToGymnasium(subspace) for key, subspace in space.spaces.items()})
else:
raise ValueError("unsupported space")
raise ValueError(f"unsupported space: {type(space)}!")

def createGymWrapper(inner_class: type):
return create_gym_wrapper(inner_class)

def create_gym_wrapper(inner_class: type):
"""
"inner_class": A gymnasium based Metadrive environment class
"""
Expand Down

0 comments on commit eea6a95

Please sign in to comment.