Skip to content

Commit

Permalink
Add connectiveity when saving to scenario description (#508)
Browse files Browse the repository at this point in the history
* Add connectivity

* readd information

* save nodenet to edge net

* format
  • Loading branch information
QuanyiLi committed Oct 4, 2023
1 parent 72e842c commit fb92adc
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 9 deletions.
26 changes: 18 additions & 8 deletions metadrive/component/road_network/edge_road_network.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from metadrive.scenario.scenario_description import ScenarioDescription as SD
from collections import namedtuple
from typing import List

from metadrive.component.road_network.base_road_network import BaseRoadNetwork
from metadrive.component.road_network.base_road_network import LaneIndex
from metadrive.scenario.scenario_description import ScenarioDescription as SD
from metadrive.utils.math import get_boxes_bounding_box
from metadrive.utils.pg.utils import get_lanes_bounding_box

Expand Down Expand Up @@ -64,16 +64,23 @@ def bfs_paths(self, start: str, goal: str) -> List[List[str]]:
"""
Breadth-first search of all routes from start to goal.
:param start: starting node
:param goal: goal node
:param start: starting edges
:param goal: goal edge
:return: list of paths from start to goal.
"""
queue = [(start, [start])]
lanes = self.graph[start].left_lanes + self.graph[start].right_lanes + [start]

queue = [(lane, [lane]) for lane in lanes]
while queue:
(node, path) = queue.pop(0)
if node not in self.graph:
(lane, path) = queue.pop(0)
if lane not in self.graph:
yield []
for _next in set(self.graph[node].exit_lanes) - set(path):
if len(self.graph[lane].exit_lanes) == 0:
continue
for _next in set(self.graph[lane].exit_lanes):
if _next in path:
# circle
continue
if _next == goal:
yield path + [_next]
elif _next in self.graph:
Expand All @@ -99,7 +106,6 @@ def __del__(self):
logging.debug("{} is released".format(self.__class__.__name__))

def get_map_features(self, interval=2):
from metadrive.type import MetaDriveType

ret = {}
for id, lane_info in self.graph.items():
Expand All @@ -108,6 +114,10 @@ def get_map_features(self, interval=2):
SD.POLYLINE: lane_info.lane.get_polyline(interval),
SD.POLYGON: lane_info.lane.polygon,
SD.TYPE: lane_info.lane.metadrive_type,
SD.ENTRY: lane_info.entry_lanes,
SD.EXIT: lane_info.exit_lanes,
SD.LEFT_NEIGHBORS: lane_info.left_lanes,
SD.RIGHT_NEIGHBORS: lane_info.right_lanes,
"speed_limit_kmh": lane_info.lane.speed_limit
}
return ret
Expand Down
26 changes: 25 additions & 1 deletion metadrive/component/road_network/node_road_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,39 @@ def shortest_path(self, start: str, goal: str) -> List[str]:
return next(self.bfs_paths(start_road_node, goal), [])

def get_map_features(self, interval=2):
from metadrive.type import MetaDriveType
def find_entry_exit():
entries = dict()
exits = dict()

for _from, _to_dict in self.graph.items():
for _to, lanes in _to_dict.items():
if _from in exits:
exits[_from] += ["{}".format(l.index) for l in lanes]
else:
exits[_from] = ["{}".format(l.index) for l in lanes]

if _to in entries:
entries[_to] += ["{}".format(l.index) for l in lanes]
else:
entries[_to] = ["{}".format(l.index) for l in lanes]
return entries, exits

entries, exits = find_entry_exit()

ret = {}
for _from, _to_dict in self.graph.items():
for _to, lanes in _to_dict.items():
for k, lane in enumerate(lanes):
left_n = ["{}".format(l.index) for l in lanes[:k]]
right_n = ["{}".format(l.index) for l in lanes[k + 1:]]
ret["{}".format(lane.index)] = {
SD.POLYLINE: lane.get_polyline(interval),
SD.POLYGON: lane.polygon,
# Convert to EdgeNetwork
SD.LEFT_NEIGHBORS: left_n,
SD.RIGHT_NEIGHBORS: right_n,
SD.ENTRY: entries.get(_from, []),
SD.EXIT: exits.get(_to, []),
SD.TYPE: lane.metadrive_type,
"speed_limit_kmh": lane.speed_limit
}
Expand Down
104 changes: 104 additions & 0 deletions metadrive/tests/test_export_record_scenario/test_connectivity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import copy
import os
import pickle
import shutil

