In [1]:
import datetime as dt
import calendar
import time
import pandas as pd
import pyspark.sql.functions as functions
import math
import getpass
import pyspark
from datetime import datetime, date, timedelta
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .master("yarn") \
    .appName('journey_planner-{0}'.format(getpass.getuser())) \
    .config('spark.jars.packages', 'graphframes:graphframes:0.6.0-spark2.3-s_2.11') \
    .config('spark.executor.memory', '8g') \
    .config('spark.executor.instances', '5') \
    .config('spark.port.maxRetries', '100') \
    .getOrCreate()

from graphframes import *

In [None]:
# load the data
df = spark.read.csv('/datasets/sbb/2018/*/*istdaten.csv.bz2', sep=';', header=True)

In [216]:
stations = pd.read_csv('data/filtered_stations.csv')
valid_stations = set(stations['Remark'])

In [None]:
stations = stations[['Longitude', 'Latitude', 'Remark']];
stations['key'] = 0

earth_radius = 6371e3

def haversine(row):
    phi1         = 2 * math.pi * float(row['Latitude_x']) / 360
    phi2         = 2 * math.pi * float(row['Latitude_y']) / 360
    delta_phi    = 2 * math.pi * (float(row['Latitude_y']) - float(row['Latitude_x'])) / 360
    delta_lambda = 2 * math.pi * (float(row['Longitude_y']) - float(row['Longitude_x'])) / 360
    
    a = (math.sin(delta_phi/2) ** 2) + \
        math.cos(phi1) * math.cos(phi2) * (math.sin(delta_lambda/2) ** 2)
    
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    
    d = earth_radius * c
    
    return d / 1000

prod = pd.merge(stations, stations, on='key')
prod['dist'] = prod.apply(lambda row: haversine(row), axis=1)

In [None]:
# We don't consider walking to stops that are more than 3 kilometers away
max_walking_distance = 3
walk_df = prod[prod['dist'] <= max_walking_distance]
walk_df = walk_df[walk_df['Remark_x'] != walk_df['Remark_y']]

walk_df = walk_df[['Remark_x', 'Remark_y', 'dist']]
walk_df['type'] = 'walk'
walk_df['line'] = 'walk'
walk_df['departure_day']  = 'null'
walk_df['departure_time'] = 'null'
walk_df['arrival_time']   = 'null'
# We assume an average walking speed of 5 kilometers per hour
walk_df['lateAvg'] = walk_df.apply(lambda row: 3600 * float(row['dist']) / 5, axis=1)
walk_df['lateStd'] = 0.0
walk_df.drop('dist', axis=1, inplace=True)
walk_df.columns = ['src', 'dst', 'type', 'line', 'departure_day', 'departure_time', 'arrival_time', 'lateAvg', 'lateStd']

In [None]:
walk_edges = spark.createDataFrame(walk_df)

In [None]:
vertices_df = stations[['Remark', 'Longitude', 'Latitude']]
vertices_df.columns = ['id', 'lon', 'lat']
vertices = spark.createDataFrame(vertices_df)

In [None]:
dateFormat = 'dd.MM.yyyy HH:mm'
timeLate = (functions.unix_timestamp('AN_PROGNOSE', format=dateFormat)
            - functions.unix_timestamp('ANKUNFTSZEIT', format=dateFormat))

@functions.udf
def clamp(late):
    return 0 if late < 0 else late

valid_stops = df.filter((df.DURCHFAHRT_TF=='false') & 
                        (df.FAELLT_AUS_TF=='false') & 
                        (df.ZUSATZFAHRT_TF=='false') &
                        (df.AN_PROGNOSE_STATUS=='GESCHAETZT') &
                        (df.HALTESTELLEN_NAME.isin(valid_stations))) \
                .select('BETRIEBSTAG',
                        'FAHRT_BEZEICHNER', 
                        'PRODUKT_ID', 
                        'LINIEN_TEXT', 
                        'HALTESTELLEN_NAME', 
                        'AN_PROGNOSE',
                        'ANKUNFTSZEIT', 
                        'ABFAHRTSZEIT') \
                .withColumn('AN_PROGNOSE',  functions.to_timestamp(df.AN_PROGNOSE, dateFormat))  \
                .withColumn('ANKUNFTSZEIT', functions.to_timestamp(df.ANKUNFTSZEIT, dateFormat)) \
                .withColumn('ABFAHRTSZEIT', functions.to_timestamp(df.ABFAHRTSZEIT, dateFormat)) \
                .withColumn('late', clamp(timeLate)) \
                .drop('AN_PROGNOSE')

