# Render Episode
Render a stored episode.  Env file needs to have "episode" and "action" keys. 
- creates a moving gif file of the episode
- displays the episode in a widget with a slider for the time steps.

# Setup

In [None]:
#!apt -qq install graphviz libgraphviz-dev pkg-config
#!pip install -qq git+https://gitlab.aicrowd.com/flatland/flatland.git

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from IPython import display

In [None]:
import os
import pandas as pd
import PIL


In [None]:
from flatland.utils.rendertools import RenderTool
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.malfunction_generators import malfunction_from_file, no_malfunction_generator
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.step_utils.states import TrainState
from flatland.envs.persistence import RailEnvPersister

In [None]:
from IPython.display import HTML, display, clear_output
import ipywidgets as ipw
from io import BytesIO
import PIL
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
import time      
  
def create_rendering_area():
    rendering_area = ipw.Image()
    display(rendering_area)
    return rendering_area

def render_env_to_image(flatland_renderer):
    flatland_renderer.render_env(show=False, show_observations=False)
    image = flatland_renderer.get_image()
    return image

def render_env(flatland_renderer, rendering_area : ipw.Image):
    pil_image = PIL.Image.fromarray(render_env_to_image(flatland_renderer))
    if rendering_area is None:
        clear_output(wait=False)
        display(pil_image)
        return

    # convert numpy to PIL to png-format bytes  
    with BytesIO() as fOut:
        pil_image.save(fOut, format="png")
        byPng = fOut.getvalue()

    # set the png bytes as the image value; 
    # this updates the image in the browser.
    rendering_area.value=byPng

def process_frames(frames, frames_per_second=1000/20):
    dpi = 72
    interval = frames_per_second # ms

    plt.figure(figsize=(frames[0].shape[1]/dpi,frames[0].shape[0]/dpi),dpi=dpi)
    plt.axis=('off')
    plot = plt.imshow(frames[0])

    def init():
        pass

    def update(i):
        plot.set_data(frames[i])
        return plot,

    anim = FuncAnimation(fig=plt.gcf(),
                      func=update,
                      frames=len(frames),
                      init_func=init,
                      interval=interval,
                      repeat=True,
                      repeat_delay=20)
    plt.close(anim._fig)
    return anim


# Experiments

This has been mostly changed to load envs using `importlib_resources`.  It's getting them from the package "envdata.tests`

In [None]:
env, env_dict = RailEnvPersister.load_new("complex_scene_2.pkl", load_from_package="env_data.railway", legacy=True)
_ = env.reset()
env._max_episode_steps = 100

In [None]:
# the seed has to match that used to record the episode, in order for the malfunctions to match.
rendering_area = create_rendering_area()
print(env._max_episode_steps)

In [None]:
loAgs = env_dict["agents"]
lCols =  "initial_direction,direction,initial_position,position".split(",")
pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols] 
              for oAg in loAgs], columns=lCols)

In [None]:
pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols] 
              for oAg in env.agents], columns=lCols)

In [None]:
pd.DataFrame([ vars(oAg) for oAg in env.agents])

In [None]:
# from persistence.py
def get_agent_state(env):
    list_agents_state = []
    for iAg, oAg in enumerate(env.agents):
        # the int cast is to avoid numpy types which may cause problems with msgpack
        # in env v2, agents may have position None, before starting
        if oAg.position is None:
            pos = (0, 0)
        else:
            pos = (int(oAg.position[0]), int(oAg.position[1]))
        # print("pos:", pos, type(pos[0]))
        list_agents_state.append(
            [*pos, int(oAg.direction), oAg.malfunction_handler])
      
    return list_agents_state

In [None]:
pd.DataFrame([ vars(oAg) for oAg in env.agents])

In [None]:
expert_actions = []
action = {}

In [None]:
env_renderer = RenderTool(env, gl="PGL", show_debug=True)

n_agents = env.get_num_agents()
x_dim, y_dim = env.width, env.height
max_steps = env._max_episode_steps

action_dict = {}
frames = []
status_info = []
# log everything in original state
statuses = []
for a in range(n_agents):
    statuses.append(env.agents[a].state)
    
rendered_image = render_env_to_image(env_renderer) 
frames.append(rendered_image)
status_info.append(statuses)

step = 0
all_done = False
failed_action_check = False
print("Processing episode steps:")
while not all_done:
    print(step, end=", ")
     
    for agent_handle, agent in enumerate(env.agents):
        action_dict.update({agent_handle: RailEnvActions.MOVE_FORWARD})
     
    next_obs, all_rewards, done, info = env.step(action_dict)
  
    statuses = []
    for a in range(n_agents):
        statuses.append(env.agents[a].state)

    #clear_output(wait=True)
    rendered_image = render_env_to_image(env_renderer) 
    
    frames.append(rendered_image)
    status_info.append(statuses)
    #print("Replaying {}/{}".format(step, max_steps))

    if done['__all__']:
        all_done = True
        max_steps = step + 1
        print("done")

    step += 1

In [None]:
assert failed_action_check == False, "Realised states did not match stored states."

In [None]:
process_frames(frames)

In [None]:
anim = process_frames(frames)

In [None]:
display(HTML(anim.to_jshtml()))