In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from flatland.trajectories.trajectories import Trajectory
from flatland.evaluators.trajectory_evaluator import TrajectoryEvaluator
from IPython.display import HTML, display, clear_output
import ipywidgets as ipw
from io import BytesIO
import PIL
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
import time      
from pathlib import Path
from flatland.envs.persistence import RailEnvPersister
import tqdm
from flatland.utils.rendertools import RenderTool
from IPython.display import display
from PIL import Image
import pandas as pd
import ast

## Get episode data

In [None]:
!wget -q https://github.com/flatland-association/flatland-scenarios/raw/refs/heads/main/trajectories/FLATLAND_BENCHMARK_EPISODES_FOLDER_v4.zip -O /tmp/FLATLAND_BENCHMARK_EPISODES_FOLDER_v4.zip
!mkdir -p /tmp/episodes
!unzip -o -qq /tmp/FLATLAND_BENCHMARK_EPISODES_FOLDER_v4.zip -d /tmp/episodes

In [None]:
!ls -al  "/tmp/episodes/30x30 map/10_trains"

## Aggregate stats

In [None]:
all_actions = []
all_trains_positions = []
all_trains_arrived = []
all_trains_rewards_dones_infos = []
env_stats = []
agent_stats = []

root_data_dir = Path("/tmp/episodes/malfunction_deadlock_avoidance_heuristics/Test_02/")
data_dirs = sorted(list(root_data_dir.glob("*")))
for data_dir in data_dirs:
    snapshots = [snapshot for snapshot in (data_dir / "serialised_state").glob("*.pkl") if "step" not in snapshot.name ]
    assert len(snapshots) == 1
    ep_id = snapshots[0].stem
    trajectory = Trajectory(data_dir=data_dir, ep_id=ep_id)
    trajectory.load()
    env = trajectory.restore_episode()
    
    all_actions.append(trajectory.actions)
    all_trains_positions.append(trajectory.trains_positions)
    all_trains_arrived.append(trajectory.trains_arrived)
    trajectory.trains_rewards_dones_infos["action_required"] = trajectory.trains_rewards_dones_infos["info"].map(lambda d: d["action_required"])
    trajectory.trains_rewards_dones_infos["malfunction"] = trajectory.trains_rewards_dones_infos["info"].map(lambda d: d["malfunction"])
    trajectory.trains_rewards_dones_infos["speed"] = trajectory.trains_rewards_dones_infos["info"].map(lambda d: d["speed"])
    trajectory.trains_rewards_dones_infos["state"] = trajectory.trains_rewards_dones_infos["info"].map(lambda d: d["state"])
    all_trains_rewards_dones_infos.append(trajectory.trains_rewards_dones_infos)
    env_stats.append(pd.DataFrame.from_records([{
        "episode_id": ep_id,
        "max_episode_steps": env._max_episode_steps,
        "num_agents": len(env.agents)
    }]))

    agent_stats.append(pd.DataFrame.from_records([{
        "episode_id": ep_id,
        "agent_id": agent.handle,
        "earliest_departure": agent.earliest_departure,
        "latest_arrival": agent.latest_arrival,
        "num_waypoints": len(agent.waypoints),
    } for agent in env.agents]))
    
all_actions = pd.concat(all_actions)
all_trains_positions = pd.concat(all_trains_positions)
all_trains_arrived = pd.concat(all_trains_arrived)
all_trains_rewards_dones_infos = pd.concat(all_trains_rewards_dones_infos)
env_stats = pd.concat(env_stats)
agent_stats = pd.concat(agent_stats)


In [None]:
all_actions

In [None]:
all_trains_positions

In [None]:
all_trains_arrived

In [None]:
all_trains_rewards_dones_infos

In [None]:
env_stats

## Results stats

In [None]:
import seaborn as sns

In [None]:
def plot_stats(df, col):
    fig, axs = plt.subplots(2)
    sns.histplot(data=df, x=col, ax = axs[0])
    sns.boxplot(x=df[col], ax = axs[1])

In [None]:
plot_stats(env_stats, "max_episode_steps")

In [None]:
plot_stats(all_trains_arrived, "success_rate")

In [None]:
plot_stats(env_stats, "num_agents")

In [None]:
plot_stats(agent_stats, "earliest_departure")

In [None]:
plot_stats(agent_stats, "latest_arrival")

In [None]:
plot_stats(agent_stats, "num_waypoints")

In [None]:
episode_cum_rewards = all_trains_rewards_dones_infos.groupby(["episode_id"]).sum(["reward"])
episode_cum_rewards

In [None]:
plot_stats(episode_cum_rewards, "reward")

In [None]:
agent_cum_rewards = all_trains_rewards_dones_infos.groupby(["episode_id", "agent_id"]).sum(["reward"])
agent_cum_rewards

In [None]:
plot_stats(agent_cum_rewards, "reward")