In [None]:
departures = valid_stops.filter(valid_stops.ABFAHRTSZEIT.isNotNull())\
                        .drop('ANKUNFTSZEIT', 'late')
arrivals   = valid_stops.filter(valid_stops.ANKUNFTSZEIT.isNotNull())\
                        .drop('ABFAHRTSZEIT')

In [None]:
arrivals.createOrReplaceTempView('arrivals')
departures.createOrReplaceTempView('departures')

joinQuery = 'SELECT d.HALTESTELLEN_NAME AS src, a.HALTESTELLEN_NAME AS dst,              \
                    d.PRODUKT_ID AS type, d.LINIEN_TEXT AS line,                         \
                    date_format(d.ABFAHRTSZEIT, \'EEEE\') AS departure_day,              \
                    SUBSTRING(d.ABFAHRTSZEIT, 12, 8) AS departure_time,                  \
                    SUBSTRING(a.ANKUNFTSZEIT, 12, 8) AS arrival_time,                    \
                    a.late                                                               \
             FROM arrivals AS a INNER JOIN departures AS d                               \
             ON a.BETRIEBSTAG == d.BETRIEBSTAG                                           \
             AND a.FAHRT_BEZEICHNER == d.FAHRT_BEZEICHNER                                \
             WHERE a.HALTESTELLEN_NAME != d.HALTESTELLEN_NAME                            \
             AND d.ABFAHRTSZEIT < a.ANKUNFTSZEIT'

edges = spark.sql(joinQuery)

In [None]:
edges.createOrReplaceTempView('edges')

query = 'SELECT src, dst, type, line, departure_day, departure_time, arrival_time,              \
         AVG(late) AS lateAvg, STD(late) AS lateStd                                             \
         FROM edges GROUP BY src, dst, type, line, departure_day, departure_time, arrival_time'

aggregated = spark.sql(query)
aggregated_edges = aggregated.na.fill(0.0)

all_edges = aggregated_edges.union(walk_edges)

In [None]:
all_edges.write.parquet('/homes/schmutz/edges', mode='overwrite')

In [None]:
vertices.write.parquet('/homes/schmutz/vertices', mode='overwrite')

## How far can we go in M minutes

In [3]:
vertices = spark.read.parquet('/homes/schmutz/vertices')

In [4]:
vertices.show(n=5)

+--------------------+-----------------+-----------------+
|                  id|              lon|              lat|
+--------------------+-----------------+-----------------+
|   Zumikon, Gössikon|         8.614773|        47.332474|
|   Zumikon, Waltikon|         8.618188|        47.336109|
|Zumikon, Dorfzentrum|         8.622922|        47.332976|
|Zürich, Meierhofp...|         8.499375|        47.402009|
|  Zürich, Heizenholz|8.483903999999999|47.41229600000001|
+--------------------+-----------------+-----------------+
only showing top 5 rows



In [5]:
edges = spark.read.parquet('/homes/schmutz/edges')

In [6]:
edges.show(n=5)

+--------------------+--------------------+----+----+-------------+--------------+------------+------------------+-------+
|                 src|                 dst|type|line|departure_day|departure_time|arrival_time|           lateAvg|lateStd|
+--------------------+--------------------+----+----+-------------+--------------+------------+------------------+-------+
|Thalwil, Archstrasse|Thalwil, Feldstrasse|walk|walk|         null|          null|        null|472.61572587575955|    0.0|
|Thalwil, Archstrasse|Thalwil, Mühlebac...|walk|walk|         null|          null|        null| 421.5532234300196|    0.0|
|Thalwil, Archstrasse|    Thalwil, Zentrum|walk|walk|         null|          null|        null|175.33795811343845|    0.0|
|Thalwil, Archstrasse|    Thalwil, Bahnhof|walk|walk|         null|          null|        null|184.83110825655203|    0.0|
|Thalwil, Archstrasse|Küsnacht ZH, Ob. ...|walk|walk|         null|          null|        null| 2109.991732919267|    0.0|
+---------------

