In [None]:
from metadrive.envs.metadrive_env import MetaDriveEnv
from metadrive.component.map.base_map import BaseMap
from metadrive.component.map.pg_map import MapGenerateMethod
from metadrive.examples.ppo_expert.torch_expert import torch_expert as expert
from metadrive.engine.logger import get_logger
from IPython.display import Image, clear_output
import pandas as pd
from pprint import pprint

import json

from PIL import Image
import numpy as np
import logging
import time

In [None]:
logger = get_logger()

In [None]:
def create_env(seed=0):
    # ===== Termination Scheme =====
    termination_sceme = dict(
        out_of_route_done=False,
        on_continuous_line_done=False,
        crash_vehicle_done=True,
        crash_object_done=True,
        crash_human_done=True,
    )
    # ===== Map Config =====
    map_config = {
        BaseMap.GENERATE_TYPE: MapGenerateMethod.BIG_BLOCK_NUM,
        BaseMap.GENERATE_CONFIG: 5,  # 20 block
    }

    cfg = dict(
        # use_render=True,
        log_level=logging.INFO,  # logging.DEBUG
        start_seed=seed,
        map_config=map_config,
        **termination_sceme
    )
    env = MetaDriveEnv(config=cfg)
    return env

In [None]:
def get_max_steps(env: MetaDriveEnv):
    decision_repeat = env.config["decision_repeat"]
    dt = env.config["physics_world_step_size"]
    distance = env.agent.navigation.total_length
    V_min = 2.0  # [m/s]  # set minimal velocity to 2m/s

    max_steps = distance / (V_min * decision_repeat * dt)
    logger.info(f"Calculating max steps with: ")
    logger.info(f"{V_min=}, {decision_repeat=}, {dt=}, {distance=}, {max_steps=}")
    return round(max_steps)

In [None]:
def serialize_step_info(info) -> dict:

    info["action"] = [float(x) for x in info["action"]]
    info["raw_action"] = [float(x) for x in info["raw_action"]]
    return info


def state_action_loop(env: MetaDriveEnv, max_step: int) -> list:

    steps_info = []
    while True:

        action = expert(env.agent, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)

        if info["episode_length"] == max_step:
            truncated = True
            info["max_step"] = True

        steps_info.append(serialize_step_info(info))

        if terminated or truncated:
            break

    return steps_info


# state_action_loop(env)

In [None]:
def process_timestamps(start_ts, initialized_ts, scenario_done_ts):

    init_time = initialized_ts - start_ts
    logger.info(f"Initializing the env took {init_time:.2f}s")

    scenario_time = scenario_done_ts - initialized_ts
    logger.info(f"Running the scenario took {scenario_time:.2f}s")

    return locals()

In [None]:
def get_map_img(env):
    map = env.current_map.get_semantic_map(
        env.current_map.get_center_point(),
    )
    map = map.squeeze()  # reduce dimensionality
    map = (map * 255 * 4).astype(np.uint8)
    img = Image.fromarray(map)
    return img

In [None]:
def run_scenario(seed: int = 0):

    start_ts = time.perf_counter()

    # initialize
    env = create_env(seed)
    _, reset_info = env.reset()

    initialized_ts = time.perf_counter()

    # running loop
    max_step = get_max_steps(env)
    steps_info = state_action_loop(env, max_step)
    scenario_done_ts = time.perf_counter()

    # save metadata
    scenario_data = process_timestamps(start_ts, initialized_ts, scenario_done_ts)

    steps_info.insert(0, reset_info)
    scenario_data["steps_infos"] = steps_info
    scenario_data["map_data"] = env.current_map.get_meta_data()["block_sequence"]
    scenario_data["max_steps"] = max_step

    with open(f"eval_data/{seed}.json", "w") as f:
        json.dump(scenario_data, f, indent=4)

    get_map_img(env).save(f"eval_data/{seed}.png")

    data_saved_ts = time.perf_counter()
    logger.info(f"Saving data took {data_saved_ts-scenario_done_ts:.2f}s")
    logger.info(f"Running scenario finished.")

    env.close()

In [None]:
run_scenario(seed=10)
run_scenario(seed=69)
run_scenario(seed=2137)