In [None]:
from marl_traffic_gen.environments.envs.metadrive_environment import CustomMultiAgentMetaDrive

env = CustomMultiAgentMetaDrive(
    dict(
        map="SS",
        is_multi_agent=True,
        force_reuse_object_name=True,
        traffic_density=0.0,
        num_agents=2,
        allow_respawn=False,
        random_agent_model=False,
        crash_done=True,
        out_of_road_done=True,
        delay_done=25,
        start_seed=0,
        num_scenarios=1,
        log_level=50,
        horizon=1000,
        interface_panel=["dashboard"],
        # truncate_as_terminate=True,
        out_of_road_penalty=5,
        crash_vehicle_penalty=5,
        crash_object_penalty=5,
        crash_vehicle_cost=1,
        crash_object_cost=1,
        out_of_road_cost=2,
        success_reward=10,
        crash_sidewalk_penalty=2,
        agent_configs={
            "agent0": {
                "vehicle_model": "static_default",
                "random_color": False,
                "_specified_spawn_lane": True,
                "spawn_longitude": 40.0,
                "spawn_lateral": 7.0,
            },
            "agent1": {
                "vehicle_model": "static_default",
                "random_color": False,
                "_specified_spawn_lane": True,
                "spawn_longitude": 15.0,
                "spawn_lateral": 4.0,
            },
        },
        # use_render=True,
    )
)

# Run an episode
obs, _ = env.reset(seed=0)
print("MARL agent IDs:", list(obs.keys()))
print("------------------------")
for ob in obs:
    print(f"Observation for {ob}:")
    print(f"{obs[ob]}")
    print("------------------------")

In [None]:
RED_COLOR = (153 / 255, 0 / 255, 0 / 255)
GREY_COLOR = (128 / 255, 128 / 255, 128 / 255)


def set_vehicle_color(env: CustomMultiAgentMetaDrive, adv_agent_ids: list[int | str] = None) -> None:
    # Default to empty list if no adversarial agent IDs are provided
    if adv_agent_ids is None:
        adv_agent_ids = []

    for key, vehicle in env.engine.get_objects().items():
        if key in adv_agent_ids:
            vehicle._panda_color = RED_COLOR
        else:
            vehicle._panda_color = GREY_COLOR

In [None]:
import matplotlib.pyplot as plt

set_vehicle_color(env, adv_agent_ids=list(env.agents.keys()))

topdown = env.render(
    mode="topdown",
    window=False,
    scaling=3,
    screen_size=(600, 600),
    camera_position=(100, 0),
)

# Display with matplotlib
plt.figure(figsize=(6, 6))
plt.imshow(topdown, cmap="gray")
plt.axis("off")
plt.show()

In [None]:
from metadrive.utils import generate_gif

frames = []
while True:
    actions = {agent_id: [0.0, 0.2] for agent_id in obs}
    obs, reward, terminated, truncated, _ = env.step(actions)
    topdown = env.render(
        mode="topdown",
        window=False,
        screen_record=True,
        screen_size=(650, 650),
    )
    frames.append(topdown)
    if terminated["__all__"] or truncated["__all__"]:
        break

generate_gif(
    frames=frames,
    gif_name="demo.gif",
)