In [1]:
# import numpy as np
# import networkx as nx
# import PIL
# import itertools
# import time
# import copy
# import sys
# from io import BytesIO
# from enum import IntEnum
from typing import List, NamedTuple, Optional, Dict, Tuple
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation

# we need a lot of frames to render larger rail environments running...
#matplotlib.rcParams['animation.embed_limit'] = 2**128 # even this seems problematic when we then do big rendering...
#matplotlib.rcParams['animation.embed_limit'] = 2**512 # too much for colab. Seems to lead to disconnect.

from collections import deque

from IPython.display import HTML, display, clear_output
from IPython.terminal.embed import embed
import ipywidgets as ipw
from ipycanvas import canvas

# from clyngor import ASP, solve
#from clingo.symbol import Function, Number
#from clingo.application import Application, clingo_main

from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import Agent, EnvAgent

from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_manual_specifications_generator
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_generators import sparse_rail_generator

from flatland.envs.observations import TreeObsForRailEnv

from flatland.envs.step_utils.state_machine import TrainStateMachine
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.transition_utils import check_action_on_agent
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv

from flatland.envs.step_utils.transition_utils import check_action_on_agent
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_generators import rail_from_manual_specifications_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen

from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_simple_rail_unconnected, make_simple_rail_unconnected
# from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2 # careful, this is not a rail environment, just a grid transition map...
#from flatland.utils.env_edit_utils import makeEnv # broken, does not work in flatland V3...
from flatland.utils.simple_rail import make_simple_rail2, make_simple_rail_with_alternatives
from flatland.utils.rendertools import RenderTool, AgentRenderVariant

In [2]:
import numpy as np
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap

In [3]:
def make_simple_rail_with_alternatives2() -> Tuple[GridTransitionMap, np.array]:
    # We instantiate a very simple rail network on a 7x10 grid:
    #  0 1 2 3 4 5 6 7 8 9  10
    # 0        /-------------\
    # 1        |             |
    # 2        |             |
    # 3 _ _ _ /_  _ _        |
    # 4              \   ___ /
    # 5               |/
    # 6               |
    # 7               |
    transitions = RailEnvTransitions()
    cells = transitions.transition_list

    empty = cells[0]
    dead_end_from_south = cells[7]
    right_turn_from_south = cells[8]
    right_turn_from_west = transitions.rotate_transition(right_turn_from_south, 90)
    right_turn_from_north = transitions.rotate_transition(right_turn_from_south, 180)
    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
    vertical_straight = cells[1]
    simple_switch_north_left = cells[2]
    simple_switch_north_right = cells[10]
    simple_switch_left_east = transitions.rotate_transition(simple_switch_north_left, 90)
    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
    double_switch_south_horizontal_straight = horizontal_straight + cells[6]
    double_switch_north_horizontal_straight = transitions.rotate_transition(
        double_switch_south_horizontal_straight, 180)
    
    rail_map = np.array(
        [[empty] * 2 + [16386] + [17411] + [horizontal_straight] * 5 + [right_turn_from_west]] +
        [[empty] * 2 + [32800] + [vertical_straight] + [empty] * 5 + [vertical_straight]] * 2 +
        [[dead_end_from_east] + [horizontal_straight] + [38505] + [simple_switch_left_east] + [horizontal_straight] * 2 + [
            right_turn_from_west] + [empty] * 2 + [vertical_straight]] +
        [[empty] * 6 + [simple_switch_north_right] + [horizontal_straight] * 2 + [right_turn_from_north]] +
        [[empty] * 6 + [vertical_straight] + [empty] * 3] +
        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0], transitions=transitions)
    rail.grid = rail_map
    city_positions = [(0,2), (6, 6)]
    train_stations = [
                      [( (0, 2), 0 ) ],
                      [( (6, 6), 0 ) ],
                     ]
    city_orientations = [0, 5]
    agents_hints = {'city_positions': city_positions,
                    'train_stations': train_stations,
                    'city_orientations': city_orientations
                   }
    optionals = {'agents_hints': agents_hints}
    return rail, rail_map, optionals

