Skip to content

Commit

Permalink
Make TopDownMultiChannel support ScenarioEnv (#498)
Browse files Browse the repository at this point in the history
* works!

* add preliminary support for top down view when using a scenario

* format

* remove extraeous change

* draw trajectories

* format and fix bug
  • Loading branch information
pimpale committed Oct 18, 2023
1 parent 653431b commit 3a7c740
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 12 deletions.
2 changes: 1 addition & 1 deletion metadrive/component/road_network/edge_road_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from metadrive.utils.math import get_boxes_bounding_box
from metadrive.utils.pg.utils import get_lanes_bounding_box

lane_info = namedtuple("neighbor_lanes", "lane entry_lanes exit_lanes left_lanes right_lanes")
lane_info = namedtuple("edge_lane", ["lane", "entry_lanes", "exit_lanes", "left_lanes", "right_lanes"])


class EdgeRoadNetwork(BaseRoadNetwork):
Expand Down
5 changes: 5 additions & 0 deletions metadrive/manager/scenario_traffic_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from metadrive.component.static_object.traffic_object import TrafficCone, TrafficBarrier
from metadrive.component.traffic_participants.cyclist import Cyclist
from metadrive.component.traffic_participants.pedestrian import Pedestrian
from metadrive.component.vehicle.base_vehicle import BaseVehicle
from metadrive.component.vehicle.vehicle_type import get_vehicle_type, reset_vehicle_type_count
from metadrive.constants import DEFAULT_AGENT
from metadrive.manager.base_manager import BaseManager
Expand Down Expand Up @@ -162,6 +163,10 @@ def sdc_object_id(self):
def current_scenario_length(self):
return self.engine.data_manager.current_scenario_length

@property
def vehicles(self):
return list(self.engine.get_objects(filter=lambda o: isinstance(o, BaseVehicle)).values())

def spawn_vehicle(self, v_id, track):
state = parse_object_state(track, self.episode_step)

Expand Down
49 changes: 38 additions & 11 deletions metadrive/obs/top_down_obs_multi_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,23 @@
import numpy as np

from metadrive.component.vehicle.base_vehicle import BaseVehicle
from metadrive.component.traffic_participants.base_traffic_participant import BaseTrafficParticipant
from metadrive.scenario.scenario_description import ScenarioDescription
from metadrive.component.lane.point_lane import PointLane
from metadrive.constants import Decoration, DEFAULT_AGENT
from metadrive.obs.top_down_obs import TopDownObservation
from metadrive.obs.top_down_obs_impl import WorldSurface, COLOR_BLACK, ObjectGraphics, LaneGraphics, \
ObservationWindowMultiChannel
from metadrive.utils import import_pygame, clip

from metadrive.component.road_network.node_road_network import NodeRoadNetwork
from metadrive.component.vehicle_navigation_module.node_network_navigation import NodeNetworkNavigation
from metadrive.component.vehicle_navigation_module.edge_network_navigation import EdgeNetworkNavigation
from metadrive.component.vehicle_navigation_module.trajectory_navigation import TrajectoryNavigation

pygame, gfxdraw = import_pygame()
COLOR_WHITE = pygame.Color("white")
DEFAULT_TRAJECTORY_LANE_WIDTH = 3


class TopDownMultiChannel(TopDownObservation):
Expand Down Expand Up @@ -106,16 +115,29 @@ def draw_map(self) -> pygame.Surface:
self.canvas_background.move_display_window_to(centering_pos)
self.canvas_road_network.move_display_window_to(centering_pos)

# self.draw_navigation(self.canvas_navigation)
self.draw_navigation(self.canvas_background, (64, 64, 64))
if isinstance(self.target_vehicle.navigation, NodeNetworkNavigation):
self.draw_navigation_node(self.canvas_background, (64, 64, 64))
elif isinstance(self.target_vehicle.navigation, EdgeNetworkNavigation):
# TODO: draw edge network navigation
pass
elif isinstance(self.target_vehicle.navigation, TrajectoryNavigation):
self.draw_navigation_trajectory(self.canvas_background, (64, 64, 64))

if isinstance(self.road_network, NodeRoadNetwork):
for _from in self.road_network.graph.keys():
decoration = True if _from == Decoration.start else False
for _to in self.road_network.graph[_from].keys():
for l in self.road_network.graph[_from][_to]:
two_side = True if l is self.road_network.graph[_from][_to][-1] or decoration else False
LaneGraphics.LANE_LINE_WIDTH = 0.5
LaneGraphics.display(l, self.canvas_background, two_side)
elif hasattr(self.engine, "map_manager"):
for data in self.engine.map_manager.current_map.blocks[-1].map_data.values():
if ScenarioDescription.POLYLINE in data:
LaneGraphics.display_scenario_line(
data[ScenarioDescription.POLYLINE], data[ScenarioDescription.TYPE], self.canvas_background
)

for _from in self.road_network.graph.keys():
decoration = True if _from == Decoration.start else False
for _to in self.road_network.graph[_from].keys():
for l in self.road_network.graph[_from][_to]:
two_side = True if l is self.road_network.graph[_from][_to][-1] or decoration else False
LaneGraphics.LANE_LINE_WIDTH = 0.5
LaneGraphics.display(l, self.canvas_background, two_side)
self.canvas_road_network.blit(self.canvas_background, (0, 0))
self.obs_window.reset(self.canvas_runtime)
self._should_draw_map = False
Expand All @@ -142,7 +164,8 @@ def draw_scene(self):
ego_heading = vehicle.heading_theta
ego_heading = ego_heading if abs(ego_heading) > 2 * np.pi / 180 else 0

for v in self.engine.traffic_manager.vehicles:
for v in self.engine.get_objects(lambda o: isinstance(o, BaseVehicle) or isinstance(o, BaseTrafficParticipant)
).values():
if v is vehicle:
continue
h = v.heading_theta
Expand Down Expand Up @@ -256,13 +279,17 @@ def observe(self, vehicle: BaseVehicle):
img = np.clip(img, 0, 255)
return np.transpose(img, (1, 0, 2))

def draw_navigation(self, canvas, color=(128, 128, 128)):
def draw_navigation_node(self, canvas, color=(128, 128, 128)):
checkpoints = self.target_vehicle.navigation.checkpoints
for i, c in enumerate(checkpoints[:-1]):
lanes = self.road_network.graph[c][checkpoints[i + 1]]
for lane in lanes:
LaneGraphics.draw_drivable_area(lane, canvas, color=color)

def draw_navigation_trajectory(self, canvas, color=(128, 128, 128)):
lane = PointLane(self.target_vehicle.navigation.checkpoints, DEFAULT_TRAJECTORY_LANE_WIDTH)
LaneGraphics.draw_drivable_area(lane, canvas, color=color)

def _get_stack_indices(self, length, frame_skip=None):
frame_skip = frame_skip or self.frame_skip
num = int(math.ceil(length / frame_skip))
Expand Down

0 comments on commit 3a7c740

Please sign in to comment.