In [7]:
graph = GraphFrame(vertices, edges)

In [6]:
def getSubGraph(graph, startDay, finishDay, startTime, finishTime, duration):
    def valid(day, depTime, arrTime, walkTime):
        if startDay==finishDay:
            return ((day=='null') & (walkTime<=duration)) | \
                    ((day==startDay) & (depTime>=startTime) & (arrTime<=finishTime) & (depTime<=arrTime))
        else:
            return ((day=='null') & (walkTime<=duration)) | \
                    (((day==startDay) & (depTime>=startTime) & ((depTime<=arrTime) | (arrTime<=finishTime))) | \
                     ((day==finishDay) & (depTime<finishTime) & (arrTime<=finishTime)))

    return graph.filterEdges(valid(graph.edges.departure_day, 
                                graph.edges.departure_time,
                                graph.edges.arrival_time,
                                graph.edges.lateAvg))  \
                .dropIsolatedVertices()

In [15]:
def howFarNaive(graph, fromStation, startDateTime, duration):
    
    if duration >= 120:
        print('You can walk anywhere in that time')
        return
    
    finishDateTime = startDateTime + timedelta(minutes=duration)

    startTime  = str(startDateTime.time())
    finishTime = str(finishDateTime.time())

    startDay  = calendar.day_name[startDateTime.weekday()]
    finishDay = calendar.day_name[finishDateTime.weekday()]

    print(startDay, startTime)
    print(finishDay, finishTime)
    
    @functions.udf
    def addTime(arr_time, dep_time, late):
        if arr_time=='null':
            tmp = dep_time.split(':')
            return str((datetime.combine(date.today(), dt.time(int(tmp[0]), int(tmp[1]), int(tmp[2]))) + 
                        timedelta(seconds=int(late))).time())
        else:
            return arr_time
    
    @functions.udf
    def checkDay(day, dep_time, arr_time):
        return finishDay if arr_time<dep_time else day
    
    @functions.udf
    def checkWalk(ttype):
        return 1 if ttype=='walk' else 0
    
    @functions.udf
    def checkIfValid(arr_time, day):
        tmp = arr_time.split(':')
        arr_date = startDateTime.date() if day==calendar.day_name[startDateTime.weekday()] else finishDateTime.date()
        arrival = datetime.combine(arr_date, dt.time(int(tmp[0]), int(tmp[1]), int(tmp[2])))
        return arrival < finishDateTime
    
    reachable = vertices.filter(vertices.id==fromStation)             \
                        .withColumn('time', functions.lit(startTime)) \
                        .withColumn('day', functions.lit(startDay))   \
                        .withColumn('just_walked', functions.lit(0))
    
    g = getSubGraph(graph, startDay, finishDay, startTime, finishTime, 60*duration)
    g.persist();
    g.edges.createOrReplaceTempView('edges')
    g.vertices.createOrReplaceTempView('vertices')
    
    curr = reachable
    
    #while len(curr.head(1)) > 0:
    for i in range(1):
        curr.createOrReplaceTempView('curr')

        query = 'SELECT v.*, c.time AS past_time, c.just_walked, c.day, e.type,          \
                        e.departure_time, e.arrival_time, e.lateAvg                      \
                 FROM curr AS c INNER JOIN edges AS e INNER JOIN vertices AS v           \
                 ON c.id==e.src                                                          \
                 AND e.dst==v.id                                                         \
                 WHERE (e.type!=\'walk\' OR c.just_walked==0)                            \
                 AND (e.type==\'walk\'                                                   \
                 OR (e.departure_time>=c.time AND c.day==e.departure_day)                \
                 OR (e.departure_time<c.time AND c.day!=e.departure_day))'

        curr = spark.sql(query).withColumn('time', addTime('arrival_time', 'past_time', 'lateAvg')) \
                               .withColumn('day', checkDay('day', 'past_time', 'time'))             \
                               .filter(checkIfValid('time', 'day')=='true')                         \
                               .withColumn('just_walked', checkWalk('type'))                        \
                               .select('id', 'lon', 'lat', 'time', 'day', 'just_walked')
        curr.persist()
        reachable = reachable.union(curr)
    
    @functions.udf
    def computeRadius(arr_time, day):
        tmp = arr_time.split(':')
        arr_date = startDateTime.date() if day==calendar.day_name[startDateTime.weekday()] else finishDateTime.date()
        arrival = datetime.combine(arr_date, dt.time(int(tmp[0]), int(tmp[1]), int(tmp[2])))
        return (finishDateTime - arrival).seconds * 5 // 3.6
    
    reachable = reachable.withColumn('radius', computeRadius('time', 'day')) \
                         .select('id', 'lon', 'lat', 'radius', 'time', 'just_walked')                     \
                         .toPandas()
    
    g.unpersist();
    
    return reachable