In [4]:
def make_potsdam() -> Tuple[GridTransitionMap, np.array]:
    # We instantiate a rail network based on Potsdam

    transitions = RailEnvTransitions()
    cells = transitions.transition_list

    empty = 0
    t1_0 = 32800
    t1_90 = 1025
    t2_0 = 37408
    t2_90 = 17411
    t2_180 = 32872
    t10_90 = 3089
    t3_0 = 33825
    t4_0 = 38433
    t4_90 = 50211
    t4_180 = 33897
    t4_270 = 35889
    t5_0 = 38505
    t5_90 = 52275
    t6_0 = 20994
    t7_90 = 16458
    t8_180 = 2136
    t9_270 = 6672
    t7_0 = 8192
    t7_90 = 4
    t7_180 = 128
    t7_270 = 256
    t8_0 = 4608
    t8_90 = 16386
    t8_180 = 72
    t8_270 = 2064
    t10_0 = 49186
    t10_90 = 1097
    t10_180 = 34864
    t10_270 = 5633
    
    rail_map = np.array(
        [ [empty]*3 + [t7_0] + [empty]*21 ]
        + [ [empty]*3 + [t1_0] + [empty]*21 ]
        + [ [empty]*3 + [t1_0] + [empty]*9 + [t8_90] + [t1_90] + [t8_0] + [empty]*9 ]
        + [ [t7_90] + [t1_90]*2 + [t10_90] + [t1_90]*4 + [t10_90] + [t1_90]*4 + [t10_90] + [t1_90] + [t10_90] + [t1_90] + [t10_90] + [t1_90] + [t10_90] + [t1_90]*2 + [t10_90] + [t10_90] + [t7_270] ]
        + [ [empty]*8 + [t1_0] + [empty]*8 + [t8_180] + [t1_90] + [t8_270] + [empty]*2 + [t1_0] + [t8_180] + [t7_270] ]
        + [ [empty]*8 + [t1_0] + [empty]*13 + [t1_0] + [empty]*2 ]
        + [ [empty]*8 + [t7_180] + [empty]*13 + [t1_0] + [empty]*2 ]
        + [ [empty]*22 + [t1_0] + [empty]*2 ]
        + [ [empty]*22 + [t7_180] + [empty]*2 ]
        , dtype=np.uint16)
    
    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0], transitions=transitions)
    rail.grid = rail_map
    city_positions = [(0,2), (6, 6)]
    train_stations = [
                      [( (0, 2), 0 ) ],
                      [( (6, 6), 0 ) ],
                     ]
    city_orientations = [0, 5]
    agents_hints = {'city_positions': city_positions,
                    'train_stations': train_stations,
                    'city_orientations': city_orientations
                   }
    optionals = {'agents_hints': agents_hints}
    return rail, rail_map, optionals

In [4]:
#rail, rail_map, optionals = make_simple_rail_with_alternatives()
rail, rail_map, optionals = make_simple_rail_with_alternatives2()
rail_gen = rail_from_grid_transition_map(rail_map=rail, optionals=optionals)
hand_crafted_env = RailEnv(width=25, height=9, rail_generator=rail_gen, number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=0))

In [5]:
rail_map

array([[    0,     0, 16386, 17411,  1025,  1025,  1025,  1025,  1025,
         4608],
       [    0,     0, 32800, 32800,     0,     0,     0,     0,     0,
        32800],
       [    0,     0, 32800, 32800,     0,     0,     0,     0,     0,
        32800],
       [    4,  1025, 38505,  3089,  1025,  1025,  4608,     0,     0,
        32800],
       [    0,     0,     0,     0,     0,     0, 49186,  1025,  1025,
         2064],
       [    0,     0,     0,     0,     0,     0, 32800,     0,     0,
            0],
       [    0,     0,     0,     0,     0,     0,   128,     0,     0,
            0]], dtype=uint16)

In [6]:
def deactivate_windows(env:RailEnv) -> None:
  """
  helper function to remove effects of departure and arrival windows.
  Modifies the rail env in-place.
  """
  for agent in env.agents:
    agent.earliest_departure = 0
    #agent.latest_arrival = np.inf # is this problematic? maybe set to env._max_episode_steps
    agent.latest_arrival = env._max_episode_steps # with this, it seems to work.
    # set states for all agents from WAITING to READY_TO_DEPART
    agent.state = TrainState.READY_TO_DEPART

    # trying this out, no idea...
    agent.arrival_time = None # this seems to work. Needs to be "initialized" with None.
    #Then at last step when arrived, is set to number of steps it took, and agent is removed from env.
    # --> the logic in flatland seems to be a bit strange for this.

