# Imports

##### General imports

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../")

In [None]:
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import copy
from tqdm import tqdm

##### Import from flatland environment 

In [None]:
from flatland.envs.rail_env import RailEnv
from flatland.envs.observations import *
from flatland.envs.rail_generators import complex_rail_generator,rail_from_manual_specifications_generator,random_rail_generator, RailGenerator,sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator, ScheduleGenerator, sparse_schedule_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant

##### Import from our framework

In [None]:
from src.graph import NetworkGraph

In [None]:
from src.flows import *

# Test of time expanded network

##### Create a flatland network

sparse netwrok

In [None]:
number_agents = 5

size_side = 40
stochastic_data = {'prop_malfunction': 0.3,  # Percentage of defective agents
                   'malfunction_rate': 30,  # Rate of malfunction occurence
                   'min_duration': 3,  # Minimal duration of malfunction
                   'max_duration': 20  # Max duration of malfunction
                   }
speed_ration_map = {1.: 0.25,  # Fast passenger train
                    1. / 2.: 0.25,  # Fast freight train
                    1. / 3.: 0.25,  # Slow commuter train
                    1. / 4.: 0.25}  # Slow freight train
env = RailEnv(width=size_side,
              height=size_side,
              rail_generator=sparse_rail_generator(max_num_cities=4,
                                                   # Number of cities in map (where train stations are)
                                                   seed=14,  # Random seed
                                                   grid_mode=False,
                                                   max_rails_between_cities=2,
                                                   max_rails_in_city=8,
                                                   ),
              schedule_generator=sparse_schedule_generator(speed_ration_map),
              number_of_agents=number_agents,
              stochastic_data=stochastic_data,  # Malfunction data generator
              obs_builder_object=GlobalObsForRailEnv(),
              remove_agents_at_target=True
              )

# RailEnv.DEPOT_POSITION = lambda agent, agent_handle : (agent_handle % env.height,0)
env.reset()

env_renderer = RenderTool(env,
                          agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
                          show_debug=True,
                          screen_height=1100,
                          screen_width=1800)
env_renderer.reset()

env_renderer.render_env(show=True, show_observations=False, show_predictions=False)



usual network

In [None]:
number_agents = 40

size_side = 50
env = RailEnv(width=size_side,
              height=size_side,
              rail_generator=complex_rail_generator(nr_start_goal=number_agents, nr_extra=1, 
                                                    min_dist=6, max_dist=99999, 
                                                    seed = np.random.randint(0,2000)),
              schedule_generator=complex_schedule_generator(),
              number_of_agents=number_agents,
              obs_builder_object=GlobalObsForRailEnv())

env.reset()



env_renderer = RenderTool(env)
env_renderer.render_env(show=True, show_predictions=False, show_observations=False)

In [None]:
matrix_rail = np.array(env.rail.grid.tolist())
flatlandNetwork = NetworkGraph(matrix_rail,[(0,1)],[(1,0)])

In [None]:
sources = []
sinks = []
directions = []
for agent in env.agents:
    sources.append(agent.initial_position)
    sinks.append(agent.target)
    try:
        directions.append(agent.direction)
    except:
        pass

In [None]:
if len(directions) == 0:
    directions = None

##### create a time expanded network

In [None]:
import time

In [None]:
start = time.time()
TestNetworkTime = TimeNetwork(flatlandNetwork, depth=200)
stop = time.time()
print(f'time taken to build the graph: {stop-start}')

In [None]:
TestNetworkTime.connect_sources_and_sink(sources,sinks,directions = None)

### create initial solution

In [None]:
constraints = TestNetworkTime.get_topology_network()

In [None]:
find_constraints = {}

In [None]:
for restriction in constraints:
    for edge in restriction:
        if edge not in find_constraints.keys():
            find_constraints[edge] = [restriction]
        else:
            find_constraints[edge].append(restriction)

In [None]:
test = InitialSolutionGenerator(TestNetworkTime,constraints,find_constraints,len(sources))

In [None]:
solution = test.getInitialSolution()

In [None]:
test.showStats()

In [None]:
solution

In [None]:
len(solution) == len(sources)

In [None]:
solutions = {}
for i,sol in enumerate(solution):
    solutions[i] = sol[1:-2]

##### Test LP Formulation

test a simple graph

In [None]:
import datetime as dt
print(dt.datetime.now())

In [None]:
mcflow = MCFlow(TestNetworkTime.graph,len(sources),TestNetworkTime.get_topology_network(),integer = False)

In [None]:
mcflow.solve()

In [None]:
paths = mcflow.get_paths_solution()

In [None]:
score = 0
for _,path in paths.items():
    score += int(len(path)/2)
    
print(f"score {score}")

In [None]:
lengths = []
pathsToAllongate = copy.deepcopy(mcflow.solution)
for agent,path in pathsToAllongate.items():
    lengths.append(len(path))
    
maxLength = max(lengths)
for agent,path in pathsToAllongate.items():
    for i in range(maxLength-len(path)):
        path.append(None)
    pathsToAllongate[agent] = path

In [None]:
dfPaths = pd.DataFrame(pathsToAllongate)
dfPaths

In [None]:
import datetime as dt
print(dt.datetime.now())

In [None]:
colors = [(230, 25, 75), (60, 180, 75), (255, 225, 25), (0, 130, 200), (245, 130, 48),
          (145, 30, 180), (70, 240, 240), (240, 50, 230), (210, 245, 60), (250, 190, 190),
          (0, 128, 128), (230, 190, 255), (170, 110, 40), (255, 250, 200), (128, 0, 0), (170, 255, 195),
          (128, 128, 0), (255, 215, 180), (0, 0, 128), (128, 128, 128), (255, 255, 255), (0, 0, 0)]
colors = [(x[0]/255.,x[1]/255.,x[2]/255.) for x in colors]

In [None]:
flatlandNetwork.show(paths = solutions)