Skip to content

Commit

Permalink
Introduce route completion for PG map (#638)
Browse files Browse the repository at this point in the history
* up

* finish

* finish

* add doc

* fix
  • Loading branch information
pengzhenghao committed Feb 12, 2024
1 parent a369fd1 commit ffdb1e7
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 11 deletions.
56 changes: 52 additions & 4 deletions metadrive/component/navigation_module/node_network_navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,14 @@ def auto_assign_task(map, current_lane_index, final_road_node=None, random_seed=

def set_route(self, current_lane_index: str, destination: str):
"""
Find a shortest path from start road to end road
:param current_lane_index: start road node
:param destination: end road node or end lane index
:return: None
Find the shortest path from start road to the end road.
Args:
current_lane_index: start road node
destination: end road node or end lane index
Returns:
None
"""
self.spawn_road = current_lane_index[:-1]
self.checkpoints = self.map.road_network.shortest_path(current_lane_index, destination)
Expand Down Expand Up @@ -127,13 +131,37 @@ def set_route(self, current_lane_index: str, destination: str):
check_point = ref_lane.position(ref_lane.length, later_middle)
self._dest_node_path.setPos(panda_vector(check_point[0], check_point[1], self.MARK_HEIGHT))

# Compute the total length of the route for computing route completion
self.total_length = 0.0
self.travelled_length = 0.0
self._last_long_in_ref_lane = 0.0
for ckpt1, ckpt2 in zip(self.checkpoints[:-1], self.checkpoints[1:]):
self.total_length += self.map.road_network.graph[ckpt1][ckpt2][0].length

def update_localization(self, ego_vehicle):
"""
Update current position, route completion and checkpoints according to current position.
Args:
ego_vehicle: a vehicle object
Returns:
None
"""
position = ego_vehicle.position
lane, lane_index = self._update_current_lane(ego_vehicle)
long, _ = lane.local_coordinates(position)
need_update = self._update_target_checkpoints(lane_index, long)
assert len(self.checkpoints) >= 2

# Update travelled_length for route completion
long_in_ref_lane, _ = self.current_ref_lanes[0].local_coordinates(position)
travelled = long_in_ref_lane - self._last_long_in_ref_lane
self.travelled_length += travelled
self._last_long_in_ref_lane = long_in_ref_lane
# print(f"{self.travelled_length=}, {travelled=}, {long_in_ref_lane=}, "
# f"{self.route_completion=}, {self._last_long_in_ref_lane=}")

# target_road_1 is the road segment the vehicle is driving on.
if need_update:
target_road_1_start = self.checkpoints[self._target_checkpoints_index[0]]
Expand All @@ -142,6 +170,8 @@ def update_localization(self, ego_vehicle):
self.current_ref_lanes = target_lanes_1
self.current_road = Road(target_road_1_start, target_road_1_end)

self._last_long_in_ref_lane = self.current_ref_lanes[0].local_coordinates(position)[0]

# target_road_2 is next road segment the vehicle should drive on.
target_road_2_start = self.checkpoints[self._target_checkpoints_index[1]]
target_road_2_end = self.checkpoints[self._target_checkpoints_index[1] + 1]
Expand All @@ -157,10 +187,12 @@ def update_localization(self, ego_vehicle):

self._navi_info.fill(0.0)
half = self.CHECK_POINT_INFO_DIM
# Put the next checkpoint's information into the first half of the navi_info
self._navi_info[:half], lanes_heading1, checkpoint = self._get_info_for_checkpoint(
lanes_id=0, ref_lane=self.current_ref_lanes[0], ego_vehicle=ego_vehicle
)

# Put the next of the next checkpoint's information into the first half of the navi_info
self._navi_info[half:], lanes_heading2, _ = self._get_info_for_checkpoint(
lanes_id=1,
ref_lane=self.next_ref_lanes[0] if self.next_ref_lanes is not None else self.current_ref_lanes[0],
Expand Down Expand Up @@ -241,7 +273,17 @@ def _get_current_lane(self, ego_vehicle):
return (*possible_lanes[0][:-1], on_lane) if len(possible_lanes) > 0 else (None, None, on_lane)

def _get_info_for_checkpoint(self, lanes_id, ref_lane, ego_vehicle):
"""
Return the information of checkpoints for state observation.
Args:
lanes_id: the lane index of current lane. (lanes is a list so each lane has an index in this list)
ref_lane: the reference lane.
ego_vehicle: the vehicle object.
Returns:
navi_information, lanes_heading, check_point
"""
navi_information = []
# Project the checkpoint position into the target vehicle's coordination, where
# +x is the heading and +y is the right hand side.
Expand Down Expand Up @@ -304,5 +346,11 @@ def _update_current_lane(self, ego_vehicle):
return lane, lane_index

def get_state(self):
"""Return the navigation information for recording/replaying."""
final_road = self.final_road
return {"spawn_road": self.spawn_road, "destination": (final_road.start_node, final_road.end_node)}

@property
def route_completion(self):
"""Return the route completion at this moment."""
return self.travelled_length / self.total_length
20 changes: 13 additions & 7 deletions metadrive/component/road_network/node_road_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,12 @@ def bfs_paths(self, start: str, goal: str) -> List[List[str]]:
"""
Breadth-first search of all routes from start to goal.
:param start: starting node
:param goal: goal node
:return: list of paths from start to goal.
Args:
start: starting node
goal: goal node
Returns:
list of paths from start to goal.
"""
queue = [(start, [start])]
while queue:
Expand All @@ -260,11 +263,14 @@ def bfs_paths(self, start: str, goal: str) -> List[List[str]]:

def shortest_path(self, start: str, goal: str) -> List[str]:
"""
Breadth-first search of shortest checkpoints from start to goal.
Breadth-first search of the shortest checkpoints from start to goal.
Args:
start: starting node
goal: goal node
:param start: starting node
:param goal: goal node
:return: shortest checkpoints from start to goal.
Returns:
The shortest checkpoints from start to goal.
"""
start_road_node = start[0]
assert start != goal
Expand Down
3 changes: 3 additions & 0 deletions metadrive/envs/metadrive_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ def reward_function(self, vehicle_id: str):
reward = -self.config["crash_vehicle_penalty"]
elif vehicle.crash_object:
reward = -self.config["crash_object_penalty"]

step_info["route_completion"] = vehicle.navigation.route_completion

return reward, step_info

def setup_engine(self):
Expand Down
68 changes: 68 additions & 0 deletions metadrive/tests/test_functionality/test_route_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from metadrive.envs import MetaDriveEnv

from metadrive.policy.idm_policy import IDMPolicy
from metadrive.utils import setup_logger


def test_route_completion_easy():
"""
Use an easy map to test whether the route completion is computed correctly.
"""
# In easy map
config = {}
config["map"] = "SSS"
config["traffic_density"] = 0
try:
env = MetaDriveEnv(config=config)
o, i = env.reset()
assert "route_completion" in i
rc = env.vehicle.navigation.route_completion
epr = 0
for _ in range(1000):
o, r, tm, tc, i = env.step([0, 1])
epr += r
env.render(mode="topdown")
if tm or tc:
epr = 0
break
assert "route_completion" in i
print(i["route_completion"])
assert i["route_completion"] > 0.95
finally:
if "env" in locals():
env.close()


def test_route_completion_hard():
"""
Use a hard map to test whether the route completion is computed correctly.
"""
# In hard map
config = {}
config["map"] = "SCXTO"
config["agent_policy"] = IDMPolicy
config["traffic_density"] = 0
try:
env = MetaDriveEnv(config=config)
o, i = env.reset()
assert "route_completion" in i
rc = env.vehicle.navigation.route_completion
epr = 0
for _ in range(1000):
o, r, tm, tc, i = env.step([0, 0])
epr += r
env.render(mode="topdown")
if tm or tc:
epr = 0
break
assert "route_completion" in i
print(i["route_completion"], i)
assert i["route_completion"] > 0.8 # The vehicle will not reach destination due to randomness in IDM.
finally:
if "env" in locals():
env.close()


if __name__ == '__main__':
setup_logger(True)
test_route_completion_hard()

0 comments on commit ffdb1e7

Please sign in to comment.