In [17]:
graph = GraphFrame(vertices, edges)
fromStation = 'Dietlikon'
startDateTime  = datetime(2019, 5, 31, 23, 45)
duration = 60

start = time.time()
reachable = howFarNaive(graph, fromStation, startDateTime, duration)
print(time.time() - start)
reachable

Friday 23:45:00
Saturday 00:45:00
3.202773094177246


Unnamed: 0,id,lon,lat,radius,time,just_walked
0,Dietlikon,8.619255,47.420195,4999.0,23:45:00,0
1,"Wallisellen, Glatt",8.595545,47.409209,2838.0,00:10:56,1
2,"Wallisellen, Zentrum Glatt",8.595545,47.409209,2838.0,00:10:56,1
3,Dübendorf,8.623407,47.400076,2741.0,00:12:06,1
4,Wallisellen,8.591911,47.412717,2781.0,00:11:37,1
5,Bassersdorf,8.626199,47.438564,2893.0,00:10:17,1
6,"Bassersdorf, Bahnhof",8.626136,47.438718,2876.0,00:10:29,1
7,"Bassersdorf, Talgüetli",8.620484,47.445099,2230.0,00:18:14,1
8,"Bassersdorf, Schmitte",8.628763,47.442945,2372.0,00:16:32,1
9,"Bassersdorf, Löwen",8.628127,47.443582,2315.0,00:17:13,1


In [210]:
MINUTES_PER_DAY = 1440
MINUTES_PER_HOUR = 60
SECONDS_PER_MINUTE = 60

def computeDiff(departure, arrival):
    dep = (departure[2:]).split(':')
    arr = (arrival[2:]).split(':')
    a = (int(arrival[:1]) - int(departure[:1])) * MINUTES_PER_DAY
    b = (int(arr[0]) - int(dep[0])) * MINUTES_PER_HOUR
    c = (int(arr[1]) - int(dep[1]))
    tot = a + b + c
    hours = tot // MINUTES_PER_HOUR
    minutes = tot % MINUTES_PER_HOUR
    return "{:02d}".format(int(hours)) + ':' + "{:02d}".format(int(minutes)) + ':00'

def computeCost(cost, late):
    tmp = cost.split(':')
    a = int(tmp[0][2:]) * MINUTES_PER_HOUR + int(tmp[1]) + MINUTES_PER_DAY - late // SECONDS_PER_MINUTE
    prefix = '0-' if a < MINUTES_PER_DAY else '1-'
    a = a % MINUTES_PER_DAY
    minutes = a % MINUTES_PER_HOUR
    hours = (a - minutes) // MINUTES_PER_HOUR
    return prefix + "{:02d}".format(int(hours)) + ':' + "{:02d}".format(int(minutes)) + ':00'

