Skip to content

Commit

Permalink
Add support for integrating sumo map (#597)
Browse files Browse the repository at this point in the history
* Add support for integrating sumo map

* fix

* fix test
  • Loading branch information
QuanyiLi committed Jan 15, 2024
1 parent 9276a2c commit 6d52930
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 28 deletions.
8 changes: 6 additions & 2 deletions metadrive/component/map/scenario_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@


class ScenarioMap(BaseMap):
def __init__(self, map_index, random_seed=None):
def __init__(self, map_index, map_data, random_seed=None):
self.map_index = map_index
self.map_data = map_data
self.need_lane_localization = self.engine.global_config["need_lane_localization"]
super(ScenarioMap, self).__init__(dict(id=self.map_index), random_seed=random_seed)

Expand All @@ -27,6 +28,7 @@ def _generate(self):
global_network=self.road_network,
random_seed=0,
map_index=self.map_index,
map_data=self.map_data,
need_lane_localization=self.need_lane_localization
)
self.crosswalks = block.crosswalks
Expand Down Expand Up @@ -144,6 +146,7 @@ def get_boundary_line_vector(self, interval):
# data = read_scenario_data(file_path)

default_config = ScenarioEnv.default_config()
default_config["_render_mode"] = "onscreen"
default_config["use_render"] = True
default_config["debug"] = True
default_config["debug_static_world"] = True
Expand All @@ -154,7 +157,8 @@ def get_boundary_line_vector(self, interval):
engine = initialize_engine(default_config)

engine.data_manager = ScenarioDataManager()
map = ScenarioMap(map_index=0)
m_data = engine.data_manager.get_scenario(0, should_copy=False)["map_features"]
map = ScenarioMap(map_index=0, map_data=m_data)
map.attach_to_world()
engine.enableMouse()
map.road_network.show_bounding_box(engine)
Expand Down
21 changes: 2 additions & 19 deletions metadrive/component/scenario_block/scenario_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,22 @@
from metadrive.component.road_network.edge_road_network import EdgeRoadNetwork
from metadrive.constants import PGDrivableAreaProperty
from metadrive.constants import PGLineType, PGLineColor
from metadrive.engine.engine_utils import get_engine
from metadrive.scenario.scenario_description import ScenarioDescription
from metadrive.type import MetaDriveType
from metadrive.utils.interpolating_line import InterpolatingLine
from metadrive.utils.math import norm
from metadrive.utils.vertex import make_polygon_model


class ScenarioBlock(BaseBlock):
LINE_CULL_DIST = 500

def __init__(self, block_index: int, global_network, random_seed, map_index, need_lane_localization):
def __init__(self, block_index: int, global_network, random_seed, map_index, map_data, need_lane_localization):
# self.map_data = map_data
self.need_lane_localization = need_lane_localization
self.map_index = map_index
data = self.engine.data_manager.current_scenario
sdc_track = data.get_sdc_track()
self.sdc_start_point = sdc_track["state"]["position"][0]
self.map_data = map_data
super(ScenarioBlock, self).__init__(block_index, global_network, random_seed)

@property
def map_data(self):
e = get_engine()
return e.data_manager.get_scenario(self.map_index, should_copy=False)["map_features"]

def _sample_topology(self) -> bool:
for object_id, data in self.map_data.items():
if MetaDriveType.is_lane(data.get("type", False)):
Expand Down Expand Up @@ -85,10 +76,6 @@ def construct_continuous_line(self, polyline, color):
segment_num = int(line.length / PGDrivableAreaProperty.STRIPE_LENGTH)
for segment in range(segment_num):
start = line.get_point(PGDrivableAreaProperty.STRIPE_LENGTH * segment)
# trick for optimizing
dist = norm(start[0] - self.sdc_start_point[0], start[1] - self.sdc_start_point[1])
if dist > self.LINE_CULL_DIST:
continue

if segment == segment_num - 1:
end = line.get_point(line.length)
Expand All @@ -102,10 +89,6 @@ def construct_broken_line(self, polyline, color):
segment_num = int(line.length / (2 * PGDrivableAreaProperty.STRIPE_LENGTH))
for segment in range(segment_num):
start = line.get_point(segment * PGDrivableAreaProperty.STRIPE_LENGTH * 2)
# trick for optimizing
dist = norm(start[0] - self.sdc_start_point[0], start[1] - self.sdc_start_point[1])
if dist > self.LINE_CULL_DIST:
continue
end = line.get_point(
segment * PGDrivableAreaProperty.STRIPE_LENGTH * 2 + PGDrivableAreaProperty.STRIPE_LENGTH
)
Expand Down
4 changes: 2 additions & 2 deletions metadrive/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def get_semantic_map_pixel_per_meter(cls):
return 22 if cls.map_region_size <= 1024 else 11

@classmethod
def point_in_map(cls, point, map_center=None):
def point_in_map(cls, point):
"""
Return if the point is in the map region
Args:
Expand All @@ -508,7 +508,7 @@ def point_in_map(cls, point, map_center=None):
return -x <= x_ <= x and -y <= y_ <= y

@classmethod
def clip_polygon(cls, polygon, map_center=None):
def clip_polygon(cls, polygon):
"""
Clip the Polygon. Make it fit into the map region and throw away the part outside the map region
Args:
Expand Down
3 changes: 2 additions & 1 deletion metadrive/manager/scenario_map_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def reset(self):
self.sdc_dest_point = None

if self._stored_maps[seed] is None:
new_map = ScenarioMap(map_index=seed)
m_data = self.engine.data_manager.get_scenario(seed, should_copy=False)["map_features"]
new_map = ScenarioMap(map_index=seed, map_data=m_data)
if self.store_map:
self._stored_maps[seed] = new_map
else:
Expand Down
6 changes: 4 additions & 2 deletions metadrive/tests/test_component/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def test_map_get_semantic_map(dir="waymo", render=False, show=False):
engine.data_manager = ScenarioDataManager()
for idx in range(default_config["num_scenarios"]):
engine.seed(idx)
map = ScenarioMap(map_index=idx)
m_data = engine.data_manager.get_scenario(idx, should_copy=False)["map_features"]
map = ScenarioMap(map_index=idx, map_data=m_data)
heightfield = map.get_semantic_map([0, 0], size, res)
assert heightfield.shape[0] == heightfield.shape[1] == int(size * res)
if show:
Expand All @@ -48,7 +49,8 @@ def test_map_get_elevation_map(dir="waymo", render=False, show=False):
engine.data_manager = ScenarioDataManager()
for idx in range(default_config["num_scenarios"]):
engine.seed(idx)
map = ScenarioMap(map_index=idx)
m_data = engine.data_manager.get_scenario(idx, should_copy=False)["map_features"]
map = ScenarioMap(map_index=idx, map_data=m_data)
heightfield = map.get_height_map([0, 0], size, res, extension=4)
assert heightfield.shape[0] == heightfield.shape[1] == int(size * res)
if show:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def test_waymo_map_memory_leak():

for t in range(10):
lt = time.time()

map = ScenarioMap(map_index=0)
m_data = engine.data_manager.get_scenario(0, should_copy=False)["map_features"]
map = ScenarioMap(map_index=0, map_data=m_data)
map.attach_to_world(engine.render, engine.physics_world)
map.destroy()

Expand Down

0 comments on commit 6d52930

Please sign in to comment.