# Flatland Javascript Renderer

This is in active development and is intended to replace the tkinter / PIL renderer.
This notebook shows how the JS renderer can be started in a notebook using an "iframe".  It starts a Flask server in a background thread, listening on port 8080, localhost.

In [90]:
import time
import os
import numpy as np 
from numpy import array
import threading
import webbrowser
import pandas as pd # useful for displaying tables
#import ipysheet as ips
from IPython.display import HTML, display
from matplotlib import pyplot as plt
import PIL
import jpy_canvas
from ipywidgets import Button

In [2]:
%matplotlib inline

In [3]:
#pd.describe_option("rows")

In Flatland you can use custom observation builders and predicitors<br>
Observation builders generate the observation needed by the controller<br>
Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network

In [4]:
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.observations import GlobalObsForRailEnv
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
# We also include a renderer because we want to visualize what is going on in the environment
from flatland.utils.rendertools import RenderTool, AgentRenderVariant


Flask static folder:  /home/jeremy/projects/aicrowd/rl-trains/flatland/notebooks/static


In [5]:
seed = 14  # Random seed
grid_distribution_of_cities = False  # Type of city distribution, if False cities are randomly placed
max_rails_between_cities = 2  # Max number of tracks allowed between cities. This is number of entry point to a city
max_rail_in_cities = 6  # Max number of parallel tracks within a city, representing a realistic trainstation

In [6]:
generator_configs = [
    [30, 30, 5, 2],
    [8*7, 8*7, 20, 10],
    [16*7, 9*7, 50, 20],
]
dfConfig = pd.DataFrame(generator_configs, columns="width height nr_trains cities_in_map".split(" "))
dfConfig.index.name="config"
dfConfig

Unnamed: 0_level_0,width,height,nr_trains,cities_in_map
config,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,30,30,5,2
1,56,56,20,10
2,112,63,50,20


In [20]:
i_config = 0
srConf = dfConfig.iloc[i_config,:]
srConf

width            30
height           30
nr_trains         5
cities_in_map     2
Name: 0, dtype: int64

In [21]:
rail_generator = sparse_rail_generator(max_num_cities=srConf.cities_in_map,
                                       seed=seed,
                                       grid_mode=grid_distribution_of_cities,
                                       max_rails_between_cities=max_rails_between_cities,
                                       max_rails_in_city=max_rail_in_cities,
                                       )

The schedule generator can make very basic schedules with a start point, end point and a speed profile for each agent.<br>
The speed profiles can be adjusted directly as well as shown later on. We start by introducing a statistical<br>
distribution of speed profiles

Different agent types (trains) with different speeds.

In [22]:
speed_ration_map = {1.: 0.25,  # Fast passenger train
                    1. / 2.: 0.25,  # Fast freight train
                    1. / 3.: 0.25,  # Slow commuter train
                    1. / 4.: 0.25}  # Slow freight train

We can now initiate the schedule generator with the given speed profiles

In [23]:
schedule_generator = sparse_schedule_generator(speed_ration_map)

We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions<br>
during an episode.

In [24]:
stochastic_data = MalfunctionParameters(malfunction_rate=10000,  # Rate of malfunction occurence
                                        min_duration=15,  # Minimal duration of malfunction
                                        max_duration=50  # Max duration of malfunction
                                        )
# Custom observation builder without predictor
observation_builder = GlobalObsForRailEnv()

Custom observation builder with predictor, uncomment line below if you want to try this one<br>
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())

Construct the enviornment with the given observation, generataors, predictors, and stochastic data

In [25]:
env = RailEnv(width=srConf.width,
              height=srConf.height,
              rail_generator=rail_generator,
              schedule_generator=schedule_generator,
              number_of_agents=srConf.nr_trains,
              obs_builder_object=observation_builder,
              malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
              remove_agents_at_target=True)
env.reset()
pass

Initiate the renderer

In [26]:
if False:
    env_renderer_js = RenderTool(env, # gl="PILSVG", # defaults to BROWSER
                              agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
                              show_debug=False,
                              host="0.0.0.0",
                              )
                          
env_renderer = render_pil = RenderTool(env, gl="PILSVG",
                          agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
                          show_debug=False,
                          screen_height=1000,  # Adjust these parameters to fit your resolution
                          screen_width=1300)  # Adjust these parameters to fit your resolution



WebSocket transport not available. Install eventlet or gevent and gevent-websocket for improved performance.


 * Serving Flask app "flatland.utils.flask_util" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://0.0.0.0:8085/ (Press CTRL+C to quit)