def getFilteredEdges(startDay, finishDay, startTime, finishTime, duration):
    def valid(day, depTime, arrTime, walkTime):
        if startDay==finishDay:
            return ((day=='null') & (walkTime<=duration)) | \
                    ((day==startDay) & (depTime>=startTime) & (arrTime<=finishTime) & (depTime<=arrTime))
        else:
            return ((day=='null') & (walkTime<=duration)) | \
                    (((day==startDay) & (depTime>=startTime) & ((depTime<=arrTime) | (arrTime<=finishTime))) | \
                     ((day==finishDay) & (depTime<finishTime) & (arrTime<=finishTime)))

    return graph.filterEdges(valid(graph.edges.departure_day, 
                                graph.edges.departure_time,
                                graph.edges.arrival_time,
                                graph.edges.lateAvg)).edges

def add_vertice_to_set(max_set, vertice, vertice_costs, edges, next_vertices):
    
    max_set.add(vertice)
    cost = vertice_costs[vertice]

    vertice_edges = edges[((edges.dst == vertice) & (edges.type == 'walk')) 
                                  | ((edges.dst == vertice) 
                                  & (edges.arrival_time < cost))]

    for i, edge in vertice_edges.iterrows():
        if edge['type'] == 'walk':
            new_cost = computeCost(cost, edge['lateAvg'])
            if edge['src'] not in vertice_costs or new_cost > vertice_costs[edge.dst]:
                next_vertices[edge['src']] = edge
                vertice_costs[edge.src] = new_cost
        elif edge['src'] not in vertice_costs or edge['departure_time'] > vertice_costs[edge['dst']]:
            vertice_costs[edge['src']] = edge['departure_time']
            next_vertices[edge['src']] = edge
            

def get_max_vertice_not_in_set(max_set, vertice_costs, min_trip_departure_time):
    max_vertice = None
    max_cost = min_trip_departure_time
    for vertice in vertice_costs:
        if vertice not in max_set and vertice_costs[vertice] > max_cost:
            max_cost = vertice_costs[vertice]
            max_vertice = vertice
    
    return max_vertice

def find_path(next_vertices, current_vertice, current_path):
    if current_vertice not in next_vertices:
        return current_path
    next_vertice = next_vertices[current_vertice]['dst']
    current_path.append(next_vertices[current_vertice])
    return find_path(next_vertices, next_vertice, current_path)
    

def find_shortest_path(departure_station, arrival_station, 
                       startDateTime, endDateTime, 
                       min_probability_of_sucess , edges):
    
    print(startDateTime)
    print(endDateTime)
    
    startTime  = str(startDateTime.time())
    endTime = str(endDateTime.time())

    startDay  = calendar.day_name[startDateTime.weekday()]
    endDay = calendar.day_name[endDateTime.weekday()]
    
    min_trip_departure_time = '0-' + startTime
    
    endTimePrefix = '0-' if (startDay == endDay) else '1-'
    requested_arrival_time = endTimePrefix + endTime
    
    duration = (endDateTime - startDateTime).seconds
    
    print(startDay, startTime)
    print(endDay, endTime)
    
    filtered_edges = getFilteredEdges(startDay, endDay, startTime, endTime, duration).toPandas()
    
    def to_dt(time):
        if time == 'null':
            return 'null'
        elif time >= startTime:
            return '0-' + time
        else:
            return '1-' + time
    
    filtered_edges['departure_time'] = filtered_edges['departure_time'].map(lambda x: to_dt(x))
    filtered_edges['arrival_time']   = filtered_edges['arrival_time'].map(lambda x: to_dt(x))
    
    # in minutes
    vertice_costs = {}
    vertice_costs[arrival_station] = requested_arrival_time

    max_set = set()
    next_vertices = {}
    add_vertice_to_set(max_set, arrival_station, vertice_costs, filtered_edges, next_vertices)
    no_solution = False
    while(departure_station not in max_set and not no_solution):
        max_vertice = get_max_vertice_not_in_set(max_set, vertice_costs, min_trip_departure_time)
        if max_vertice is None:
            no_solution = True
        else:
            add_vertice_to_set(max_set, max_vertice, vertice_costs, filtered_edges, next_vertices)
    
    if no_solution:
        print("no solution", vertice_costs)
    
    departure_time = (vertice_costs[departure_station])[2:]
        
    trip_duration = computeDiff(vertice_costs[departure_station], requested_arrival_time)
    
    return departure_time, trip_duration, find_path(next_vertices, departure_station, [departure_station])

