# Plot the Env and some trees

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import random
import time
import sys

In [3]:
# in case you need to tweak your PYTHONPATH...
sys.path.append("../flatland")

In [4]:
import flatland.core.env
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv

# Generate

## Rendering - notebook integration

Helpful methods

In [5]:
from IPython.display import HTML, display, clear_output
import ipywidgets as ipw
from io import BytesIO
import PIL
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
import time      
  
def create_rendering_area():
    rendering_area = ipw.Image()
    display(rendering_area)
    return rendering_area

def render_env_to_image(flatland_renderer):
    flatland_renderer.render_env(show=False, show_observations=False)
    image = flatland_renderer.get_image()
    return image

def render_env(flatland_renderer, rendering_area : ipw.Image):
    pil_image = PIL.Image.fromarray(render_env_to_image(flatland_renderer))
    if rendering_area is None:
        clear_output(wait=False)
        display(pil_image)
        return

    # convert numpy to PIL to png-format bytes  
    with BytesIO() as fOut:
        pil_image.save(fOut, format="png")
        byPng = fOut.getvalue()

    # set the png bytes as the image value; 
    # this updates the image in the browser.
    rendering_area.value=byPng

### Create a render area - this is the object on which Flatland is visualised.

In [6]:
rendering_area = create_rendering_area()

Image(value=b'')

# Create Flatland and RenderTool 

In [7]:
nAgents = 3
n_cities = 2
max_rails_between_cities = 2
max_rails_in_city = 4
seed = 0
env = RailEnv(
        width=20,
        height=30,
        rail_generator=sparse_rail_generator(
            max_num_cities=n_cities,
            seed=seed,
            grid_mode=True,
            max_rails_between_cities=max_rails_between_cities,
            max_rail_pairs_in_city=max_rails_in_city
        ),
        line_generator=sparse_line_generator(),
        number_of_agents=nAgents,
        obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
    )

init_observation, _ = env.reset()

In [8]:
env_renderer = RenderTool(env, gl="PILSVG",
                                  agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
                                  show_debug=True,
                                  screen_height=750,
                                  screen_width=750)

render_env(env_renderer, rendering_area)