Skip to content

Commit

Permalink
Fix a bug that observation is not in np.float32 to support gym==0.20.0 (
Browse files Browse the repository at this point in the history
#74)

* fix the shape of float32

* fix tests

* fix
  • Loading branch information
PENG Zhenghao committed Sep 15, 2021
1 parent 65cb8e4 commit dec4a3c
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 11 deletions.
4 changes: 2 additions & 2 deletions metadrive/component/vehicle/vehicle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, horizon, dim, dt):
self.mass = 800
self.len = 1
self.bounds = []
self.u = np.zeros(self.dim * self.horizon)
self.u = np.zeros(self.dim * self.horizon, dtype=np.float32)
self.config = {"replan": False}

def cost(self, u, *args):
Expand All @@ -54,7 +54,7 @@ def plant_model(self, state, dt, *control):

def solve(self):
if self.config['replan']:
self.u = np.zeros(self.dim * self.horizon)
self.u = np.zeros(self.dim * self.horizon, dtype=np.float32)
else:
for _ in range(self.dim):
self.u = np.delete(self.u, 0)
Expand Down
2 changes: 1 addition & 1 deletion metadrive/component/vehicle_module/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self.current_road = None
self.next_road = None
self._target_checkpoints_index = None
self._navi_info = np.zeros((self.navigation_info_dim, )) # navi information res
self._navi_info = np.zeros((self.navigation_info_dim, ), dtype=np.float32) # navi information res

# Vis
self._show_navi_info = (engine.mode == RENDER_MODE_ONSCREEN and not engine.global_config["debug_physics_world"])
Expand Down
2 changes: 1 addition & 1 deletion metadrive/envs/marl_envs/marl_inout_roundabout.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def observe(self, vehicle):
self.cloud_points = cloud_points
self.detected_objects = detected_objects
self.current_observation = np.concatenate((state, np.asarray(other_v_info)))
return self.current_observation
return self.current_observation.astype(np.float32)

def state_observe(self, vehicle):
return self.state_obs.observe(vehicle)
Expand Down
3 changes: 2 additions & 1 deletion metadrive/envs/marl_envs/marl_tollgate.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def observe(self, vehicle):
# print(toll_obs)
state = self.state_observe(vehicle)
other_v_info = self.lidar_observe(vehicle)
return np.concatenate((state, np.asarray(other_v_info), np.asarray(toll_obs)))
ret = np.concatenate((state, np.asarray(other_v_info), np.asarray(toll_obs)))
return ret.astype(np.float32)


class MATollGateMap(PGMap):
Expand Down
4 changes: 2 additions & 2 deletions metadrive/obs/image_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, config, image_source: str, clip_rgb: bool):
self.image_source = image_source
super(ImageObservation, self).__init__(config)
self.rgb_clip = clip_rgb
self.state = np.zeros(self.observation_space.shape)
self.state = np.zeros(self.observation_space.shape, dtype=np.float32)

@property
def observation_space(self):
Expand All @@ -70,4 +70,4 @@ def reset(self, env, vehicle=None):
:param vehicle: BaseVehicle
:return: None
"""
self.state = np.zeros(self.observation_space.shape)
self.state = np.zeros(self.observation_space.shape, dtype=np.float32)
6 changes: 4 additions & 2 deletions metadrive/obs/state_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def observe(self, vehicle):
"""
navi_info = vehicle.navigation.get_navi_info()
ego_state = self.vehicle_state(vehicle)
return np.concatenate([ego_state, navi_info])
ret = np.concatenate([ego_state, navi_info])
return ret.astype(np.float32)

