In [1]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.append("../")
from src.graph import NetworkGraph

In [2]:
from tqdm import tqdm
from PIL import Image


In [3]:
from src.visualization.graphic import draw_path

In [4]:
from flatland.envs.rail_env import RailEnv
from flatland.envs.observations import *
from flatland.envs.rail_generators import complex_rail_generator,rail_from_manual_specifications_generator,random_rail_generator, RailGenerator
from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator, ScheduleGenerator
from flatland.utils.rendertools import RenderTool
from flatland.envs.rail_env import RailEnv
from flatland.envs.observations import *
from flatland.envs.rail_generators import complex_rail_generator,rail_from_manual_specifications_generator,random_rail_generator, RailGenerator,sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator, ScheduleGenerator, sparse_schedule_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant

In [5]:
def create_env_sparse(height,width,seed,number_of_agents=5):

    stochastic_data = {'prop_malfunction': 0,  # Percentage of defective agents
                       'malfunction_rate': 30,  # Rate of malfunction occurence
                       'min_duration': 3,  # Minimal duration of malfunction
                       'max_duration': 20  # Max duration of malfunction
                       }
    speed_ration_map = {1.: 1,  # Fast passenger train
                        1. / 2.: 0,  # Fast freight train
                        1. / 3.: 0,  # Slow commuter train
                        1. / 4.: 0}  # Slow freight train
    env = RailEnv(width=height,
                  height=width,
                  rail_generator=sparse_rail_generator(max_num_cities=number_of_agents+2,  
                                                       seed=seed, 
                                                       grid_mode=False,
                                                       max_rails_between_cities=4,
                                                       max_rails_in_city=4,
                                                       ),
                  schedule_generator=sparse_schedule_generator(speed_ration_map,seed = seed),
                  number_of_agents=number_of_agents,
                  stochastic_data=stochastic_data,  
                  obs_builder_object=GlobalObsForRailEnv(),
                  remove_agents_at_target=True
                  )
    env.reset()

    return env

In [6]:
def create_env(height,width,seed,number_of_agents=5):
    env = RailEnv(width=width,
              height=height,
              rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=1, 
                                                    min_dist=6, max_dist=99999, seed = seed),
              schedule_generator=complex_schedule_generator(),
              number_of_agents=number_of_agents)
    
    env.reset()
    
    return env

## Visu

In [7]:
env = create_env(5,7,1,2)

In [8]:
env.restart_agents()

In [9]:
env.agents

