Skip to content

Commit

Permalink
Supporting NuPlan env to export MetaDrive Scenario Description (#414)
Browse files Browse the repository at this point in the history
* fix nuplan

* test nuplan

* format

* rename

* revert opendrive changes

* fix nuplan issue in win

* format

* fix

* move assert scenario equal

* move

* fix bug

---------

Co-authored-by: pengzhenghao <pzh@cs.ucla.edu>
  • Loading branch information
QuanyiLi and pengzhenghao committed Apr 7, 2023
1 parent 5843289 commit 643c85e
Show file tree
Hide file tree
Showing 17 changed files with 388 additions and 240 deletions.
2 changes: 1 addition & 1 deletion metadrive/component/nuplan_block/nuplan_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def destroy(self):
def __del__(self):
self.destroy()
super(NuPlanBlock, self).__del__()
print("NuPlan Block is being deleted.")
# print("NuPlan Block is being deleted.")

@staticmethod
def _get_points_from_boundary(boundary, center):
Expand Down
13 changes: 6 additions & 7 deletions metadrive/envs/real_data_envs/nuplan_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _is_out_of_road(self, vehicle):
# "multi_thread_render_mode": "Cull/Draw",
"start_scenario_index": 0,
# "pstats": True,
"num_scenarios": 400,
"num_scenarios": 2,
"show_coordinates": False,
"horizon": 1000,
# "show_fps": False,
Expand All @@ -318,19 +318,18 @@ def _is_out_of_road(self, vehicle):
"force_render_fps": 40,
"show_fps": True,
"DATASET_PARAMS": [
'scenario_builder=nuplan_mini',
# use nuplan mini database (2.5h of 8 autolabeled logs in Las Vegas)
'scenario_builder=nuplan_mini', # use nuplan mini database (2.5h of 8 autolabeled logs in Las Vegas)
'scenario_filter=one_continuous_log', # simulate only one log
"scenario_filter.log_names=['2021.09.16.15.12.03_veh-42_01037_01434']",
'scenario_filter.limit_total_scenarios=1000', # use 2 total scenarios
"scenario_filter.log_names=['2021.05.12.22.00.38_veh-35_01008_01518']",
'scenario_filter.limit_total_scenarios=2', # use 2 total scenarios
],
"show_mouse": True,
}
)
success = []
env.reset(8)
env.reset()
for seed in [8, 14] * 10:
env.reset(seed)
env.reset()
# env.reset(seed)
for i in range(env.engine.data_manager.current_scenario_length * 10):
o, r, d, info = env.step([0, 0])
Expand Down
41 changes: 20 additions & 21 deletions metadrive/manager/nuplan_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from nuplan.planning.script.utils import set_up_common_builder
from dataclasses import dataclass
from metadrive.manager.base_manager import BaseManager
from metadrive.utils.utils import is_win

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -108,27 +109,25 @@ def _get_nuplan_cfg(self):
DATASET_PARAMS = self.engine.global_config["DATASET_PARAMS"]

# Compose the configuration
cfg = hydra.compose(
config_name=simulation_hydra_paths.config_name,
overrides=[
f'group={SAVE_DIR}',
'worker=sequential',
f'ego_controller={EGO_CONTROLLER}',
f'observation={OBSERVATION}',
f'hydra.searchpath=[{simulation_hydra_paths.common_dir}, {simulation_hydra_paths.experiment_dir}]',
'output_dir=${group}/${experiment}',
*DATASET_PARAMS,

# TODO: Check which one is correct.
# Option 1: LQY write:
# f'experiment_name=planner_tutorial',

# Option 2: planning tutorial:
# Copied from tutorial
f'job_name=planner_tutorial',
'experiment=${experiment_name}/${job_name}/${experiment_time}',
]
)
overrides = [
f'group={SAVE_DIR}',
'worker=sequential',
f'ego_controller={EGO_CONTROLLER}',
f'observation={OBSERVATION}',
f'hydra.searchpath=[{simulation_hydra_paths.common_dir}, {simulation_hydra_paths.experiment_dir}]',
'output_dir=${group}/${experiment}',
*DATASET_PARAMS,
]
if is_win():
overrides.extend(
[
f'job_name=planner_tutorial',
'experiment=${experiment_name}/${job_name}/${experiment_time}',
]
)
else:
overrides.append(f'experiment_name=planner_tutorial')
cfg = hydra.compose(config_name=simulation_hydra_paths.config_name, overrides=overrides)
return cfg

def get_scenario(self, index, force_get_current_scenario=True):
Expand Down
28 changes: 26 additions & 2 deletions metadrive/manager/nuplan_light_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import copy