In [213]:
fromStation = 'Dietlikon'
toStation   = 'Zürich HB'
startDateTime = datetime(2019, 6, 2, 23, 15)
endDateTime   = datetime(2019, 6, 3, 0, 10)

find_shortest_path('Dietlikon', 'Zürich HB', 
                   startDateTime, 
                   endDateTime, 0, edges)

2019-06-02 23:15:00
2019-06-03 00:10:00
Sunday 23:15:00
Monday 00:10:00


RecursionError: maximum recursion depth exceeded in comparison

In [223]:
MAX_TRIP_DURATION = 180
NUMBER_OF_MINUTES_IN_WEEK = 1440
NUMBER_OF_SECONDS_IN_MINUTE = 60
NUMBER_OF_MINUTES_IN_HOUR = 60
def string_to_minute_of_week(day_of_week, hour_of_day):
    if hour_of_day == 'null':
        return 0
    week_days = {'Monday':0, 'Tuesday': 1, 'Wednesday':2, 'Thursday':3, 'Friday':4, 'Saturday':5, 'Sunday':6}
    day = week_days[day_of_week] * NUMBER_OF_MINUTES_IN_WEEK
    hour = int(hour_of_day[:2]) * NUMBER_OF_MINUTES_IN_HOUR
    minutes = int(hour_of_day[3:5])
    return day + hour + minutes

def minute_to_hour(minutes_):
    week_days = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
    minutes = minutes_ % NUMBER_OF_MINUTES_IN_WEEK
    hour = minutes / NUMBER_OF_MINUTES_IN_HOUR
    minutes = int(minutes % NUMBER_OF_MINUTES_IN_HOUR)
    return week_days[int(minutes_ / NUMBER_OF_MINUTES_IN_WEEK)], "{:02d}".format(int(hour)) + ':' + "{:02d}".format(int(minutes)) + ":00"

def add_vertice_to_set(max_set, vertice, vertice_costs, edges, next_vertices):
    startTime = time.time()
    max_set.add(vertice)
    cost = vertice_costs[vertice]

    vertice_edges = edges[((edges.dst == vertice) & (edges.departure_time == 'null')) 
                                  | ((edges.dst == vertice) 
                                  & (edges.arrival_time_min < cost))]

    for i, edge in vertice_edges.iterrows():
        if edge['type'] == 'walk':
            new_cost = cost - edge['lateAvg'] / NUMBER_OF_SECONDS_IN_MINUTE
            if edge['src'] not in vertice_costs or new_cost > vertice_costs[edge.dst]:
                next_vertices[edge['src']] = edge
                vertice_costs[edge.src] = new_cost
        elif edge['src'] not in vertice_costs or edge['departure_time_min'] > vertice_costs[edge['dst']]:
            vertice_costs[edge['src']] = edge['departure_time_min']
            next_vertices[edge['src']] = edge
    print(time.time() - startTime)
            

def get_max_vertice_not_in_set(max_set, vertice_costs, min_trip_departure_time):
    max_vertice = None
    max_cost = min_trip_departure_time
    for vertice in vertice_costs:
        if vertice not in max_set and vertice_costs[vertice] > max_cost:
            max_cost = vertice_costs[vertice]
            max_vertice = vertice
    
    return max_vertice

def find_path(next_vertices, current_vertice, current_path):
    if current_vertice not in next_vertices:
        return current_path
    next_vertice = next_vertices[current_vertice]['dst']
    current_path.append(next_vertices[current_vertice])
    return find_path(next_vertices, next_vertice, current_path)
    