from metadrive.component.map.base_map import BaseMap
from metadrive.component.map.pg_map import MapGenerateMethod
from metadrive.envs.metadrive_env import MetaDriveEnv
from metadrive.envs.real_data_envs.waymo_env import WaymoEnv
from metadrive.policy.idm_policy import IDMPolicy
from metadrive.policy.replay_policy import ReplayEgoCarPolicy


def test_search_path(render_export_env=False, render_load_env=False):
# Origin Data
env = MetaDriveEnv(
dict(
start_seed=0,
use_render=render_export_env,
num_scenarios=1,
map_config={
BaseMap.GENERATE_TYPE: MapGenerateMethod.BIG_BLOCK_SEQUENCE,
BaseMap.GENERATE_CONFIG: "OSXCTrCS", # it can be a file path / block num / block ID sequence
BaseMap.LANE_WIDTH: 3.5,
BaseMap.LANE_NUM: 1,
"exit_length": 50,
},
agent_policy=IDMPolicy
)
)
policy = lambda x: [0, 1]
dir = None

try:
scenarios, done_info = env.export_scenarios(policy, scenario_index=[i for i in range(1)])
dir = os.path.join(os.path.dirname(__file__), "../test_component/test_export")
os.makedirs(dir, exist_ok=True)
for i, data in scenarios.items():
with open(os.path.join(dir, "{}.pkl".format(i)), "wb+") as file:
pickle.dump(data, file)
node_roadnet = copy.deepcopy(env.current_map.road_network)
env.close()

# Loaded Data
env = WaymoEnv(
dict(agent_policy=ReplayEgoCarPolicy, data_directory=dir, use_render=render_load_env, num_scenarios=1)
)
scenarios, done_info = env.export_scenarios(policy, scenario_index=[i for i in range(1)])
dir = os.path.join(os.path.dirname(__file__), "../test_component/test_export")
os.makedirs(dir, exist_ok=True)
for i, data in scenarios.items():
with open(os.path.join(dir, "{}.pkl".format(i)), "wb+") as file:
pickle.dump(data, file)
env.close()

# reload
env = WaymoEnv(
dict(agent_policy=ReplayEgoCarPolicy, data_directory=dir, use_render=render_load_env, num_scenarios=1)
)
for index in range(1):
env.reset(seed=index)
done = False
while not done:
o, r, tm, tc, i = env.step([0, 0])
done = tm or tc
edge_roadnet = copy.deepcopy(env.current_map.road_network)
all_node_lanes = node_roadnet.get_all_lanes()
all_edge_lanes = edge_roadnet.get_all_lanes()
diff = set(["{}".format(l.index) for l in all_node_lanes]) - set(["{}".format(l.index) for l in all_edge_lanes])
assert len(diff) == 0
nodes = node_roadnet.shortest_path('>', "8S0_0_")
print(nodes)
edges = edge_roadnet.shortest_path("('>', '>>', 0)", "('7C0_1_', '8S0_0_', 0)")

def process_data(input_list):
# Initialize the output list
output_list = []

for item in input_list:
# Remove the outer double quotes and then split the string based on commas
elements = item.strip('""').split(',')

# Extract the first two elements, strip the unnecessary characters and append to the output list
for elem in elements[:2]:
output_list.append(elem.strip(' "\'()'))

# Remove duplicates while maintaining order
output_list = list(dict.fromkeys(output_list))

return output_list

to_node = process_data(edges)
print(to_node)
assert to_node == nodes

finally:
env.close()
if dir is not None:
shutil.rmtree(dir)


if __name__ == '__main__':
test_search_path()

0 comments on commit fb92adc

Please sign in to comment.