# Simple Animation Demo

This notebook demostrates how Flatland can be rendered in Colab / Jupyter notebooks. There are different variants, which are implemented as examples. 
- Direct rendering, i.e. during the simulation (Flatland) 
- Indirect rendering, i.e. it is rendered in image memory and all frames are stored so that everything can then be played back using HTML / video.

Install Flatland from Source (Latest version)

In [None]:
use_colab = False
if use_colab:
    use_pip_installer = True
    if use_pip_installer:
      !pip install flatland-rl &> /dev/null
    else:
      !git clone https://github.com/flatland-association/flatland-rl.git
      %cd flatland-rl
      !git pull
      %cd ..
      !pip install -r /content/flatland-rl/requirements_dev.txt &> /dev/null

      import os
      import sys
      os.environ['PYTHONPATH'] = "/env/python:/content/flatland-rl"
      if "/content/flatland-rl" not in sys.path:
        sys.path.insert(1, "/content/flatland-rl")

Import Flatland

In [None]:
# import all flatland dependance
import time
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.malfunction_generators import ParamMalfunctionGen, MalfunctionParameters
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.envs.rail_env_action import RailEnvActions

Import and create helpful methods for rendering flatland within colab / jupyter notebooks.

In [None]:
from IPython.display import HTML, display, clear_output
import ipywidgets as ipw
from io import BytesIO
import PIL
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
import time      
  
def create_rendering_area():
  rendering_area = ipw.Image()
  display(rendering_area)
  return rendering_area

def render_env_to_image(flatland_renderer):
  flatland_renderer.render_env(show=False, show_observations=False)
  image = flatland_renderer.get_image()
  return image

def render_env(flatland_renderer, rendering_area : ipw.Image):
  pil_image = PIL.Image.fromarray(render_env_to_image(flatland_renderer))
  if rendering_area is None:
    clear_output(wait=False)
    display(pil_image)
    return

  # convert numpy to PIL to png-format bytes  
  with BytesIO() as fOut:
    pil_image.save(fOut, format="png")
    byPng = fOut.getvalue()
      
  # set the png bytes as the image value; 
  # this updates the image in the browser.
  rendering_area.value=byPng

def process_frames(frames, frames_per_second=1000/20):
  dpi = 72
  interval = frames_per_second # ms

  plt.figure(figsize=(frames[0].shape[1]/dpi,frames[0].shape[0]/dpi),dpi=dpi)
  plt.axis=('off')
  plot = plt.imshow(frames[0])

  def init():
    pass

  def update(i):
    plot.set_data(frames[i])
    return plot,

  anim = FuncAnimation(fig=plt.gcf(),
                      func=update,
                      frames=len(frames),
                      init_func=init,
                      interval=interval,
                      repeat=True,
                      repeat_delay=20)
  plt.close(anim._fig)
  return anim


Create Flatland environment.

In [None]:
def create_flatland_env(
        obs_builder_object: ObservationBuilder,
        max_rails_between_cities=2,
        max_rails_in_city=4,
        malfunction_rate=1 / 1000,
        n_cities=5,
        number_of_agents=10,
        grid_width=30,
        grid_height=40,
        random_seed=0) -> RailEnv:
    return RailEnv(
        width=grid_width,
        height=grid_height,
        rail_generator=sparse_rail_generator(
            max_num_cities=n_cities,
            seed=random_seed,
            grid_mode=True,
            max_rails_between_cities=max_rails_between_cities,
            max_rail_pairs_in_city=max_rails_in_city
        ),
        malfunction_generator=ParamMalfunctionGen(
            MalfunctionParameters(
                malfunction_rate=malfunction_rate, min_duration=10, max_duration=50
            )
        ),
        random_seed=random_seed,
        number_of_agents=number_of_agents,
        obs_builder_object=obs_builder_object
    )

Create a simple flatland simulation (bring it all together).

In [None]:
def run_simulation(max_steps=50, enable_in_simulation_rendering=False):
  env = create_flatland_env(DummyObservationBuilder())
  
  if enable_in_simulation_rendering:
    rendering_area = create_rendering_area()
  
  env_renderer = RenderTool(env, gl="PILSVG",
                                  agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
                                  show_debug=True,
                                  screen_height=750,
                                  screen_width=750)
  obs, info = env.reset()

  offscreen_rendered_frames = []
 
  for step in range(max_steps):
    action_dict = dict()
    for a in range(env.get_num_agents()):
      action_dict.update({a: RailEnvActions.MOVE_FORWARD})
    
    env.step(action_dict)
    
    if enable_in_simulation_rendering:
      render_env(env_renderer, rendering_area)
    else:
      offscreen_rendered_frames.append(render_env_to_image(env_renderer))

  return offscreen_rendered_frames

Direct rendering during simulation runs.

In [None]:
frames = run_simulation(enable_in_simulation_rendering=True)

Run the simulation loop and collect frames. 

In [None]:
frames = run_simulation()

Process the collected frames and prepare an Matplotlib animation.

In [None]:
anim = process_frames(frames)

Render the animation.

In [None]:
display(HTML(anim.to_jshtml()))