[EnvAgent(initial_position=(1, 0), direction=3, target=(3, 6), moving=False, speed_data={'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}, malfunction_data={'malfunction': 0, 'malfunction_rate': 0.0, 'next_malfunction': 0, 'nr_malfunctions': 0}, status=<RailAgentStatus.READY_TO_DEPART: 0>, position=None, handle=0, old_direction=None, old_position=None),
 EnvAgent(initial_position=(4, 1), direction=3, target=(0, 6), moving=False, speed_data={'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}, malfunction_data={'malfunction': 0, 'malfunction_rate': 0.0, 'next_malfunction': 0, 'nr_malfunctions': 0}, status=<RailAgentStatus.READY_TO_DEPART: 0>, position=None, handle=1, old_direction=None, old_position=None)]

In [10]:
env.agents[0].position = (1,3)
env.agents[0].direction = 3
env.agents[1].position = (1,2)
env.agents[1].direction = 1

In [11]:
env_renderer = RenderTool(env,screen_height=1200,screen_width=1200)
env_renderer.render_env(show=True, show_predictions=True, show_observations=True)

  Observation builder needs to populate: env.dev_obs_dict")
  Predictors builder needs to populate: env.dev_pred_dict")


In [None]:
actions_dict = {0:1,1:4}

In [None]:
_ = env.step(actions_dict)

In [None]:
env_renderer.reset()
env_renderer.render_env(show=True, show_predictions=True, show_observations=True)

In [None]:
draw_path(env_renderer,[(1,1),(1,2),(1,3),(1,4),(1,5),(2,5),(2,6),(2,4),(3,4)],0.3)
env_renderer.render_env(show=True, show_predictions=True, show_observations=True)

In [12]:
img = env_renderer.get_image()
image = Image.fromarray(img)
image.save("blocked_agents.png")

In [None]:
env.agents

## Size of networks

In [None]:
results = {}

In [None]:
for i in tqdm(range(1,50)):
    results[i*10] = []
    for repet in range(3):
        env = create_env(i*10,i*10,1+repet,1)
        T = 4 * 2 * (env.width + env.height + 20)
        test = NetworkGraph(np.array(env.rail.grid.tolist()))
        results[i*10].append(len(test.edges)*T)

In [None]:
results

In [None]:
x_axis = []
y_axis = []
for x,ys in results.items():
    if x != 470:
        x_axis.append(x)
        y_axis.append((ys[0]+ys[1]+ys[2])/3)

In [None]:
x_axis_sparse = []
y_axis_sparse = []
for x,ys in results_sparse.items():
    if x < 470 and x >10:
        x_axis_sparse.append(x)
        y_axis_sparse.append((ys[0]+ys[1]+ys[2])/3)

In [None]:
ticks = [x_axis[7*i] for i in range(int(len(x_axis)/7)) ]
ticks.append(460)

In [None]:
ticks = [10, 80, 220, 360, 460]

In [None]:
import seaborn as sns
sns.set()

In [None]:
import matplotlib

In [None]:
matplotlib.rc('font', size=10)
matplotlib.rc('axes', titlesize=12)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1,figsize = (10,5))

squads = [f'({x},{x})' for x in ticks]

ax.set_xticks(ticks)
ax.set_xticklabels(squads)
plt.plot(x_axis,y_axis, label = "complex")
plt.plot(x_axis_sparse,y_axis_sparse,label = "sparse")
ax = plt.gca()
plt.legend()
plt.xlabel("Size of the environment (height,widht)")
plt.ylabel("Edges in time expanded netowrk")


for item in (ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(12)
    
for item in [ax.title, ax.xaxis.label, ax.yaxis.label] :
    item.set_fontsize(12)
    
l = plt.legend()
for text in l.get_texts():
    text.set_fontsize(12)
    
plt.tight_layout()
plt.savefig("../reports/reportCharles/img/number_variables.jpg",dpi = 300)

In [None]:
results_sparse ={}

In [None]:
for i in tqdm(range(5,47)):
    results_sparse[i*10] = []
    for repet in range(3):
        env = create_env_sparse(i*10,i*10,1+repet,1)
        T = 4 * 2 * (env.width + env.height + 20)
        test = NetworkGraph(np.array(env.rail.grid.tolist()))
        results_sparse[i*10].append(len(test.edges)*T)

In [None]:
results_sparse

In [None]:
env = create_env(30,30,1,20)

In [None]:
env_renderer = RenderTool(env,screen_height=1200,screen_width=1200,
                          agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
                          show_debug=True)
env_renderer.render_env(show=True, show_predictions=False, show_observations=False)

In [None]:
from src.flows import *
from src.graph import *

In [None]:
test = NetworkGraph(np.array(env.rail.grid.tolist()))

In [None]:
test.show(title= "test_flatland_network.jpg")

In [None]:

import networkx as nx

G=nx.star_graph(2)
pos=nx.spring_layout(G)
colors=range(2)
nx.draw(G,pos,node_color='#A0CBE2',edge_color=colors,width=4,edge_cmap=plt.get_cmap("CMRmap"),with_labels=False)
plt.savefig("edge_colormap.png") # save as png
plt.show() # display


In [None]:
x = [0,1,2,3,4,5,6,7]

In [None]:
#plt.subplot(111, facecolor=(46/255., 48/255., 55/255.))
plt.plot(x,x)
plt.savefig("test.png")

In [None]:
x = np.array(x)

In [None]:
plt.rcParams['savefig.facecolor'] = (46/255., 48/255., 55/255.)
plt.rcParams['axes.facecolor'] = (46/255., 48/255., 55/255.)
fig, ax = plt.subplots(nrows=1, ncols=1)
ax.spines['bottom'].set_color("white")
ax.spines['top'].set_color("white") 
ax.spines['right'].set_color('white')
ax.spines['left'].set_color('white')
ax.tick_params(axis='x', colors='white')
ax.tick_params(axis='y', colors='white')
ax.title.set_color('white')
ax.xaxis.label.set_color('white')
ax.yaxis.label.set_color('white')
fig.set_facecolor((46/255., 48/255., 55/255.))
ax.set_facecolor((46/255., 48/255., 55/255.))


plt.plot(x,x, label = "wtf")
plt.xlabel("wtf")
plt.title("coucou")

#l = plt.legend()
#for text in l.get_texts():
#    text.set_color("white")



ax = plt.gca()
plt.savefig("test.png")

In [None]:
fig=plt.figure()
plt.rcParams['axes.facecolor'] = (46/255., 48/255., 55/255.)

# Plot the data and set the labels.
plt.plot(x,x,color='r', label ="wtf")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()

# Save the figure to a file
plt.savefig("face.png")

In [None]:
plt.plot(x,x)
plt.show()

In [None]:
import pylab as plt
plt.plot(x, label="randn")

leg = plt.legend(framealpha = 0, loc = 'best')
for text in leg.get_texts():
    plt.setp(text, color = 'w')