Skip to content

Commit

Permalink
Deprecate the resource expensive read_summary pipeline (#639)
Browse files Browse the repository at this point in the history
* deprecate the resource expensive read_summary pipeline, instead raise error

* remove a test for no_summary

* fix tests

* fix test
  • Loading branch information
pengzhenghao committed Feb 12, 2024
1 parent ffdb1e7 commit a4da4c1
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 107 deletions.
124 changes: 98 additions & 26 deletions metadrive/scenario/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.pyplot import figure

from metadrive.component.static_object.traffic_object import TrafficCone, TrafficBarrier
from metadrive.component.traffic_light.base_traffic_light import BaseTrafficLight
from metadrive.component.traffic_participants.cyclist import Cyclist
from metadrive.component.traffic_participants.pedestrian import Pedestrian
from metadrive.component.vehicle.base_vehicle import BaseVehicle
from metadrive.constants import DATA_VERSION, DEFAULT_AGENT
from metadrive.engine import get_logger
from metadrive.scenario import ScenarioDescription as SD
from metadrive.scenario.scenario_description import ScenarioDescription
from metadrive.type import MetaDriveType
Expand All @@ -21,6 +23,19 @@
VELOCITY_DECIMAL = 1 # velocity can not be set accurately
MIN_LENGTH_RATIO = 0.8

logger = get_logger()


def dict_recursive_remove_array_and_set(d):
if isinstance(d, np.ndarray):
return d.tolist()
if isinstance(d, set):
return tuple(d)
if isinstance(d, dict):
for k in d.keys():
d[k] = dict_recursive_remove_array_and_set(d[k])
return d


def draw_map(map_features, show=False):
figure(figsize=(8, 6), dpi=500)
Expand Down Expand Up @@ -94,10 +109,19 @@ def find_data_manager_name(manager_info):

def convert_recorded_scenario_exported(record_episode, scenario_log_interval=0.1, to_dict=True):
"""
This function utilizes the recorded data natively emerging from MetaDrive run.
The output data structure follows MetaDrive data format, but some changes might happen compared to original data.
For example, MetaDrive InterpolateLane will reformat the Lane data and making all waypoints equal distancing.
We call this lane sampling rate, which is 0.2m in MetaDrive but might different in other dataset.
This function converts the internal run data of MetaDrive to the ScenarioNet's scenario description.
The output data structure follows MetaDrive data format, but with some changes compared to the original data.
For example, MetaDrive InterpolateLane will reformat the Lane data and making all waypoints equally distancing.
We call this lane sampling rate, which is 0.2m in MetaDrive but might different in other datasets.
Args:
record_episode: the internal data structure from MetaDrive run.
scenario_log_interval: the time interval for one step.
to_dict: whether to return a python dict or a ScenarioDescription object.
Returns:
a python dict or a ScenarioDescription object.
"""
result = SD()

Expand Down Expand Up @@ -294,18 +318,6 @@ def convert_recorded_scenario_exported(record_episode, scenario_log_interval=0.1
result[SD.TRACKS] = tracks
result[SD.DYNAMIC_MAP_STATES] = lights

# # Traffic Light: Straight-through forward from original data
# for k, manager_state in record_episode["manager_metadata"].items():
# if "DataManager" in k:
# if "raw_data" in manager_state:
# original_dynamic_map = copy.deepcopy(manager_state["raw_data"][SD.DYNAMIC_MAP_STATES])
# clipped_dynamic_map = {}
# for obj_id, obj_state in original_dynamic_map.items():
# obj_state["state"] = {k: v[:episode_len] for k, v in obj_state["state"].items()}
# clipped_dynamic_map[obj_id] = obj_state
# result[SD.METADATA]["history_metadata"] = manager_state["raw_data"][SD.METADATA]
# result[SD.DYNAMIC_MAP_STATES] = clipped_dynamic_map

# Record agent2object, object2agent metadata
result[SD.METADATA]["agent_to_object"] = {str(k): str(v) for k, v in agent_to_object.items()}
result[SD.METADATA]["object_to_agent"] = {str(k): str(v) for k, v in object_to_agent.items()}
Expand Down Expand Up @@ -364,17 +376,19 @@ def read_dataset_summary(file_folder, check_file_existence=True):
summary_dict = pickle.load(f)

else:
raise ValueError(f"Summary file is not found at {summary_file}!")

# === The following is deprecated ===
# Create a fake one
files = []
for file in os.listdir(file_folder):
if SD.is_scenario_file(os.path.basename(file)):
files.append(file)
try:
files = sorted(files, key=lambda file_name: int(file_name.replace(".pkl", "")))
except ValueError:
files = sorted(files, key=lambda file_name: file_name.replace(".pkl", ""))
files = [p for p in files]
summary_dict = {f: read_scenario_data(os.path.join(file_folder, f))["metadata"] for f in files}
# files = []
# for file in os.listdir(file_folder):
# if SD.is_scenario_file(os.path.basename(file)):
# files.append(file)
# try:
# files = sorted(files, key=lambda file_name: int(file_name.replace(".pkl", "")))
# except ValueError:
# files = sorted(files, key=lambda file_name: file_name.replace(".pkl", ""))
# summary_dict = {f: read_scenario_data(os.path.join(file_folder, f))["metadata"] for f in files}

mapping = None
if os.path.exists(mapping_file):
Expand All @@ -395,6 +409,64 @@ def read_dataset_summary(file_folder, check_file_existence=True):
return summary_dict, list(summary_dict.keys()), mapping


def extract_dataset_summary_and_mapping(scenario_list, dataset_name, dataset_version):
"""Extract the dataset summary and mapping for dataset_summary.pkl and dataset_mapping.pkl.
Args:
scenario_list: A list of Scenario Description objects.
Returns:
Summary dict, mapping dict, scenario dict
"""
summary = {}
mapping = {}
scenario_dict = {}
for sd_scenario in scenario_list:
scenario_id = sd_scenario[SD.ID]
export_file_name = SD.get_export_file_name(dataset_name, dataset_version, scenario_id)

sd_scenario = SD(sd_scenario)
if hasattr(SD, "update_summaries"):
SD.update_summaries(sd_scenario)
else:
raise ValueError("Please update MetaDrive to latest version.")

# update summary/mapping dict
if export_file_name in summary:
logger.warning("Scenario {} already exists and will be overwritten!".format(export_file_name))
summary[export_file_name] = copy.deepcopy(sd_scenario[SD.METADATA])
mapping[export_file_name] = "" # in the same dir

# sanity check
sd_scenario = sd_scenario.to_dict()
SD.sanity_check(sd_scenario, check_self_type=True)

scenario_dict[export_file_name] = sd_scenario

return summary, mapping, scenario_dict


def save_dataset(scenario_list, dataset_name, dataset_version, dataset_dir):
summary, mapping, scenario_dict = extract_dataset_summary_and_mapping(scenario_list, dataset_name, dataset_version)
summary_file = pathlib.Path(dataset_dir) / ScenarioDescription.DATASET.SUMMARY_FILE
summary_file = summary_file.resolve()
mapping_file = pathlib.Path(dataset_dir) / ScenarioDescription.DATASET.MAPPING_FILE
mapping_file = mapping_file.resolve()
os.makedirs(dataset_dir, exist_ok=True)
for file_name, scenario in scenario_dict.items():
file_path = pathlib.Path(dataset_dir) / file_name
with open(file_path, "wb") as file:
pickle.dump(scenario, file)
with open(summary_file, "wb") as file:
pickle.dump(dict_recursive_remove_array_and_set(summary), file)
with open(mapping_file, "wb") as file:
pickle.dump(mapping, file)
print(
"\n ================ Dataset Summary and Mapping are saved at: {} "
"================ \n".format(summary_file)
)


def get_number_of_scenarios(dataset_path):
_, files, _ = read_dataset_summary(dataset_path)
return len(files)
Expand Down
24 changes: 12 additions & 12 deletions metadrive/tests/test_export_record_scenario/test_connectivity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import copy
import os
import pickle
import pathlib
import shutil

from metadrive.component.map.base_map import BaseMap
Expand All @@ -10,6 +9,7 @@
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.policy.idm_policy import IDMPolicy
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
from metadrive.scenario.utils import save_dataset


def test_search_path(render_export_env=False, render_load_env=False):
Expand All @@ -34,11 +34,11 @@ def test_search_path(render_export_env=False, render_load_env=False):

try:
scenarios, done_info = env.export_scenarios(policy, scenario_index=[i for i in range(1)])
dir = os.path.join(os.path.dirname(__file__), "../test_component/test_export")
os.makedirs(dir, exist_ok=True)
for i, data in scenarios.items():
with open(os.path.join(dir, "{}.pkl".format(i)), "wb+") as file:
pickle.dump(data, file)

dir = pathlib.Path(os.path.dirname(__file__)) / "../test_component/test_export"
save_dataset(
scenario_list=list(scenarios.values()), dataset_name="reconstructed", dataset_version="v0", dataset_dir=dir
)
node_roadnet = copy.deepcopy(env.current_map.road_network)
env.close()

Expand All @@ -47,11 +47,11 @@ def test_search_path(render_export_env=False, render_load_env=False):
dict(agent_policy=ReplayEgoCarPolicy, data_directory=dir, use_render=render_load_env, num_scenarios=1)
)
scenarios, done_info = env.export_scenarios(policy, scenario_index=[i for i in range(1)])
dir = os.path.join(os.path.dirname(__file__), "../test_component/test_export")
os.makedirs(dir, exist_ok=True)
for i, data in scenarios.items():
with open(os.path.join(dir, "{}.pkl".format(i)), "wb+") as file:
pickle.dump(data, file)

dir = pathlib.Path(os.path.dirname(__file__)) / "../test_component/test_export"
save_dataset(
scenario_list=list(scenarios.values()), dataset_name="reconstructed", dataset_version="v0", dataset_dir=dir
)
env.close()

# reload
Expand Down
48 changes: 27 additions & 21 deletions metadrive/tests/test_export_record_scenario/test_export_scenario.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import pickle
import pathlib
import shutil

from metadrive.envs.metadrive_env import MetaDriveEnv
from metadrive.envs.scenario_env import ScenarioEnv, AssetLoader
from metadrive.policy.idm_policy import IDMPolicy
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
from metadrive.scenario.utils import save_dataset


def test_export_metadrive_scenario(render_export_env=False, render_load_env=False):
Expand All @@ -14,20 +15,23 @@ def test_export_metadrive_scenario(render_export_env=False, render_load_env=Fals
dict(start_seed=0, use_render=render_export_env, num_scenarios=num_scenarios, agent_policy=IDMPolicy)
)
policy = lambda x: [0, 1]
dir = None
dataset_dir = None
try:
scenarios, done_info = env.export_scenarios(policy, scenario_index=[i for i in range(num_scenarios)])
dir = os.path.join(os.path.dirname(__file__), "../test_component/test_export")
os.makedirs(dir, exist_ok=True)
for i, data in scenarios.items():
with open(os.path.join(dir, "{}.pkl".format(i)), "wb+") as file:
pickle.dump(data, file)

dataset_dir = pathlib.Path(os.path.dirname(__file__)) / "../test_component/test_export"
save_dataset(
scenario_list=list(scenarios.values()),
dataset_name="reconstructed_waymo",
dataset_version="v0",
dataset_dir=dataset_dir
)
env.close()

env = ScenarioEnv(
dict(
agent_policy=ReplayEgoCarPolicy,
data_directory=dir,
data_directory=dataset_dir,
use_render=render_load_env,
num_scenarios=num_scenarios
)
Expand All @@ -40,8 +44,8 @@ def test_export_metadrive_scenario(render_export_env=False, render_load_env=Fals
done = tm or tc
finally:
env.close()
if dir is not None:
shutil.rmtree(dir)
if dataset_dir is not None:
shutil.rmtree(dataset_dir)


def test_export_waymo_scenario(num_scenarios=3, render_export_env=False, render_load_env=False):
Expand All @@ -55,23 +59,25 @@ def test_export_waymo_scenario(num_scenarios=3, render_export_env=False, render_
)
)
policy = lambda x: [0, 1]
dir = None
dataset_dir = None
try:
scenarios, done_info = env.export_scenarios(
policy, scenario_index=[i for i in range(num_scenarios)], verbose=True
)
dir = os.path.join(os.path.dirname(__file__), "../test_component/test_export")
os.makedirs(dir, exist_ok=True)
for i, data in scenarios.items():
with open(os.path.join(dir, "{}.pkl".format(i)), "wb+") as file:
pickle.dump(data, file)
dataset_dir = pathlib.Path(os.path.dirname(__file__)) / "../test_component/test_export"
save_dataset(
scenario_list=list(scenarios.values()),
dataset_name="reconstructed_waymo",
dataset_version="v0",
dataset_dir=dataset_dir
)
env.close()

print("===== Start restoring =====")
env = ScenarioEnv(
dict(
agent_policy=ReplayEgoCarPolicy,
data_directory=dir,
data_directory=dataset_dir,
use_render=render_load_env,
num_scenarios=num_scenarios
)
Expand All @@ -88,10 +94,10 @@ def test_export_waymo_scenario(num_scenarios=3, render_export_env=False, render_
print("Finish replaying scenario {} with step {}".format(index, count))
finally:
env.close()
if dir is not None:
shutil.rmtree(dir)
if dataset_dir is not None:
shutil.rmtree(dataset_dir)


if __name__ == "__main__":
# test_export_metadrive_scenario(render_export_env=False, render_load_env=False)
test_export_waymo_scenario(num_scenarios=3, render_export_env=False, render_load_env=False)
test_export_metadrive_scenario(render_export_env=False, render_load_env=False)
# test_export_waymo_scenario(num_scenarios=3, render_export_env=False, render_load_env=False)

0 comments on commit a4da4c1

Please sign in to comment.