def vehicle_state(self, vehicle):
"""
Expand Down Expand Up @@ -167,7 +168,8 @@ def observe(self, vehicle):
state = self.state_observe(vehicle)
other_v_info = self.lidar_observe(vehicle)
self.current_observation = np.concatenate((state, np.asarray(other_v_info)))
return self.current_observation
ret = self.current_observation
return ret.astype(np.float32)

def state_observe(self, vehicle):
return self.state_obs.observe(vehicle)
Expand Down
2 changes: 1 addition & 1 deletion metadrive/tests/test_component/test_detector_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def fake_cutils_perceive(

env = MetaDriveEnv({"map": "C", "traffic_density": 1.0, "environment_num": 10, "use_render": False})
env.reset()
env.vehicle.lidar.cloud_points = np.ones((env.vehicle.lidar.num_lasers, ), dtype=float)
env.vehicle.lidar.cloud_points = np.ones((env.vehicle.lidar.num_lasers, ), dtype=np.float32)
env.vehicle.lidar.detected_objects = []
try:
for _ in range(3):
Expand Down
2 changes: 1 addition & 1 deletion metadrive/tests/test_env/test_ma_roundabout_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def test_ma_roundabout_reward_done_alignment_1():
# #assert r[kkk] == -1.7777
for kkk, ddd in d.items():
if ddd and kkk != "__all__":
assert i[kkk][TerminationState.OUT_OF_ROAD]
assert i[kkk][TerminationState.OUT_OF_ROAD], i[kkk]
# print('{} done passed!'.format(kkk))
for kkk, rrr in r.items():
if rrr == -1.7777:
Expand Down
71 changes: 71 additions & 0 deletions metadrive/tests/test_env/test_metadrive_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import copy
import os

import numpy as np
import pytest

from metadrive import MetaDriveEnv
from metadrive.constants import TerminationState

# Key: case name, value: environmental config
blackbox_test_configs = dict(
default=dict(),
random_traffic=dict(random_traffic=True),
large_seed=dict(start_seed=1000000),
traffic_density_0=dict(traffic_density=0),
traffic_density_1=dict(traffic_density=1),
decision_repeat_50=dict(decision_repeat=50),
map_7=dict(map=7),
map_30=dict(map=30),
map_CCC=dict(map="CCC"),
envs_100=dict(environment_num=100),
envs_1000=dict(environment_num=1000),
envs_10000=dict(environment_num=10000),
envs_100000=dict(environment_num=100000),
no_lidar0={"vehicle_config": dict(lidar=dict(num_lasers=0, distance=0, num_others=0))},
no_lidar1={"vehicle_config": dict(lidar=dict(num_lasers=0, distance=10, num_others=0))},
no_lidar2={"vehicle_config": dict(lidar=dict(num_lasers=10, distance=0, num_others=0))},
no_lidar3={"vehicle_config": dict(lidar=dict(num_lasers=0, distance=0, num_others=10))},
no_lidar4={"vehicle_config": dict(lidar=dict(num_lasers=10, distance=10, num_others=0))},
no_lidar5={"vehicle_config": dict(lidar=dict(num_lasers=10, distance=0, num_others=10))},
no_lidar6={"vehicle_config": dict(lidar=dict(num_lasers=0, distance=10, num_others=10))},
no_lidar7={"vehicle_config": dict(lidar=dict(num_lasers=10, distance=10, num_others=10))},
)

pid_control_config = dict(environment_num=1, start_seed=5, map="CrXROSTR", traffic_density=0.0, use_render=False)

info_keys = [
"cost", "velocity", "steering", "acceleration", "step_reward", TerminationState.CRASH_VEHICLE,
TerminationState.OUT_OF_ROAD, TerminationState.SUCCESS
]

assert "__init__.py" not in os.listdir(os.path.dirname(__file__)), "Please remove __init__.py in tests directory."


def _act(env, action):
assert env.action_space.contains(action)
obs, reward, done, info = env.step(action)
assert env.observation_space.contains(obs)
assert np.isscalar(reward)
assert isinstance(info, dict)
for k in info_keys:
assert k in info


@pytest.mark.parametrize("config", list(blackbox_test_configs.values()), ids=list(blackbox_test_configs.keys()))
def test_pgdrive_env_blackbox(config):
env = MetaDriveEnv(config=copy.deepcopy(config))
try:
obs = env.reset()
assert env.observation_space.contains(obs)
_act(env, env.action_space.sample())
for x in [-1, 0, 1]:
env.reset()
for y in [-1, 0, 1]:
_act(env, [x, y])
finally:
env.close()


if __name__ == '__main__':
pytest.main(["-sv", "test_metadrive_env.py"])
1 change: 1 addition & 0 deletions metadrive/tests/test_env/test_safe_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def test_safe_env():
for i in range(1, 100):
o, r, d, info = env.step([0, 1])
total_cost += info["cost"]
assert env.observation_space.contains(o)
if d:
total_cost = 0
print("Reset")
Expand Down

0 comments on commit dec4a3c

Please sign in to comment.