def find_shortest_path(departure_station, arrival_station, 
                       requested_arrival_time, min_trip_departure_time, 
                       min_probability_of_sucess , edges):
    
    print(min_trip_departure_time)
    print(requested_arrival_time)
    
    day_departure, cost_string_departure = minute_to_hour(min_trip_departure_time)
    day_arrival, cost_string_arrival = minute_to_hour(requested_arrival_time)
    
    print(day_departure, cost_string_departure)
    print(day_arrival, cost_string_arrival)
    
    filtered_edges = edges.filter((edges.departure_time == 'null') 
                                  | ((edges.departure_time > cost_string_departure) 
                                  & (edges.departure_day == day_departure)
                                  & (edges.arrival_time > cost_string_departure)
                                  & (edges.arrival_time < cost_string_arrival))).toPandas()
    
    print(len(filtered_edges.index))
    
    start = time.time()
    filtered_edges['arrival_time_min'] = filtered_edges['arrival_time'].map(lambda x: string_to_minute_of_week(day_departure, x))
    filtered_edges['departure_time_min'] = filtered_edges['departure_time'].map(lambda x: string_to_minute_of_week(day_departure, x))
    print(time.time() - start)
    
    # in minutes
    vertice_costs = {}
    vertice_costs[arrival_station] = requested_arrival_time

    max_set = set()
    next_vertices = {}
    add_vertice_to_set(max_set, arrival_station, vertice_costs, filtered_edges, next_vertices)
    no_solution = False
    while(departure_station not in max_set and not no_solution):
        max_vertice = get_max_vertice_not_in_set(max_set, vertice_costs, min_trip_departure_time)
        if max_vertice is None:
            no_solution = True
        else:
            add_vertice_to_set(max_set, max_vertice, vertice_costs, filtered_edges, next_vertices)

    if no_solution:
        print("no solution", vertice_costs)
        return
    day, departure_time = minute_to_hour(vertice_costs[departure_station])
    day, trip_duration = minute_to_hour(requested_arrival_time-vertice_costs[departure_station])
    return departure_time, trip_duration, find_path(next_vertices, departure_station, [departure_station])


In [224]:
find_shortest_path('Dietlikon', 'Zürich HB', 
                   string_to_minute_of_week('Monday','00:10:00'), 
                   string_to_minute_of_week('Sunday','23:15:00'), 0, edges)

10035
10
Sunday 23:15:00
Monday 00:10:00
119508
0.10329103469848633
0.06613564491271973
no solution {'Zürich HB': 10, 'Zürich, Albisriederplatz': -16.92239394957554, 'Zürich, Bäckeranlage': -5.650032178932243, 'Zürich Enge, Bahnhof': -10.299471352285227, 'Zürich Enge, Bahnhof/Bederstr.': -10.196969351718202, 'Zürich Hardbrücke, Bahnhof': -12.78366600971242, 'Zürich Selnau, Bahnhof': -2.4461688781234177, 'Zürich Wipkingen, Bahnhof': -12.535379860442397, 'Zürich, Bahnhofstrasse/HB': 7.830954334366682, 'Zürich, Beckenhof': 2.114718609167875, 'Zürich, Berninaplatz': -24.684590911626465, 'Zürich, Bernoulli-Häuser': -25.236340174229298, 'Zürich, Bertastrasse': -14.600055494515622, 'Zürich, Berufswahlschule': -25.500189058695724, 'Zürich, Bethanien': -4.152229911498626, 'Zürich, Bezirksgebäude': -4.265368586549478, 'Zürich, Grubenstrasse': -22.907914835571106, 'Zürich, Binz Center': -24.13751837433888, 'Zürich, Bircher-Benner': -12.470728814924733, 'Zürich, Börsenstrasse': -3.3587851772228525

In [None]:
#grouped = edges.groupBy([edges.src, edges.dst, edges.type, edges.subtype])
#grouped_edges = grouped.agg({'departure_time': 'collect_list',
#                             'arrival_time'  : 'collect_list'})\
#                        .withColumnRenamed('collect_list(arrival_time)', 'arrival_times')\
#                        .withColumnRenamed('collect_list(departure_time)', 'departure_times')