In [7]:
# TODO: maybe the check if predicted times and positions should be in a seperate function,
# unrelated to

def run_simulation(env:RailEnv, action_plan:dict, times_positions:dict=None, enable_in_simulation_rendering:bool=False):

    """
    usage:
    input: env:RailEnv (needs to be reset already and departure/arrival windows deactivated)
    optional: times_positions_dict, extracted from solution of solver, to check if there is a mismatch with actual behaviour of trains in flatland.
    action_plan: dict of the form {t: {agent_id: RailEnvAction}}

    version1: direct rendering during simulation runs (does not work well with colab, but fine with jupyter)
    run_simulation(env, action_plan, enable_in_simulation_rendering=True)

    version2: run the simulation loop, collect frames, and show as a "video" (works well with colab for small environments, then runs out of space...)
    frames = run_simulation(env, action_plan, enable_in_simulation_rendering=False#) # Run the simulation loop and collect frames.
    anim = process_frames(frames) # Process the collected frames and prepare a Matplotlib animation.
    display(HTML(anim.to_jshtml())) # Render the animation.
    """
    assert len(env.agents) > 0, "error, reset environment before usage"
    assert action_plan
    assert env._max_episode_steps > 0

    if enable_in_simulation_rendering:
        rendering_area = create_rendering_area()

    env_renderer = RenderTool(env, gl="PILSVG",
                                  agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
                                  show_debug=True,
                                  screen_height=750,
                                  screen_width=750)

    offscreen_rendered_frames = []

    MAX_STEPS = env._max_episode_steps # maybe not needed, as we check for dones...

    dones = dict()
    dones["_all__"] = False

    for time_step in range(0, MAX_STEPS + 1):
        # TODO: maybe make easier form of action_plan...
        actions = {agent.handle: action_plan[(time_step, agent.handle)] for agent in env.agents if (time_step, agent.handle) in action_plan}

        if times_positions: # optional argument. Follwing: check if positions correspond to ASP solution.
          for agent in env.agents:
            # agents can have position None (i.e. before activated)
            if (time_step, agent.handle) in times_positions and agent.position is not None: # the ASP encoding does not know yet that agents get removed at target...
              prediction = times_positions[(time_step, agent.handle)]
              x, y, direction = prediction
              assert agent.position == (x, y), f"error, agent {agent.handle} is at {agent.position} instead of {(x, y)} at t={time_step}"
              assert agent.direction == direction, f"error, agent {agent.handle} has direction {agent.direction} instead of {direction} at t={time_step}."

        # put this before the env.step(..) to render time step 0.
        if enable_in_simulation_rendering:
            render_env(env_renderer, rendering_area)
        else:
            offscreen_rendered_frames.append(render_env_to_image(env_renderer))


        obs, rewards, dones, info = env.step(actions) # careful, after this, we are in the situation t+1...



        # putting this here. Trying to get the last frame too.
        if dones["__all__"]: # time is up or all agents in targets
         # try to catch last frame
          if enable_in_simulation_rendering:
              render_env(env_renderer, rendering_area)
          else:
              offscreen_rendered_frames.append(render_env_to_image(env_renderer))
          break


    return offscreen_rendered_frames