from metadrive.component.traffic_light.nuplan_traffic_light import NuplanTrafficLight
from metadrive.manager.base_manager import BaseManager
from metadrive.scenario.scenario_description import ScenarioDescription as SD


class NuPlanLightManager(BaseManager):
Expand All @@ -8,18 +11,26 @@ class NuPlanLightManager(BaseManager):
def __init__(self):
super(NuPlanLightManager, self).__init__()
self._lane_to_lights = {}
self.nuplan_id_to_obj_id = {}
self.obj_id_to_nuplan_id = {}

self._episode_light_data = None

def before_reset(self):
super(NuPlanLightManager, self).before_reset()
self._lane_to_lights = {}
self.nuplan_id_to_obj_id = {}
self.obj_id_to_nuplan_id = {}

self._episode_light_data = self._get_episode_light_data()

def after_reset(self):
for light in self._episode_light_data[0]:
lane_info = self.engine.current_map.road_network.graph[str(light.lane_connector_id)]
traffic_light = self.spawn_object(NuplanTrafficLight, lane=lane_info.lane)
self._lane_to_lights[lane_info.lane.index] = traffic_light
self.nuplan_id_to_obj_id[str(light.lane_connector_id)] = traffic_light.name
self.obj_id_to_nuplan_id[traffic_light.name] = str(light.lane_connector_id)
traffic_light.set_status(light.status)

def after_step(self, *args, **kwargs):
Expand All @@ -31,8 +42,11 @@ def after_step(self, *args, **kwargs):
if self.CLEAR_LIGHTS:
light_to_eliminate = self._lane_to_lights.keys() - set([str(i.lane_connector_id) for i in step_data])
for lane_id in light_to_eliminate:
self.clear_objects([self._lane_to_lights[lane_id].id])
self._lane_to_lights.pop(lane_id)
light = self._lane_to_lights.pop(lane_id)
nuplan_id = self.obj_id_to_nuplan_id.pop(light.id)
obj_id = self.nuplan_id_to_obj_id.pop(nuplan_id)
assert obj_id == light.name
self.clear_objects([obj_id])

for light in self._episode_light_data[self.episode_step]:
if str(light.lane_connector_id) in self._lane_to_lights:
Expand All @@ -45,8 +59,12 @@ def after_step(self, *args, **kwargs):
traffic_light = self.spawn_object(NuplanTrafficLight, lane=lane_info.lane)
assert str(light.lane_connector_id) == lane_info.lane.index
self._lane_to_lights[lane_info.lane.index] = traffic_light
self.nuplan_id_to_obj_id[str(light.lane_connector_id)] = traffic_light.name
self.obj_id_to_nuplan_id[traffic_light.name] = str(light.lane_connector_id)
traffic_light.set_status(light.status)

assert len(self._lane_to_lights) == len(self.nuplan_id_to_obj_id) == len(self.obj_id_to_nuplan_id)

def has_traffic_light(self, lane_index):
return True if lane_index in self._lane_to_lights else False

Expand All @@ -65,3 +83,9 @@ def _get_episode_light_data(self):
length = self.engine.data_manager.current_scenario.get_number_of_iterations()
ret = {i: [t for t in self.traffic_light_status_at(i)] for i in range(length)}
return ret

def get_state(self):
return {
SD.OBJ_ID_TO_ORIGINAL_ID: copy.deepcopy(self.obj_id_to_nuplan_id),
SD.ORIGINAL_ID_TO_OBJ_ID: copy.deepcopy(self.nuplan_id_to_obj_id)
}
27 changes: 24 additions & 3 deletions metadrive/manager/nuplan_traffic_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@
from metadrive.component.traffic_participants.cyclist import Cyclist
from metadrive.component.traffic_participants.pedestrian import Pedestrian
from metadrive.component.vehicle.vehicle_type import get_vehicle_type
from metadrive.constants import DEFAULT_AGENT
from metadrive.manager.base_manager import BaseManager
from metadrive.policy.replay_policy import NuPlanReplayTrafficParticipantPolicy
from metadrive.scenario.scenario_description import ScenarioDescription as SD
from metadrive.utils.nuplan_utils.parse_object_state import parse_object_state


class NuPlanTrafficManager(BaseManager):
EGO_TOKEN = "ego"

