diff --git a/pgdrive/envs/pgdrive_env.py b/pgdrive/envs/pgdrive_env.py index 1884544ea..d9971e45d 100644 --- a/pgdrive/envs/pgdrive_env.py +++ b/pgdrive/envs/pgdrive_env.py @@ -74,8 +74,6 @@ # ===== use image ===== image_source="rgb_cam", # take effect when only when use_image == True - # use_image=False, - # rgb_clip=True, # ===== vehicle spawn ===== spawn_lane_index=(FirstBlock.NODE_1, FirstBlock.NODE_2, 0), diff --git a/pgdrive/scene_manager/agent_manager.py b/pgdrive/scene_manager/agent_manager.py index 427c024ee..06b278d4d 100644 --- a/pgdrive/scene_manager/agent_manager.py +++ b/pgdrive/scene_manager/agent_manager.py @@ -2,7 +2,8 @@ import logging from typing import Dict -from gym.spaces import Box +from gym.spaces import Box, Dict + from pgdrive.scene_creator.vehicle.base_vehicle import BaseVehicle @@ -112,7 +113,10 @@ def init(self, pg_world, config_dict: Dict): obs_space = self._init_observation_spaces[agent_id] self.observation_spaces[vehicle.name] = obs_space - assert isinstance(obs_space, Box) + if not vehicle.vehicle_config["use_image"]: + assert isinstance(obs_space, Box) + else: + assert isinstance(obs_space, Dict), "Multi-agent observation should be gym.Dict" action_space = self._init_action_spaces[agent_id] self.action_spaces[vehicle.name] = action_space assert isinstance(action_space, Box) diff --git a/pgdrive/tests/vis_funtionality/vis_depth_cam_ground.py b/pgdrive/tests/vis_funtionality/vis_depth_cam_ground.py index 1ffd945ae..15be93f5c 100644 --- a/pgdrive/tests/vis_funtionality/vis_depth_cam_ground.py +++ b/pgdrive/tests/vis_funtionality/vis_depth_cam_ground.py @@ -9,12 +9,11 @@ def __init__(self): "environment_num": 1, "traffic_density": 0.1, "start_seed": 4, - "image_source": "depth_cam", "manual_control": True, "use_render": True, "use_image": True, "rgb_clip": True, - "vehicle_config": dict(depth_cam=(200, 88, True)), + "vehicle_config": dict(depth_cam=(200, 88, True), image_source="depth_cam"), "pg_world_config": { "headless_image": False, }, @@ -31,7 +30,7 @@ def __init__(self): if __name__ == "__main__": def get_image(env): - env.vehicle.image_sensors[env.config["image_source"]].save_image() + env.vehicle.image_sensors[env.vehicle.vehicle_config["image_source"]].save_image() env.pg_world.screenshot() env = TestEnv() diff --git a/pgdrive/tests/vis_funtionality/vis_depth_cam_no_ground.py b/pgdrive/tests/vis_funtionality/vis_depth_cam_no_ground.py index eff4f0943..4c66776ea 100644 --- a/pgdrive/tests/vis_funtionality/vis_depth_cam_no_ground.py +++ b/pgdrive/tests/vis_funtionality/vis_depth_cam_no_ground.py @@ -9,12 +9,11 @@ def __init__(self): "environment_num": 1, "traffic_density": 0.1, "start_seed": 4, - "image_source": "depth_cam", "manual_control": True, "use_render": True, "use_image": True, "rgb_clip": True, - "vehicle_config": dict(depth_cam=(200, 88, False)), + "vehicle_config": dict(depth_cam=(200, 88, False), image_source="depth_cam"), "pg_world_config": { "headless_image": False, }, @@ -31,7 +30,7 @@ def __init__(self): if __name__ == "__main__": env = TestEnv() env.reset() - env.pg_world.accept("m", env.vehicle.image_sensors[env.config["image_source"]].save_image) + env.pg_world.accept("m", env.vehicle.image_sensors[env.vehicle.vehicle_config["image_source"]].save_image) for i in range(1, 100000): o, r, d, info = env.step([0, 1]) diff --git a/pgdrive/tests/vis_funtionality/vis_rgb_cam.py b/pgdrive/tests/vis_funtionality/vis_rgb_cam.py index d7eba92f7..ce4d43b11 100644 --- a/pgdrive/tests/vis_funtionality/vis_rgb_cam.py +++ b/pgdrive/tests/vis_funtionality/vis_rgb_cam.py @@ -8,15 +8,10 @@ def __init__(self): "environment_num": 1, "traffic_density": 0.1, "start_seed": 4, - "image_source": "rgb_cam", "manual_control": True, "use_render": True, - "use_image": True, - "rgb_clip": True, - # "vehicle_config": dict(rgb_cam=(200, 88)), - "pg_world_config": { - "headless_image": False - } + "use_image": True, # it is a switch telling pgdrive to use rgb as observation + "rgb_clip": True, # clip rgb to range(0,1) instead of (0, 255) } ) @@ -24,7 +19,8 @@ def __init__(self): if __name__ == "__main__": env = TestEnv() env.reset() - env.pg_world.accept("m", env.vehicle.image_sensors[env.config["image_source"]].save_image) + # print m to capture rgb observation + env.pg_world.accept("m", env.vehicle.image_sensors[env.vehicle.vehicle_config["image_source"]].save_image) for i in range(1, 100000): o, r, d, info = env.step([0, 1])