In [None]:
from flatland.trajectories.trajectories import Trajectory
from flatland.evaluators.trajectory_evaluator import TrajectoryEvaluator
from IPython.display import HTML, display, clear_output
import ipywidgets as ipw
from io import BytesIO
import PIL
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
import time      
from pathlib import Path
from flatland.envs.persistence import RailEnvPersister
import tqdm
from flatland.utils.rendertools import RenderTool
from IPython.display import display
from PIL import Image

## Get episode data

In [None]:
!wget https://github.com/flatland-association/flatland-scenarios/raw/refs/heads/main/trajectories/FLATLAND_BENCHMARK_EPISODES_FOLDER_v2.zip -O /tmp/FLATLAND_BENCHMARK_EPISODES_FOLDER_v2.zip
!mkdir -p /tmp/episodes
!unzip -o /tmp/FLATLAND_BENCHMARK_EPISODES_FOLDER_v2.zip -d /tmp/episodes

In [None]:
!ls -al  "/tmp/episodes/30x30 map/10_trains"

## Rendering

In [None]:
def render_env_to_image(flatland_renderer):
    flatland_renderer.render_env(show=False, show_observations=False)
    image = flatland_renderer.get_image()
    return image

def process_frames(frames, frames_per_second=1000/20):
    dpi = 72
    interval = frames_per_second # ms

    plt.figure(figsize=(frames[0].shape[1]/dpi,frames[0].shape[0]/dpi),dpi=dpi)
    plt.axis=('off')
    plot = plt.imshow(frames[0])

    def init():
        pass

    def update(i):
        plot.set_data(frames[i])
        return plot,

    anim = FuncAnimation(fig=plt.gcf(),
                      func=update,
                      frames=len(frames),
                      init_func=init,
                      interval=interval,
                      repeat=True,
                      repeat_delay=20)
    plt.close(anim._fig)
    return anim

## Run trajectory

In [None]:
data_dir = "/tmp/episodes/30x30 map/10_trains"
ep_id = "1649ef98-e3a8-4dd3-a289-bbfff12876ce"

In [None]:
# run with snapshots
trajectory = Trajectory(data_dir=data_dir, ep_id=ep_id)
TrajectoryEvaluator(trajectory).evaluate(snapshot_interval=1)

In [None]:
!find "$data_dir" -name "$ep_id""_step*.pkl" | sort -u | head

In [None]:
snapshots = list((Path(data_dir) /"serialised_state").glob(f'{ep_id}_step*.pkl'))
snapshots.sort()
snapshots

## Animate trajectory

In [None]:
frames = []
for p in tqdm.tqdm(snapshots):
    env, env_dict = RailEnvPersister.load_new(str(p))
    env_renderer = RenderTool(env, gl="PGL", show_debug=True)
    # terribly slow!
    rendered_image = render_env_to_image(env_renderer) 
    frames.append(rendered_image)

In [None]:
anim = process_frames(frames)
HTML(anim.to_jshtml())

In [None]:
!python -m pip install ipyplot

In [None]:
import ipyplot

ipyplot.plot_images(frames, img_width=400)

## Inspect single steps

In [None]:
def inspect(step):
    p = Path(data_dir) /"serialised_state" / f'{ep_id}_step{step:04d}.pkl'
    env, env_dict = RailEnvPersister.load_new(str(p))
    
    return env, env_dict

def show_frame(env, dpi = 40):
    env_renderer = RenderTool(env, gl="PGL", show_debug=True)
    env_renderer.render_env(show=False, show_observations=False)
    frame = env_renderer.get_image()
    plt.figure(figsize = (frame.shape[1]/dpi, frame.shape[0]/dpi))
    plt.imshow(frame)

In [None]:
env, env_dict = inspect(25)
show_frame(env)

In [None]:
env, env_dict = inspect(26)
show_frame(env)

In [None]:
agent.to_agent()

In [None]:
import pandas as pd
pd.DataFrame.from_records([agent.to_agent()._asdict() for agent in env_dict["agents"]])

In [None]:
trajectory.position_lookup(trajectory.read_trains_positions(), env_time=26, agent_id=0)

In [None]:
trajectory.action_lookup(trajectory.read_actions(), env_time=26, agent_id=0)