In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from flatland.trajectories.trajectories import Trajectory
from flatland.evaluators.trajectory_evaluator import TrajectoryEvaluator
from IPython.display import HTML, display, clear_output
import ipywidgets as ipw
from io import BytesIO
import PIL
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
import time      
from pathlib import Path
from flatland.envs.persistence import RailEnvPersister
import tqdm
from flatland.utils.rendertools import RenderTool
from IPython.display import display
from PIL import Image
import pandas as pd
import ast

## Get episode data

In [None]:
!wget https://github.com/flatland-association/flatland-scenarios/raw/refs/heads/main/trajectories/FLATLAND_BENCHMARK_EPISODES_FOLDER_v3.zip -O /tmp/FLATLAND_BENCHMARK_EPISODES_FOLDER_v3.zip
!mkdir -p /tmp/episodes
!unzip -o /tmp/FLATLAND_BENCHMARK_EPISODES_FOLDER_v3.zip -d /tmp/episodes

In [None]:
!ls -al  "/tmp/episodes/30x30 map/10_trains"

## Rendering

In [None]:
# terribly slow - generate the images in the runner withe the same RenderTool althrough
def create_frames(snapshots):
    frames = []
    for p in tqdm.tqdm(snapshots):
        env, env_dict = RailEnvPersister.load_new(str(p))
        env_renderer = RenderTool(env, gl="PGL", show_debug=True)
        env_renderer.render_env(show=False, show_observations=False)
        frame = env_renderer.get_image()
        frames.append(frame)

def process_frames(frames, frames_per_second=1000/20):
    dpi = 72
    interval = frames_per_second # ms

    plt.figure(figsize=(frames[0].shape[1]/dpi,frames[0].shape[0]/dpi),dpi=dpi)
    plt.axis=('off')
    plot = plt.imshow(frames[0])

    def init():
        pass

    def update(i):
        plot.set_data(frames[i])
        return plot,

    anim = FuncAnimation(fig=plt.gcf(),
                      func=update,
                      frames=len(frames),
                      init_func=init,
                      interval=interval,
                      repeat=True,
                      repeat_delay=20)
    plt.close(anim._fig)
    return anim

## Run trajectory

In [None]:
# data_dir = "/tmp/episodes/30x30 map/10_trains"
# ep_id = "1649ef98-e3a8-4dd3-a289-bbfff12876ce"

# data_dir = "/tmp/episodes/malfunction_deadlock_avoidance_heuristics/Test_02/Level_6"
# ep_id = "Test_02_Level_6"

data_dir = "/tmp/episodes/malfunction_deadlock_avoidance_heuristics/Test_01/Level_3"
ep_id = "Test_01_Level_3"

In [None]:
# run with snapshots
trajectory = Trajectory(data_dir=data_dir, ep_id=ep_id)
TrajectoryEvaluator(trajectory).evaluate(snapshot_interval=1)

In [None]:
!find "$data_dir" -name "$ep_id""_step*.pkl" | sort -u | head

In [None]:
snapshots = list((Path(data_dir) /"serialised_state").glob(f'{ep_id}_step*.pkl'))
snapshots.sort()
#snapshots

## Animate trajectory

In [None]:
frames = create_frames(snapshots)

In [None]:
anim = process_frames(frames)
HTML(anim.to_jshtml())

In [None]:
!python -m pip install ipyplot

In [None]:
import ipyplot

ipyplot.plot_images(frames, img_width=400)

## Aggregate stats

In [None]:
dfs = []
for env_time, snapshot in enumerate(snapshots):
    env, env_dict = RailEnvPersister.load_new(str(snapshot))
    records = [{
        "env_time": env_time, "source": snapshot, **agent.to_agent()._asdict()
    } for agent in env_dict["agents"]]
    for record in records:
        record.update(record["speed_counter"].to_dict())
        record.update(record["malfunction_handler"].to_dict())
        record.update(record["action_saver"].to_dict())
    df = pd.DataFrame.from_records(records)
    # print(df)
    dfs.append(df)
stats = pd.concat(dfs)
stats["agent_id"] = stats["handle"]
stats.set_index(['env_time', 'agent_id'], verify_integrity=True, inplace=True)
stats

In [None]:
trajectory_actions = trajectory.read_actions()
trajectory_actions = trajectory_actions[(trajectory_actions["episode_id"]==ep_id)]
trajectory_actions.set_index(['env_time', 'agent_id'], verify_integrity=True, inplace=True)
trajectory_actions

In [None]:
trajectory_positions = trajectory.read_trains_positions()
trajectory_positions = trajectory_positions[(trajectory_positions["episode_id"]==ep_id)]
trajectory_positions.set_index(['env_time', 'agent_id'], verify_integrity=True, inplace=True)
trajectory_positions["direction"] = trajectory_positions['position'].apply(lambda x: ast.literal_eval(x)[1])
trajectory_positions["position"] = trajectory_positions['position'].apply(lambda x: ast.literal_eval(x)[0])
trajectory_positions

In [None]:
df = trajectory_actions.join(trajectory_positions, lsuffix="_actions").join(stats,lsuffix="_expected")
df.reset_index(inplace=True)
df

In [None]:
df[["position", "position_expected"]].head(200)

In [None]:
df[["position", "position_expected"]].head(200)

In [None]:
cond = df["agent_id"]==6
cond &= df["malfunction_down_counter"]>0
df[cond]

## Inspect single steps

In [None]:
def inspect(step):
    p = Path(data_dir) /"serialised_state" / f'{ep_id}_step{step:04d}.pkl'
    env, env_dict = RailEnvPersister.load_new(str(p))
    
    return env, env_dict

def show_frame(env, dpi = 40):
    env_renderer = RenderTool(env, gl="PGL", show_debug=True)
    env_renderer.render_env(show=False, show_observations=False)
    frame = env_renderer.get_image()
    print(frame.shape)
    
    plt.figure(figsize = (frame.shape[1]/dpi, frame.shape[0]/dpi))
    plt.xticks ([ (i+0.5)/env.width * frame.shape[1] for i in range(env.width)], [ i for i in range(env.width)])
    plt.yticks ([ (i+0.5)/env.height * frame.shape[0] for i in range(env.height)], [ i for i in range(env.height)])
    plt.imshow(frame)

In [None]:
env0, env_dict0 = inspect(81)
env0.agents[6]

In [None]:
env, env_dict = inspect(82)
show_frame(env)

In [None]:
for agent in env.agents:
    print(f"[{env._elapsed_steps}][{agent.handle}] {agent.position} {agent.direction}")

In [None]:
env.agents[6]

In [None]:
env.cur_episode

In [None]:
edges = [((-1, 0), (-1, 0)), ((10, 23), (10, 23)), ((14, 8), (14, 8)), ((-1, 3), (-1, 3)), ((21, 11), (20, 11)), ((10, 14), (10, 14)), ((8, 12), (7, 12)), ((-1, 7), (-1, 7)), ((-1, 8), (-1, 8)), ((8, 14), (8, 14)), ((23, 11), (24, 11)), ((-1, 11), (-1, 11)), ((-1, 12), (-1, 12)), ((-1, 13), (-1, 13)), ((-1, 14), (-1, 14)), ((10, 2), (10, 2)), ((-1, 16), (-1, 16)), ((-1, 17), (-1, 17)), ((-1, 18), (-1, 18)), ((-1, 19), (-1, 19))]
edges[10]

In [None]:
edges[17]

In [None]:
env2, env_dict2 = inspect(83)
show_frame(env)

In [None]:
print(env2.agents[6])

In [None]:
env.agents[17]

In [None]:
agent.to_agent()