In [8]:
def action_plan(answer:list) -> dict:
  """
  input: a single answer.
  assumes that just the predicate action(..) is shown in the .lp string, i.e.
  #show action/3.

  output: dict of the form { (timestep, agent_id) : action}
  """
  #assert len(answer_list) > 0, "Error. Answer list is empty"

  # we only take the "action(...)" predicates, not at(...) etc.
  action_plan = dict()
  for predicate in answer:
    predicate_str, _ = predicate
    if predicate_str == "action":
      # unpacking the answer
      # # predicate is of the form: ("action", (0, "agent0", 2))
      predicate_str, (timestep, agent_name, action) = predicate

      # BUG, DOES NOT WORK FOR MORE THAN 10 AGENTS (fixed, check)
      agent_id_str = agent_name.replace("agent", "") # "agent12" --> "12"
      agent_id = int(agent_id_str)

      action_plan.update({(timestep, agent_id) : action})


  all_agent_names = list(set(agent_name for predicate_str, (timestep, agent_name, action) in answer))
  all_agent_names.sort()
  print(f"all agent names: {all_agent_names}")

  all_timesteps = dict()
  for agent_name in all_agent_names:
    all_timesteps[agent_name] = [timestep for predicate_str, (timestep, name, action) in answer if name == agent_name]

  min_timestep = dict()
  for agent_name in all_agent_names:
    min_timestep[agent_name] = min(all_timesteps[agent_name])

  # sorting min_timestep for nicer output (does not really matter functionally)
  min_timestep = dict(sorted(min_timestep.items(), key=lambda kv : kv[0]))


  # TODO...
  # TESTING: ADDING move_forward as first step of action plan of each agent, as ASP code just activates agents and gives actions to move.
  # in flatland, we need an additional MOVE_FORWARD action to place an agent on his initial position upon activation.

  """
  for agent_name in all_agent_names:
    agent_id = int(agent_name[-1])
    action_plan.update({(min_timestep[agent_name] - 1 , agent_id) : RailEnvActions.MOVE_FORWARD.value})
    print(f"added initial activation action MOVE_FORWARD at index (timestep, agent_id) = {(min_timestep[agent_name] - 1 , agent_id)}")
  """

  # sorting action_plan for nicer output (does not really matter functionally)
  action_plan = dict(sorted(action_plan.items(), key=lambda kv: kv[0][0] ))

  return action_plan

In [9]:
answer_list_hand_crafted_env = [('action', (0, 'agent0', 2)),
 ('at', (1, 'agent0', ('', (6, 6, 2)))),
 ('action', (1, 'agent0', 2)),
 ('action', (2, 'agent0', 2)),
 ('at', (2, 'agent0', ('', (5, 6, 0)))),
 ('at', (3, 'agent0', ('', (4, 6, 0)))),
 ('action', (3, 'agent0', 2)),
 ('action', (4, 'agent0', 2)),
 ('at', (4, 'agent0', ('', (3, 6, 0)))),
 ('action', (5, 'agent0', 2)),
 ('at', (5, 'agent0', ('', (3, 5, 3)))),
 ('action', (6, 'agent0', 2)),
 ('at', (6, 'agent0', ('', (3, 4, 3)))),
 ('at', (7, 'agent0', ('', (3, 3, 3)))),
 ('action', (7, 'agent0', 3)),
 ('at', (8, 'agent0', ('', (2, 3, 0)))),
 ('action', (8, 'agent0', 2)),
 ('at', (9, 'agent0', ('', (1, 3, 0)))),
 ('action', (9, 'agent0', 2)),
 ('at', (10, 'agent0', ('', (0, 3, 0))))]

In [10]:
answer_list_hand_crafted_env2 = [
 ('action', (0, 'agent0', 2)),
 ('at', (1, 'agent0', ('', (6, 6, 2)))),
 ('action', (1, 'agent0', 2)),
 ('action', (2, 'agent0', 2))]
# answer_list_hand_crafted_env = []

In [11]:
plan_hand_crafted_env = action_plan(answer_list_hand_crafted_env)

all agent names: ['agent0']


In [12]:
def create_rendering_area():
    rendering_area = ipw.Image()
    display(rendering_area)
    return rendering_area

def render_env_to_image(flatland_renderer):
    flatland_renderer.render_env(show=False, show_observations=False)
    image = flatland_renderer.get_image()
    return image

def render_env(flatland_renderer, rendering_area : ipw.Image):
    pil_image = PIL.Image.fromarray(render_env_to_image(flatland_renderer))
    if rendering_area is None:
        clear_output(wait=False)
        display(pil_image)
        return

    # convert numpy to PIL to png-format bytes
    with BytesIO() as fOut:
        pil_image.save(fOut, format="png")
        byPng = fOut.getvalue()

    # set the png bytes as the image value;
    # this updates the image in the browser.
    # standard: 1.0 Reason seems to be: otherwise, the animation in colab gets "jumpy".
    time.sleep(0.5) # Michel: do we not need that anymore?
    rendering_area.value=byPng