def __init__(self):
super(NuPlanTrafficManager, self).__init__()
self.nuplan_id_to_obj_id = {}
self.obj_id_to_nuplan_id = {}
self.need_traffic = not self.engine.global_config["no_traffic"]
self.need_pedestrian = not self.engine.global_config["no_pedestrian"]
self._episode_traffic_data = None
Expand All @@ -23,7 +28,9 @@ def __init__(self):
def after_reset(self):
self._episode_traffic_data = self._get_episode_traffic_data()
assert self.engine.episode_step == 0
self.nuplan_id_to_obj_id = {}
# according to scenario.initial_ego_state, the ego token is ego
self.nuplan_id_to_obj_id = {self.EGO_TOKEN: self.engine.agents[DEFAULT_AGENT].id}
self.obj_id_to_nuplan_id = {self.engine.agents[DEFAULT_AGENT].id: self.EGO_TOKEN}
for nuplan_id, obj_state in self._episode_traffic_data[0].items():
if obj_state.tracked_object_type == TrackedObjectType.VEHICLE and self.need_traffic:
state = parse_object_state(obj_state, self.engine.current_map.nuplan_center)
Expand Down Expand Up @@ -64,8 +71,12 @@ def after_step(self, *args, **kwargs):
self.spawn_pedestrian(state, nuplan_id)

for nuplan_id in list(vehicles_to_eliminate):
self.clear_objects([self.nuplan_id_to_obj_id[nuplan_id]])
self.nuplan_id_to_obj_id.pop(nuplan_id)
if nuplan_id != self.EGO_TOKEN:
self.clear_objects([self.nuplan_id_to_obj_id[nuplan_id]])
obj_id = self.nuplan_id_to_obj_id.pop(nuplan_id)
assert nuplan_id == self.obj_id_to_nuplan_id.pop(obj_id)

assert len(self.nuplan_id_to_obj_id) == len(self.obj_id_to_nuplan_id)
return dict(default_agent=dict(replay_done=False))

@property
Expand Down Expand Up @@ -111,6 +122,7 @@ def spawn_vehicle(self, state, nuplan_id):
vehicle_config=v_config,
)
self.nuplan_id_to_obj_id[nuplan_id] = v.name
self.obj_id_to_nuplan_id[v.name] = nuplan_id
v.set_velocity(state["velocity"])
v.set_position(state["position"], 0.5)
self.add_policy(v.name, NuPlanReplayTrafficParticipantPolicy, v)
Expand All @@ -122,6 +134,7 @@ def spawn_pedestrian(self, state, nuplan_id):
heading_theta=state["heading"],
)
self.nuplan_id_to_obj_id[nuplan_id] = obj.name
self.obj_id_to_nuplan_id[obj.name] = nuplan_id
obj.set_velocity(state["velocity"])
self.add_policy(obj.name, NuPlanReplayTrafficParticipantPolicy, obj)

Expand All @@ -132,6 +145,7 @@ def spawn_cyclist(self, state, nuplan_id):
heading_theta=state["heading"],
)
self.nuplan_id_to_obj_id[nuplan_id] = obj.name
self.obj_id_to_nuplan_id[obj.name] = nuplan_id
obj.set_velocity(state["velocity"])
self.add_policy(obj.name, NuPlanReplayTrafficParticipantPolicy, obj)

Expand All @@ -149,3 +163,10 @@ def is_outlier(self, nuplan_id):
return True
else:
return False

def get_state(self):
# Record mapping from original_id to new_id
ret = {}
ret[SD.ORIGINAL_ID_TO_OBJ_ID] = copy.deepcopy(self.nuplan_id_to_obj_id)
ret[SD.OBJ_ID_TO_ORIGINAL_ID] = copy.deepcopy(self.obj_id_to_nuplan_id)
return ret
35 changes: 17 additions & 18 deletions metadrive/policy/replay_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,25 +102,24 @@ def get_trajectory_info(self, *args, **kwargs):
return parse_ego_vehicle_state_trajectory(scenario, self.engine.current_map.nuplan_center)

def act(self, *args, **kwargs):

if self.episode_step < len(self.traj_info):
self.control_object.set_position(self.traj_info[int(self.episode_step)]["position"])
if self.episode_step < len(self.traj_info) - 1:
velocity = self.traj_info[int(self.episode_step +
1)]["position"] - self.traj_info[int(self.episode_step)]["position"]
velocity /= self.sim_time_interval
self.control_object.set_velocity(velocity, in_local_frame=False)
else:
velocity = self.traj_info[int(self.episode_step)]["velocity"]
self.control_object.set_velocity(velocity, in_local_frame=True)
# self.control_object.set_velocity(self.traj_info[int(self.episode_step)]["velocity"])
if self.heading is None or self.episode_step >= len(self.traj_info):
pass
if self.episode_step >= len(self.traj_info):
return