The first thing we notice is that some agents don't have feasible paths to their target.<br>
We first look at the map we have created

nv_renderer.render_env(show=True)<br>
time.sleep(2)<br>
Import your own Agent or use RLlib to train agents on Flatland<br>
As an example we use a random agent instead

In [27]:
class RandomAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
    def act(self, state):
        """
        :param state: input is the observation of the agent
        :return: returns an action
        """
        return np.random.choice([RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT,
                                 RailEnvActions.STOP_MOVING])
    def step(self, memories):
        """
        Step function to improve agent by adjusting policy given the observations
        :param memories: SARS Tuple to be
        :return:
        """
        return
    def save(self, filename):
        # Store the current policy
        return
    def load(self, filename):
        # Load a policy
        return

Initialize the agent with the parameters corresponding to the environment and observation_builder

In [28]:
controller = RandomAgent(218, env.action_space[0])

In [29]:
env_renderer.reset()

In [30]:
score = 0
# Run episode
frame_step = 0

In [31]:
nSteps = 3
action_dict = {}
observations, rewards, dones, information = env.step(action_dict)

In [92]:
def run_steps():
    global controller, observations, action_dict, env, render_pil, wCanvas, frame_step, score
    for step in range(nSteps):
        # Chose an action for each agent in the environment
        for a in range(env.get_num_agents()):
            action = controller.act(observations[a])
            action_dict.update({a: action})

        # Environment step which returns the observations for all agents, their corresponding
        # reward and whether their are done
        next_obs, all_rewards, done, _ = env.step(action_dict)

        # env_renderer_js.render_env(show=False, show_observations=False, show_predictions=False)
        render_pil.render_env(show=False, show_observations=False, show_predictions=False, show_agents=True)
        #display(PIL.Image.fromarray(render_pil.get_image()))
        wCanvas.data = render_pil.get_image()

        # env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step))
        frame_step += 1
        # Update replay buffer and train agent
        for a in range(env.get_num_agents()):
            controller.step((observations[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
            score += all_rewards[a]
        observations = next_obs.copy()
        if done['__all__']:
            print("All done!")
            break
        print('Episode: Steps {}\t Score = {}'.format(frame_step, score))

        time.sleep(0.01)

def run_steps_event(widget):
    run_steps()

In [93]:
wCanvas = jpy_canvas.Canvas(render_pil.get_image())
wButton = Button(description="run steps")
wButton.on_click(run_steps_event)
display(wButton)
display(wCanvas)

Button(description='run steps', style=ButtonStyle())

Canvas()

Episode: Steps 91	 Score = -257.8333333333337
Episode: Steps 92	 Score = -260.666666666667
Episode: Steps 93	 Score = -263.50000000000034
Episode: Steps 94	 Score = -266.33333333333366
Episode: Steps 95	 Score = -269.16666666666697
Episode: Steps 96	 Score = -272.0000000000003
Episode: Steps 97	 Score = -274.8333333333336
Episode: Steps 98	 Score = -277.6666666666669
Episode: Steps 99	 Score = -280.5000000000002
Episode: Steps 100	 Score = -283.33333333333354
Episode: Steps 101	 Score = -286.16666666666686
Episode: Steps 102	 Score = -289.00000000000017
Episode: Steps 103	 Score = -291.8333333333335
Episode: Steps 104	 Score = -294.6666666666668
Episode: Steps 105	 Score = -297.5000000000001
Episode: Steps 106	 Score = -300.3333333333334
Episode: Steps 107	 Score = -303.16666666666674
Episode: Steps 108	 Score = -306.00000000000006
Episode: Steps 109	 Score = -308.83333333333337
Episode: Steps 110	 Score = -311.6666666666667
Episode: Steps 111	 Score = -314.5
Episode: Steps 112	 Score 

In [89]:
run_steps()

Episode: Steps 88	 Score = -249.33333333333366
Episode: Steps 89	 Score = -252.166666666667
Episode: Steps 90	 Score = -255.00000000000034



<app-root></app-root>
<script>
try{
    window.nodeRequire = require;
    delete window.require;
    delete window.exports;
    delete window.module;
} catch (e) {}
</script>

<script src="static/runtime.js" defer></script>
<script src="static/polyfills-es5.js" nomodule defer></script>
<script src="static/polyfills.js" defer></script>
<script src="static/styles.js" defer></script>
<script src="static/scripts.js" defer></script>
<script src="static/vendor.js" defer></script>
<script src="static/main.js" defer></script>
</body>