def process_frames(frames, frames_per_second=1000/20):
    dpi = 72
    interval = frames_per_second # ms

    plt.figure(figsize=(frames[0].shape[1]/dpi, frames[0].shape[0]/dpi), dpi=dpi)
    plt.axis=('off')
    plot = plt.imshow(frames[0])

    def init():
        pass

    def update(i):
        plot.set_data(frames[i])
        return plot,

    anim = FuncAnimation(fig=plt.gcf(),
                      func=update,
                      frames=len(frames),
                      init_func=init,
                      interval=interval,
                      repeat=False,
                      repeat_delay=20,
                      cache_frame_data=True
                      )

    plt.close(anim._fig)
    return anim

In [13]:
hand_crafted_env.reset(regenerate_rail=False, regenerate_schedule=False, random_seed=0)
deactivate_windows(hand_crafted_env)

frames = run_simulation(env=hand_crafted_env, action_plan=plan_hand_crafted_env, enable_in_simulation_rendering=False) # Run the simulation loop and collect frames.
anim = process_frames(frames) # Process the collected frames and prepare a Matplotlib animation.
display(HTML(anim.to_jshtml())) # Render the animation.

?????
0 2
?????
0 3
?????
0 4
?????
0 5
?????
0 6
?????
0 7
?????
0 8
?????
0 9
?????
1 9
?????
2 9
?????
3 9
?????
4 9
?????
4 8
?????
4 7
?????
4 6
?????
5 6
?????
0 2




In [14]:
rail_map

array([[    0,     0, 16386, 17411,  1025,  1025,  1025,  1025,  1025,
         4608],
       [    0,     0, 32800, 32800,     0,     0,     0,     0,     0,
        32800],
       [    0,     0, 32800, 32800,     0,     0,     0,     0,     0,
        32800],
       [    4,  1025, 38505,  3089,  1025,  1025,  4608,     0,     0,
        32800],
       [    0,     0,     0,     0,     0,     0, 49186,  1025,  1025,
         2064],
       [    0,     0,     0,     0,     0,     0, 32800,     0,     0,
            0],
       [    0,     0,     0,     0,     0,     0,   128,     0,     0,
            0]], dtype=uint16)

In [None]:
rail_map = np.array(
        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
        [[dead_end_from_east] + [horizontal_straight] * 2 +
         [simple_switch_east_west_north] +
         [horizontal_straight] * 2 + [simple_switch_east_west_south] +
         [horizontal_straight] * 2 + [dead_end_from_west]] +
        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)

In [372]:
for type in range(0,11):
    print("#### Type ",type)
    for rotation in [0,90,180,270]:
        print("* Rotation ", rotation, ": ",transitions.rotate_transition(cells[type], rotation), sep='')
    print()

#### Type  0
* Rotation 0: 0
* Rotation 90: 0
* Rotation 180: 0
* Rotation 270: 0

#### Type  1
* Rotation 0: 32800
* Rotation 90: 1025
* Rotation 180: 32800
* Rotation 270: 1025

#### Type  2
* Rotation 0: 37408
* Rotation 90: 3089
* Rotation 180: 32872
* Rotation 270: 17411

#### Type  3
* Rotation 0: 33825
* Rotation 90: 33825
* Rotation 180: 33825
* Rotation 270: 33825

#### Type  4
* Rotation 0: 38433
* Rotation 90: 35889
* Rotation 180: 33897
* Rotation 270: 50211

#### Type  5
* Rotation 0: 52275
* Rotation 90: 38505
* Rotation 180: 52275
* Rotation 270: 38505

#### Type  6
* Rotation 0: 20994
* Rotation 90: 6672
* Rotation 180: 2136
* Rotation 270: 16458

#### Type  7
* Rotation 0: 8192
* Rotation 90: 256
* Rotation 180: 128
* Rotation 270: 4

#### Type  8
* Rotation 0: 16386
* Rotation 90: 4608
* Rotation 180: 2064
* Rotation 270: 72

#### Type  9
* Rotation 0: 4608
* Rotation 90: 2064
* Rotation 180: 72
* Rotation 270: 16386

#### Type  10
* Rotation 0: 49186
* Rotation 90: 5

In [None]:
rail_map