self.control_object.set_position(self.traj_info[int(self.episode_step)]["position"])
if self.episode_step < len(self.traj_info) - 1:
velocity = self.traj_info[int(self.episode_step + 1)]["position"] - self.traj_info[int(self.episode_step
)]["position"]
velocity /= self.sim_time_interval
self.control_object.set_velocity(velocity, in_local_frame=False)
else:
this_heading = self.traj_info[int(self.episode_step)]["heading"]
angular_v = self.traj_info[int(self.episode_step)]["angular_velocity"]
self.control_object.set_heading_theta(this_heading)
self.control_object.set_angular_velocity(angular_v)
velocity = self.traj_info[int(self.episode_step)]["velocity"]
self.control_object.set_velocity(velocity, in_local_frame=True)
# self.control_object.set_velocity(self.traj_info[int(self.episode_step)]["velocity"])

this_heading = self.traj_info[int(self.episode_step)]["heading"]
angular_v = self.traj_info[int(self.episode_step)]["angular_velocity"]
self.control_object.set_heading_theta(this_heading)
self.control_object.set_angular_velocity(angular_v)

return [0, 0]

Expand Down
41 changes: 31 additions & 10 deletions metadrive/scenario/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import copy

import pickle

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.pyplot import figure

from metadrive.component.traffic_light.base_traffic_light import BaseTrafficLight
from metadrive.component.traffic_participants.cyclist import Cyclist
from metadrive.component.traffic_participants.pedestrian import Pedestrian
Expand Down Expand Up @@ -51,6 +51,26 @@ def _convert_type_to_string(nested):
return nested


def find_light_manager_name(manager_info):
"""
Find the light_manager in real data manager
"""
for manager_name in manager_info:
if "LightManager" in manager_name:
return manager_name
return None


def find_traffic_manager_name(manager_info):
"""
Find the traffic_manager in real data manager
"""
for manager_name in manager_info:
if "TrafficManager" in manager_name and manager_name != "PGTrafficManager":
return manager_name
return None


def convert_recorded_scenario_exported(record_episode, scenario_log_interval=0.1):
"""
This function utilizes the recorded data natively emerging from MetaDrive run.
Expand Down Expand Up @@ -96,6 +116,9 @@ def convert_recorded_scenario_exported(record_episode, scenario_log_interval=0.1
for frame in frames:
all_objs.update(frame.step_info.keys())

traffic_manager_name = find_traffic_manager_name(record_episode["manager_metadata"])
light_manager_name = find_light_manager_name(record_episode["manager_metadata"])

tracks = {
k: dict(
type=MetaDriveType.UNSET,
Expand All @@ -116,10 +139,9 @@ def convert_recorded_scenario_exported(record_episode, scenario_log_interval=0.1
}

all_lights = set()
# TODO LightManager may have different name
for frame in frames:
if "ScenarioLightManager" in frame.manager_info:
all_lights.update(frame.manager_info["ScenarioLightManager"][SD.ORIGINAL_ID_TO_OBJ_ID].keys())
if light_manager_name is not None:
for frame in frames:
all_lights.update(frame.manager_info[light_manager_name][SD.ORIGINAL_ID_TO_OBJ_ID].keys())

lights = {
k: dict(
Expand All @@ -146,10 +168,10 @@ def convert_recorded_scenario_exported(record_episode, scenario_log_interval=0.1
# pop id from tracks
if id in tracks:
tracks.pop(id)
assert "ScenarioLightManager" in frames[frame_idx].manager_info, "Can not find light manager info"
assert light_manager_name in frames[frame_idx].manager_info, "Can not find light manager info"

# convert to original id
id = frames[frame_idx].manager_info["ScenarioLightManager"][SD.OBJ_ID_TO_ORIGINAL_ID][id]
id = frames[frame_idx].manager_info[light_manager_name][SD.OBJ_ID_TO_ORIGINAL_ID][id]

lights[id]["type"] = type
lights[id][SD.METADATA]["type"] = lights[id]["type"]
Expand Down Expand Up @@ -199,9 +221,8 @@ def convert_recorded_scenario_exported(record_episode, scenario_log_interval=0.1
if id in frames[frame_idx]._object_to_agent:
tracks[id]["metadata"]["agent_name"] = frames[frame_idx]._object_to_agent[id]

# TODO Manager may have different name
if "ScenarioTrafficManager" in frames[frame_idx].manager_info:
origin_id = frames[frame_idx].manager_info["ScenarioTrafficManager"][SD.OBJ_ID_TO_ORIGINAL_ID][id]
if traffic_manager_name is not None:
origin_id = frames[frame_idx].manager_info[traffic_manager_name][SD.OBJ_ID_TO_ORIGINAL_ID][id]
if tracks[id]["metadata"]["original_id"] == id:
tracks[id]["metadata"]["original_id"] = origin_id
else:
Expand Down

0 comments on commit 643c85e

Please sign in to comment.