# Imports

In [1]:
import datetime as dt
import calendar
import time
import numpy as np
import pandas as pd
import scipy.stats as stat
import pyspark.sql.functions as functions
import math
import getpass
import pyspark
from datetime import datetime, date, timedelta
from pyspark.sql import SparkSession
import networkx as nx

spark = SparkSession \
    .builder \
    .master("yarn") \
    .appName('journey_planner-{0}'.format(getpass.getuser())) \
    .config('spark.jars.packages', 'graphframes:graphframes:0.6.0-spark2.3-s_2.11') \
    .config('spark.executor.memory', '8g') \
    .config('spark.executor.instances', '5') \
    .config('spark.port.maxRetries', '100') \
    .getOrCreate()

from graphframes import *

# Generate vertices and edges

In [20]:
# load the data
df = spark.read.csv('/datasets/sbb/2018/*/*istdaten.csv.bz2', sep=';', header=True)

In [21]:
stations = pd.read_csv('data/filtered_stations.csv')
valid_stations = set(stations['Remark'])

## Vertices

In [22]:
vertices_df = stations[['Remark', 'Longitude', 'Latitude']]
vertices_df.columns = ['id', 'lon', 'lat']
vertices = spark.createDataFrame(vertices_df)

## Walk edges

In [23]:
stations = stations[['Longitude', 'Latitude', 'Remark']];
stations['key'] = 0

earth_radius = 6371e3

def haversine(row):
    phi1         = 2 * math.pi * float(row['Latitude_x']) / 360
    phi2         = 2 * math.pi * float(row['Latitude_y']) / 360
    delta_phi    = 2 * math.pi * (float(row['Latitude_y']) - float(row['Latitude_x'])) / 360
    delta_lambda = 2 * math.pi * (float(row['Longitude_y']) - float(row['Longitude_x'])) / 360
    
    a = (math.sin(delta_phi/2) ** 2) + \
        math.cos(phi1) * math.cos(phi2) * (math.sin(delta_lambda/2) ** 2)
    
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    
    d = earth_radius * c
    
    return d / 1000

prod = pd.merge(stations, stations, on='key')
prod['dist'] = prod.apply(lambda row: haversine(row), axis=1)

In [24]:
# We don't consider walking to stops that are more than 1 kilometers away
max_walking_distance = 1
walk_df = prod[prod['dist'] <= max_walking_distance]
walk_df = walk_df[walk_df['Remark_x'] != walk_df['Remark_y']]

walk_df = walk_df[['Remark_x', 'Remark_y', 'dist']]
walk_df['type'] = 'walk'
walk_df['line'] = 'walk'
walk_df['departure_day']  = 'null'
walk_df['departure_time'] = 'null'
walk_df['arrival_time']   = 'null'
# We assume an average walking speed of 5 kilometers per hour
walk_df['lateAvg'] = walk_df.apply(lambda row: 3600 * float(row['dist']) / 5, axis=1)
walk_df['lateStd'] = 0.0
walk_df.drop('dist', axis=1, inplace=True)
walk_df.columns = ['src', 'dst', 'type', 'line', 'departure_day', 'departure_time', 'arrival_time', 'lateAvg', 'lateStd']

In [25]:
walk_edges = spark.createDataFrame(walk_df)

## Transport edges

In [26]:
dateFormat = 'dd.MM.yyyy HH:mm'
timeLate = (functions.unix_timestamp('AN_PROGNOSE', format=dateFormat)
            - functions.unix_timestamp('ANKUNFTSZEIT', format=dateFormat))

@functions.udf
def clamp(late):
    return 0 if late < 0 else late

valid_stops = df.filter((df.DURCHFAHRT_TF=='false') & 
                        (df.FAELLT_AUS_TF=='false') & 
                        (df.ZUSATZFAHRT_TF=='false') &
                        (df.AN_PROGNOSE_STATUS=='GESCHAETZT') &
                        (df.HALTESTELLEN_NAME.isin(valid_stations))) \
                .select('BETRIEBSTAG',
                        'FAHRT_BEZEICHNER', 
                        'PRODUKT_ID', 
                        'LINIEN_TEXT', 
                        'HALTESTELLEN_NAME', 
                        'AN_PROGNOSE',
                        'ANKUNFTSZEIT', 
                        'ABFAHRTSZEIT') \
                .withColumn('AN_PROGNOSE',  functions.to_timestamp(df.AN_PROGNOSE, dateFormat))  \
                .withColumn('ANKUNFTSZEIT', functions.to_timestamp(df.ANKUNFTSZEIT, dateFormat)) \
                .withColumn('ABFAHRTSZEIT', functions.to_timestamp(df.ABFAHRTSZEIT, dateFormat)) \
                .withColumn('late', clamp(timeLate)) \
                .drop('AN_PROGNOSE')

In [27]:
departures = valid_stops.filter(valid_stops.ABFAHRTSZEIT.isNotNull())\
                        .drop('ANKUNFTSZEIT', 'late')
arrivals   = valid_stops.filter(valid_stops.ANKUNFTSZEIT.isNotNull())\
                        .drop('ABFAHRTSZEIT')

In [28]:
arrivals.createOrReplaceTempView('arrivals')
departures.createOrReplaceTempView('departures')

joinQuery = 'SELECT d.HALTESTELLEN_NAME AS src, a.HALTESTELLEN_NAME AS dst,              \
                    d.PRODUKT_ID AS type, d.LINIEN_TEXT AS line,                         \
                    date_format(d.ABFAHRTSZEIT, \'EEEE\') AS departure_day,              \
                    SUBSTRING(d.ABFAHRTSZEIT, 12, 8) AS departure_time,                  \
                    SUBSTRING(a.ANKUNFTSZEIT, 12, 8) AS arrival_time,                    \
                    a.late                                                               \
             FROM arrivals AS a INNER JOIN departures AS d                               \
             ON a.BETRIEBSTAG == d.BETRIEBSTAG                                           \
             AND a.FAHRT_BEZEICHNER == d.FAHRT_BEZEICHNER                                \
             WHERE a.HALTESTELLEN_NAME != d.HALTESTELLEN_NAME                            \
             AND d.ABFAHRTSZEIT < a.ANKUNFTSZEIT'

edges = spark.sql(joinQuery)

In [29]:
edges.createOrReplaceTempView('edges')

query = 'SELECT src, dst, type, line, departure_day, departure_time, arrival_time,              \
         AVG(late) AS lateAvg, STD(late) AS lateStd                                             \
         FROM edges GROUP BY src, dst, type, line, departure_day, departure_time, arrival_time'

aggregated = spark.sql(query)
aggregated_edges = aggregated.na.fill(0.0)

all_edges = aggregated_edges.union(walk_edges)

## Write data to hdfs

In [30]:
all_edges.write.parquet('/homes/schmutz/edges', mode='overwrite')

Py4JJavaError: An error occurred while calling o1457.parquet.
: org.apache.hadoop.security.AccessControlException: Permission denied: user=kgerard, access=WRITE, inode="/homes/schmutz":schmutz:hadoop:drwxr-xr-x
	at org.apache.hadoop.hdfs.server.namenode.FSPermissionChecker.check(FSPermissionChecker.java:399)
	at org.apache.hadoop.hdfs.server.namenode.FSPermissionChecker.checkPermission(FSPermissionChecker.java:258)
	at org.apache.hadoop.hdfs.server.namenode.FSPermissionChecker.checkPermission(FSPermissionChecker.java:193)
	at org.apache.hadoop.hdfs.server.namenode.FSDirectory.checkPermission(FSDirectory.java:1857)
	at org.apache.hadoop.hdfs.server.namenode.FSDirDeleteOp.delete(FSDirDeleteOp.java:110)
	at org.apache.hadoop.hdfs.server.namenode.FSNamesystem.delete(FSNamesystem.java:3002)
	at org.apache.hadoop.hdfs.server.namenode.NameNodeRpcServer.delete(NameNodeRpcServer.java:1095)
	at org.apache.hadoop.hdfs.protocolPB.ClientNamenodeProtocolServerSideTranslatorPB.delete(ClientNamenodeProtocolServerSideTranslatorPB.java:692)
	at org.apache.hadoop.hdfs.protocol.proto.ClientNamenodeProtocolProtos$ClientNamenodeProtocol$2.callBlockingMethod(ClientNamenodeProtocolProtos.java)
	at org.apache.hadoop.ipc.ProtobufRpcEngine$Server$ProtoBufRpcInvoker.call(ProtobufRpcEngine.java:524)
	at org.apache.hadoop.ipc.RPC$Server.call(RPC.java:1025)
	at org.apache.hadoop.ipc.Server$RpcCall.run(Server.java:876)
	at org.apache.hadoop.ipc.Server$RpcCall.run(Server.java:822)
	at java.security.AccessController.doPrivileged(Native Method)
	at javax.security.auth.Subject.doAs(Subject.java:422)
	at org.apache.hadoop.security.UserGroupInformation.doAs(UserGroupInformation.java:1730)
	at org.apache.hadoop.ipc.Server$Handler.run(Server.java:2682)

	at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)
	at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)
	at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)
	at java.lang.reflect.Constructor.newInstance(Constructor.java:423)
	at org.apache.hadoop.ipc.RemoteException.instantiateException(RemoteException.java:121)
	at org.apache.hadoop.ipc.RemoteException.unwrapRemoteException(RemoteException.java:88)
	at org.apache.hadoop.hdfs.DFSClient.delete(DFSClient.java:1603)
	at org.apache.hadoop.hdfs.DistributedFileSystem$19.doCall(DistributedFileSystem.java:953)
	at org.apache.hadoop.hdfs.DistributedFileSystem$19.doCall(DistributedFileSystem.java:950)
	at org.apache.hadoop.fs.FileSystemLinkResolver.resolve(FileSystemLinkResolver.java:81)
	at org.apache.hadoop.hdfs.DistributedFileSystem.delete(DistributedFileSystem.java:960)
	at org.apache.spark.internal.io.FileCommitProtocol.deleteWithJob(FileCommitProtocol.scala:123)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.deleteMatchingPartitions(InsertIntoHadoopFsRelationCommand.scala:210)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:117)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:104)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:102)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.doExecute(commands.scala:122)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:80)
	at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:80)
	at org.apache.spark.sql.DataFrameWriter$$anonfun$runCommand$1.apply(DataFrameWriter.scala:656)
	at org.apache.spark.sql.DataFrameWriter$$anonfun$runCommand$1.apply(DataFrameWriter.scala:656)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:77)
	at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:656)
	at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:273)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:267)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:225)
	at org.apache.spark.sql.DataFrameWriter.parquet(DataFrameWriter.scala:549)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:745)
Caused by: org.apache.hadoop.ipc.RemoteException(org.apache.hadoop.security.AccessControlException): Permission denied: user=kgerard, access=WRITE, inode="/homes/schmutz":schmutz:hadoop:drwxr-xr-x
	at org.apache.hadoop.hdfs.server.namenode.FSPermissionChecker.check(FSPermissionChecker.java:399)
	at org.apache.hadoop.hdfs.server.namenode.FSPermissionChecker.checkPermission(FSPermissionChecker.java:258)
	at org.apache.hadoop.hdfs.server.namenode.FSPermissionChecker.checkPermission(FSPermissionChecker.java:193)
	at org.apache.hadoop.hdfs.server.namenode.FSDirectory.checkPermission(FSDirectory.java:1857)
	at org.apache.hadoop.hdfs.server.namenode.FSDirDeleteOp.delete(FSDirDeleteOp.java:110)
	at org.apache.hadoop.hdfs.server.namenode.FSNamesystem.delete(FSNamesystem.java:3002)
	at org.apache.hadoop.hdfs.server.namenode.NameNodeRpcServer.delete(NameNodeRpcServer.java:1095)
	at org.apache.hadoop.hdfs.protocolPB.ClientNamenodeProtocolServerSideTranslatorPB.delete(ClientNamenodeProtocolServerSideTranslatorPB.java:692)
	at org.apache.hadoop.hdfs.protocol.proto.ClientNamenodeProtocolProtos$ClientNamenodeProtocol$2.callBlockingMethod(ClientNamenodeProtocolProtos.java)
	at org.apache.hadoop.ipc.ProtobufRpcEngine$Server$ProtoBufRpcInvoker.call(ProtobufRpcEngine.java:524)
	at org.apache.hadoop.ipc.RPC$Server.call(RPC.java:1025)
	at org.apache.hadoop.ipc.Server$RpcCall.run(Server.java:876)
	at org.apache.hadoop.ipc.Server$RpcCall.run(Server.java:822)
	at java.security.AccessController.doPrivileged(Native Method)
	at javax.security.auth.Subject.doAs(Subject.java:422)
	at org.apache.hadoop.security.UserGroupInformation.doAs(UserGroupInformation.java:1730)
	at org.apache.hadoop.ipc.Server$Handler.run(Server.java:2682)

	at org.apache.hadoop.ipc.Client.getRpcResponse(Client.java:1497)
	at org.apache.hadoop.ipc.Client.call(Client.java:1443)
	at org.apache.hadoop.ipc.Client.call(Client.java:1353)
	at org.apache.hadoop.ipc.ProtobufRpcEngine$Invoker.invoke(ProtobufRpcEngine.java:228)
	at org.apache.hadoop.ipc.ProtobufRpcEngine$Invoker.invoke(ProtobufRpcEngine.java:116)
	at com.sun.proxy.$Proxy11.delete(Unknown Source)
	at org.apache.hadoop.hdfs.protocolPB.ClientNamenodeProtocolTranslatorPB.delete(ClientNamenodeProtocolTranslatorPB.java:634)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at org.apache.hadoop.io.retry.RetryInvocationHandler.invokeMethod(RetryInvocationHandler.java:422)
	at org.apache.hadoop.io.retry.RetryInvocationHandler$Call.invokeMethod(RetryInvocationHandler.java:165)
	at org.apache.hadoop.io.retry.RetryInvocationHandler$Call.invoke(RetryInvocationHandler.java:157)
	at org.apache.hadoop.io.retry.RetryInvocationHandler$Call.invokeOnce(RetryInvocationHandler.java:95)
	at org.apache.hadoop.io.retry.RetryInvocationHandler.invoke(RetryInvocationHandler.java:359)
	at com.sun.proxy.$Proxy12.delete(Unknown Source)
	at org.apache.hadoop.hdfs.DFSClient.delete(DFSClient.java:1601)
	... 37 more


In [None]:
vertices.write.parquet('/homes/schmutz/vertices', mode='overwrite')

## Load data from hdfs

In [2]:
vertices = spark.read.parquet('/homes/schmutz/vertices')

edges = spark.read.parquet('/homes/schmutz/edges')

graph = GraphFrame(vertices, edges)

In [3]:
vertices.show(n=5)
edges.show(n=5)

+--------------------+-----------------+---------+
|                  id|              lon|      lat|
+--------------------+-----------------+---------+
|  Stallikon, Loomatt|         8.485294|47.339415|
|  Bonstetten, Lätten|         8.455791|  47.3189|
|Wettswil a.A., Sc...|8.473424000000001|47.327783|
|Birmensdorf ZH, W...|         8.445963|47.356073|
|Stallikon, Langfuren|         8.490171|47.323629|
+--------------------+-----------------+---------+
only showing top 5 rows

+--------------------+--------------------+----+----+-------------+--------------+------------+-----------------+-------+
|                 src|                 dst|type|line|departure_day|departure_time|arrival_time|          lateAvg|lateStd|
+--------------------+--------------------+----+----+-------------+--------------+------------+-----------------+-------+
|Zürich, Carl-Spit...|Zürich, Stodolast...|walk|walk|         null|          null|        null|243.1882193019385|    0.0|
|Zürich, Carl-Spit...|      

# Naive Journey Planner

In [80]:
MINUTES_PER_DAY = 1440
MINUTES_PER_HOUR = 60
SECONDS_PER_MINUTE = 60

def computeLengthInMinutes(departure, arrival):
    dep = (departure[2:]).split(':')
    arr = (arrival[2:]).split(':')
    a = (int(arrival[:1]) - int(departure[:1])) * MINUTES_PER_DAY
    b = (int(arr[0]) - int(dep[0])) * MINUTES_PER_HOUR
    c = (int(arr[1]) - int(dep[1]))
    return a + b + c

def computeTime(start, duration):
    tmp = start.split(':')
    a = int(tmp[0][2:]) * MINUTES_PER_HOUR + int(tmp[1])
    b = duration // SECONDS_PER_MINUTE
    prefix = tmp[0][:2] if a + b < MINUTES_PER_DAY else '1-'
    a = (a + b) % MINUTES_PER_DAY
    minutes = a % MINUTES_PER_HOUR
    hours = (a - minutes) // MINUTES_PER_HOUR
    return prefix + "{:02d}".format(int(hours)) + ':' + "{:02d}".format(int(minutes)) + ':00'

def computeProb(depTime, lateAvg, lateStd, arrTime):
    length = computeLengthInMinutes(depTime, arrTime) * 60
    if lateStd != 0.0:
        return stat.norm(loc=lateAvg, scale=lateStd).cdf(length)
    elif lateAvg <= length:
        return 1.0
    else:
        return 0.0

def computeDiff(departure, arrival):
    dep = (departure[2:]).split(':')
    arr = (arrival[2:]).split(':')
    a = (int(arrival[:1]) - int(departure[:1])) * MINUTES_PER_DAY
    b = (int(arr[0]) - int(dep[0])) * MINUTES_PER_HOUR
    c = (int(arr[1]) - int(dep[1]))
    tot = a + b + c
    hours = tot // MINUTES_PER_HOUR
    minutes = tot % MINUTES_PER_HOUR
    return "{:02d}".format(int(hours)) + ':' + "{:02d}".format(int(minutes)) + ':00'

def computeCost(cost, late):
    tmp = cost.split(':')
    a = int(tmp[0][2:]) * MINUTES_PER_HOUR + int(tmp[1])
    b = late // SECONDS_PER_MINUTE
    prefix = tmp[0][:2] if a > b else '0-'
    a = (a - b) % MINUTES_PER_DAY
    minutes = a % MINUTES_PER_HOUR
    hours = (a - minutes) // MINUTES_PER_HOUR
    return prefix + "{:02d}".format(int(hours)) + ':' + "{:02d}".format(int(minutes)) + ':00'

def getFilteredEdges(startDay, finishDay, startTime, finishTime, duration):
    def valid(day, depTime, arrTime, walkTime):
        if startDay==finishDay:
            return ((day=='null') & (walkTime<=duration)) | \
                    ((day==startDay) & (depTime>=startTime) & (arrTime<=finishTime) & (depTime<=arrTime))
        else:
            return ((day=='null') & (walkTime<=duration)) | \
                    (((day==startDay) & (depTime>=startTime) & ((depTime<=arrTime) | (arrTime<=finishTime))) | \
                     ((day==finishDay) & (depTime<finishTime) & (arrTime<=finishTime)))

    return graph.filterEdges(valid(graph.edges.departure_day, 
                                graph.edges.departure_time,
                                graph.edges.arrival_time,
                                graph.edges.lateAvg)).edges


def add_vertice_to_set(max_set, vertice, vertice_costs, edges, next_vertices, certain_path, earliest):
    
    max_set.add(vertice)
    cost = vertice_costs[vertice]
    
    if earliest:
        vertice_edges = edges.out_edges(vertice, data=True)

        for parallel_paths in vertice_edges:
            edge = parallel_paths[2]
            if edge['type'] == 'walk':
                new_cost = computeCost(cost, -edge['lateAvg'])
                if (vertice not in next_vertices or next_vertices[vertice]['type'] != 'walk') and (edge['dst'] not in vertice_costs or new_cost < vertice_costs[edge['dst']]):
                    next_vertices[edge['dst']] = edge
                    vertice_costs[edge['dst']] = new_cost
            elif edge['departure_time'] > cost and \
                (edge['dst'] not in vertice_costs or edge['arrival_time'] < vertice_costs[edge['dst']]):
                if (not certain_path) or computeProb(edge['departure_time'],  edge['lateAvg'], edge['lateStd'], cost) == 1:
                    vertice_costs[edge['dst']] = edge['arrival_time']
                    next_vertices[edge['dst']] = edge
    else:
        vertice_edges = edges.in_edges(vertice, data=True)

        for parallel_paths in vertice_edges:
            edge = parallel_paths[2]
            if edge['type'] == 'walk':
                new_cost = computeCost(cost, edge['lateAvg'])
                if (vertice not in next_vertices or next_vertices[vertice]['type'] != 'walk') and (edge['src'] not in vertice_costs or new_cost > vertice_costs[edge['src']]):
                    next_vertices[edge['src']] = edge
                    vertice_costs[edge['src']] = new_cost
            elif edge['arrival_time'] < cost and \
                (edge['src'] not in vertice_costs or edge['departure_time'] > vertice_costs[edge['src']]):
                if (not certain_path) or computeProb(edge['arrival_time'],  edge['lateAvg'], edge['lateStd'], cost) == 1:
                    vertice_costs[edge['src']] = edge['departure_time']
                    next_vertices[edge['src']] = edge

def get_max_vertice_not_in_set(max_set, vertice_costs, min_trip_departure_time):
    max_vertice = None
    max_cost = min_trip_departure_time
    for vertice in vertice_costs:
        if vertice not in max_set and vertice_costs[vertice] > max_cost:
            max_cost = vertice_costs[vertice]
            max_vertice = vertice
    
    return max_vertice

def get_min_vertice_not_in_set(max_set, vertice_costs, min_trip_departure_time):
    max_vertice = None
    max_cost = min_trip_departure_time
    for vertice in vertice_costs:
        if vertice not in max_set and vertice_costs[vertice] < max_cost:
            max_cost = vertice_costs[vertice]
            max_vertice = vertice
    
    return max_vertice

def find_path(next_vertices, current_vertice, current_path, direction):
    if current_vertice not in next_vertices:
        return current_path
    next_vertice = next_vertices[current_vertice][direction]
    current_path.append(next_vertices[current_vertice])
    return find_path(next_vertices, next_vertice, current_path, direction)
    

def find_shortest_path(departure_station, arrival_station, 
                       startDateTime, endDateTime, 
                       min_probability_of_sucess, get_all_destinations=False,
                       subgraph=None, certain_path=False, earliest=False):
    
    startTime  = str(startDateTime.time())
    endTime = str(endDateTime.time())

    startDay  = calendar.day_name[startDateTime.weekday()]
    endDay = calendar.day_name[endDateTime.weekday()]
    
    min_trip_departure_time = '0-' + startTime
    
    endTimePrefix = '0-' if (startDay == endDay) else '1-'
    requested_arrival_time = endTimePrefix + endTime
    
    duration = (endDateTime - startDateTime).seconds
    
    if subgraph is None:
        filtered_edges = getFilteredEdges(startDay, endDay, startTime, endTime, duration)
        
        filtered_edges = filtered_edges.toPandas()

        def to_dt(time):
            if time == 'null':
                return 'null'
            elif time >= startTime:
                return '0-' + time
            else:
                return '1-' + time

        filtered_edges['departure_time'] = filtered_edges['departure_time'].map(lambda x: to_dt(x))
        filtered_edges['arrival_time']   = filtered_edges['arrival_time'].map(lambda x: to_dt(x))
        
        G = nx.from_pandas_edgelist(filtered_edges, 'src', 'dst', edge_attr=True, create_using=nx.MultiDiGraph())
    else:
        G = subgraph
        
    # as day#-hh-mm-ss
    vertice_costs = {}
    max_set = set()
    next_vertices = {}
    if earliest:
        vertice_costs[departure_station] = min_trip_departure_time
        target = arrival_station
        add_vertice_to_set(max_set, departure_station, vertice_costs, G, next_vertices, certain_path, earliest)
        direction = 'src'
    else:
        vertice_costs[arrival_station] = requested_arrival_time
        target = departure_station
        add_vertice_to_set(max_set, arrival_station, vertice_costs, G, next_vertices, certain_path, earliest)
        direction= 'dst'

    no_solution = False
    
    while((target not in max_set or get_all_destinations) and not no_solution):
        if earliest:
            max_vertice = get_min_vertice_not_in_set(max_set, vertice_costs, requested_arrival_time)
        else:
            max_vertice = get_max_vertice_not_in_set(max_set, vertice_costs, min_trip_departure_time)
        if max_vertice is None:
            no_solution = True
        else:
            add_vertice_to_set(max_set, max_vertice, vertice_costs, G, next_vertices, certain_path, earliest)
    
    if get_all_destinations:
        return vertice_costs
    if no_solution:
        return "no solution"
    departure_time = vertice_costs[departure_station]
    
    if earliest:
        trip_duration = computeDiff(min_trip_departure_time, vertice_costs[target])
    else:
        trip_duration = computeDiff(vertice_costs[target], requested_arrival_time)
    return departure_time, trip_duration, find_path(next_vertices, target, [target], direction)

In [81]:
fromStation = 'Kilchberg'
toStation   = 'Urdorf, Schlierenstrasse'
startDateTime = datetime(2019, 6, 4, 18, 0)
endDateTime   = datetime(2019, 6, 4, 19, 57)

res = find_shortest_path(fromStation, toStation, 
                   startDateTime, 
                   endDateTime, 0, get_all_destinations=False, certain_path=False, earliest=True)
res

('0-18:00:00',
 '00:51:00',
 ['Urdorf, Schlierenstrasse',
  {'arrival_time': 'null',
   'departure_day': 'null',
   'departure_time': 'null',
   'dst': 'Urdorf, Schlierenstrasse',
   'lateAvg': 635.0863749675257,
   'lateStd': 0.0,
   'line': 'walk',
   'src': 'Glanzenberg',
   'type': 'walk'},
  {'arrival_time': '0-18:40:00',
   'departure_day': 'Tuesday',
   'departure_time': '0-18:29:00',
   'dst': 'Glanzenberg',
   'lateAvg': 156.66666666666666,
   'lateStd': 107.21062282749234,
   'line': 'S3',
   'src': 'Zürich HB',
   'type': 'Zug'},
  {'arrival_time': '0-18:23:00',
   'departure_day': 'Tuesday',
   'departure_time': '0-18:09:00',
   'dst': 'Zürich HB',
   'lateAvg': 6.666666666666667,
   'lateStd': 28.284271247461902,
   'line': 'S8',
   'src': 'Kilchberg',
   'type': 'Zug'}])

In [6]:
MINUTES_PER_DAY = 1440
MINUTES_PER_HOUR = 60
SECONDS_PER_MINUTE = 60

MAX_PATH_LENGTH = 4

def compute_dep_time(curr_time, curr_path, edge=None):
    if len(curr_path) == 1:
        dep = curr_time if edge is None else edge['departure_time']
    elif curr_path[1]['type'] == 'walk':
        if edge is not None and len(curr_path) == 2:
            dep = computeCost(edge['departure_time'], curr_path[1]['lateAvg'])
        elif len(curr_path) > 2:
            dep = computeCost(curr_path[2]['departure_time'], curr_path[1]['lateAvg'])
        else:
            dep = curr_time
    else:
        dep = curr_path[1]['departure_time']
    
    return dep

def compute_paths_between(src, dst, edges, visited, curr_path, 
                          curr_prob, curr_time, curr_lateAvg, curr_lateStd, 
                          min_trip_departure_time, requested_arrival_time, 
                          paths, last_line_taken, max_dep_times, min_prob_success, min_duration, mode):
    visited.add(src)
    
    if src == dst:
        final_prob = computeProb(curr_time, curr_lateAvg, curr_lateStd, requested_arrival_time) * curr_prob
        if final_prob >= min_prob_success:
            final_path = curr_path.copy()
            final_path.append(curr_time)
            final_path.append(final_prob)
            
            dep = compute_dep_time(min_trip_departure_time, final_path[:-2], None)
            duration = computeDiff(dep, final_path[-2])
            if mode == 'both':
                if duration < min_duration['min']:
                    min_duration['min'] = duration
            elif mode == 'arrival':
                if dep > min_duration['dep']:
                    min_duration['dep'] = dep
            
            paths.append(final_path)
            
    elif len(curr_path) < MAX_PATH_LENGTH:
        vertice_edges = edges.out_edges(src, data=True)
        for vertice_edge in vertice_edges:
            edge = vertice_edge[2]
            if edge['dst'] not in visited and edge['line'] != last_line_taken:
                
                if edge['type'] == 'walk':
                    new_time = computeTime(curr_time, edge['lateAvg'])
                    
                    dep = compute_dep_time(curr_time, curr_path)
                    duration = computeDiff(dep, new_time)
                    
                    if (mode == 'both' and duration <= min_duration['min']) or \
                       (mode == 'arrival'):
                        if new_time <= requested_arrival_time and \
                           edge['dst'] in max_dep_times and new_time <= max_dep_times[edge['dst']]:

                            curr_path.append(edge)
                            compute_paths_between(edge['dst'], dst, edges, visited, curr_path, 
                                                  curr_prob, new_time, curr_lateAvg, curr_lateStd, 
                                                  min_trip_departure_time, requested_arrival_time, paths, 
                                                  edge['line'], max_dep_times, min_prob_success, min_duration, mode)
                            curr_path.pop();
                        
                elif edge['departure_time'] >= curr_time and edge['dst'] in max_dep_times and \
                     edge['arrival_time'] <= max_dep_times[edge['dst']]:
                        
                    dep = compute_dep_time(curr_time, curr_path, edge = edge)
                    duration = computeDiff(dep, edge['arrival_time'])
                    
                    prob = computeProb(curr_time, curr_lateAvg, curr_lateStd, edge['departure_time'])
                    new_prob = curr_prob * prob
                    
                    if (mode == 'both' and duration <= min_duration['min']) or \
                       (mode == 'arrival' and dep >= min_duration['dep']):
                        if new_prob >= min_prob_success:
                            curr_path.append(edge)
                            compute_paths_between(edge['dst'], dst, edges, visited, curr_path, 
                                                  new_prob, edge['arrival_time'], edge['lateAvg'], edge['lateStd'],
                                                  min_trip_departure_time, requested_arrival_time, paths, 
                                                  edge['line'], max_dep_times, min_prob_success, min_duration, mode)
                            curr_path.pop();
        
    visited.remove(src)
    
    
def dfs(departure_station, arrival_station, 
        startDateTime, endDateTime, 
        min_probability_of_sucess, mode='both'):
    
    startTime  = str(startDateTime.time())
    endTime = str(endDateTime.time())

    startDay  = calendar.day_name[startDateTime.weekday()]
    endDay = calendar.day_name[endDateTime.weekday()]
    
    min_trip_departure_time = '0-' + startTime
    
    endTimePrefix = '0-' if (startDay == endDay) else '1-'
    requested_arrival_time = endTimePrefix + endTime
    
    duration = (endDateTime - startDateTime).seconds
    
    filtered_edges = getFilteredEdges(startDay, endDay, startTime, endTime, duration).toPandas()
    
    def to_dt(time):
        if time == 'null':
            return 'null'
        elif time >= startTime:
            return '0-' + time
        else:
            return '1-' + time

    
    filtered_edges['departure_time'] = filtered_edges['departure_time'].map(lambda x: to_dt(x))
    filtered_edges['arrival_time']   = filtered_edges['arrival_time'].map(lambda x: to_dt(x))
    
    G = nx.from_pandas_dataframe(filtered_edges, 'src', 'dst', edge_attr=True, create_using=nx.MultiDiGraph())
    
    max_dep_times = find_shortest_path(departure_station, arrival_station, 
                                       startDateTime, 
                                       endDateTime, min_probability_of_sucess, 
                                       get_all_destinations=True, 
                                       subgraph=G)
    
    fastest_certain_path = find_shortest_path(fromStation, toStation, 
                                              startDateTime, endDateTime, 0, 
                                              get_all_destinations=False, certain_path=True,
                                              subgraph=G)
    
    visited = set()
    curr_time = min_trip_departure_time
    curr_path = [departure_station]
    paths = []
    
    if fastest_certain_path != "no solution":
        if mode == 'both':
            min_duration = {'min': fastest_certain_path[1]}
        elif mode == 'arrival':
            min_duration = {'dep': fastest_certain_path[0]}
    else:
        if mode == 'both':
            min_duration = {'min': '24:00:00'}
        elif mode == 'arrival':
            min_duration = {'dep': min_trip_departure_time}
    
    compute_paths_between(departure_station, arrival_station, G, 
                          visited, curr_path, 1.0, curr_time, 0.0, 0.0, min_trip_departure_time, 
                          requested_arrival_time, paths, '', max_dep_times, 
                          min_probability_of_sucess, min_duration, mode)
    
    if len(paths) == 0:
        return {'departure time' : '', 'arrival_time' : '', 
                'duration' : '', 'path': []}
    
    times = [computeDiff(compute_dep_time(min_trip_departure_time, path[:-2], None), path[-2]) for path in paths]
    dep_times = [compute_dep_time(min_trip_departure_time, path[:-2], None) for path in paths]
    
    if mode == 'both':
        best_path_idx = np.argmin(times)
    elif mode == 'arrival':
        best_path_idx = np.argmax(dep_times)
    
    best_path = paths[best_path_idx]
    
    path_edges = best_path[1:-2]
    
    # Compute departure and arrival time for walk edges and removing unnecessary data
    for idx, edge in enumerate(path_edges):
        if edge['type'] == 'walk':
            if idx == 0:
                if len(path_edges) == 1:
                    edge['arrival_time'] = requested_arrival_time
                else:
                    edge['arrival_time'] = path_edges[idx + 1]['departure_time']
                    
                edge['departure_time'] = computeCost(edge['arrival_time'], edge['lateAvg'])
            else:
                edge['departure_time'] = path_edges[idx - 1]['arrival_time']
                edge['arrival_time'] = computeTime(edge['departure_time'], edge['lateAvg'])
            
            edge.pop('line');
                
        edge.pop('departure_day');
        edge.pop('lateAvg');
        edge.pop('lateStd');
    
    # Remove prefix from edges
    for edge in path_edges:
        edge['departure_time'] = edge['departure_time'][2:]
        edge['arrival_time'] = edge['arrival_time'][2:]
    
    departure_time = path_edges[0]['departure_time']
    arrival_time = path_edges[-1]['arrival_time']
    
    return {'departure time' : departure_time, 'arrival_time' : arrival_time, 
            'duration' : times[best_path_idx], 'path': path_edges}

In [224]:
fromStation = 'Kilchberg'
toStation   = 'Urdorf, Schlierenstrasse'
startDateTime = datetime(2019, 6, 4, 18, 20)
endDateTime   = datetime(2019, 6, 4, 19, 57)

res = dfs(fromStation, toStation, 
          startDateTime, 
          endDateTime, 0.95, mode='both')
res

{'arrival_time': '19:05:00',
 'departure time': '18:26:00',
 'duration': '00:39:00',
 'path': [{'arrival_time': '18:39:00',
   'departure_time': '18:26:00',
   'dst': 'Zürich HB',
   'line': 'S24',
   'src': 'Kilchberg',
   'type': 'Zug'},
  {'arrival_time': '18:55:00',
   'departure_time': '18:44:00',
   'dst': 'Glanzenberg',
   'line': 'S12',
   'src': 'Zürich HB',
   'type': 'Zug'},
  {'arrival_time': '19:05:00',
   'departure_time': '18:55:00',
   'dst': 'Urdorf, Schlierenstrasse',
   'src': 'Glanzenberg',
   'type': 'walk'}]}

# ARCHIVED CODE FOR POSSIBLE REUSE

In [62]:
def getSubGraph(graph, startDay, finishDay, startTime, finishTime, duration):
    def valid(day, depTime, arrTime, walkTime):
        if startDay==finishDay:
            return ((day=='null') & (walkTime<=duration)) | \
                    ((day==startDay) & (depTime>=startTime) & (arrTime<=finishTime) & (depTime<=arrTime))
        else:
            return ((day=='null') & (walkTime<=duration)) | \
                    (((day==startDay) & (depTime>=startTime) & ((depTime<=arrTime) | (arrTime<=finishTime))) | \
                     ((day==finishDay) & (depTime<finishTime) & (arrTime<=finishTime)))

    return graph.filterEdges(valid(graph.edges.departure_day, 
                                graph.edges.departure_time,
                                graph.edges.arrival_time,
                                graph.edges.lateAvg))  \
                .dropIsolatedVertices()

def howFarNaive(graph, fromStation, startDateTime, duration):
    
    if duration >= 120:
        print('You can walk anywhere in that time')
        return
    
    finishDateTime = startDateTime + timedelta(minutes=duration)

    startTime  = str(startDateTime.time())
    finishTime = str(finishDateTime.time())

    startDay  = calendar.day_name[startDateTime.weekday()]
    finishDay = calendar.day_name[finishDateTime.weekday()]

    print(startDay, startTime)
    print(finishDay, finishTime)
    
    @functions.udf
    def addTime(arr_time, dep_time, late):
        if arr_time=='null':
            tmp = dep_time.split(':')
            return str((datetime.combine(date.today(), dt.time(int(tmp[0]), int(tmp[1]), int(tmp[2]))) + 
                        timedelta(seconds=int(late))).time())
        else:
            return arr_time
    
    @functions.udf
    def checkDay(day, dep_time, arr_time):
        return finishDay if arr_time<dep_time else day
    
    @functions.udf
    def checkWalk(ttype):
        return 1 if ttype=='walk' else 0
    
    @functions.udf
    def checkIfValid(arr_time, day):
        tmp = arr_time.split(':')
        arr_date = startDateTime.date() if day==calendar.day_name[startDateTime.weekday()] else finishDateTime.date()
        arrival = datetime.combine(arr_date, dt.time(int(tmp[0]), int(tmp[1]), int(tmp[2])))
        return arrival < finishDateTime
    
    reachable = vertices.filter(vertices.id==fromStation)             \
                        .withColumn('time', functions.lit(startTime)) \
                        .withColumn('day', functions.lit(startDay))   \
                        .withColumn('just_walked', functions.lit(0))
    
    g = getSubGraph(graph, startDay, finishDay, startTime, finishTime, 60*duration)
    g.persist();
    g.edges.createOrReplaceTempView('edges')
    g.vertices.createOrReplaceTempView('vertices')
    
    curr = reachable
    
    #while len(curr.head(1)) > 0:
    for i in range(1):
        curr.createOrReplaceTempView('curr')

        query = 'SELECT v.*, c.time AS past_time, c.just_walked, c.day, e.type,          \
                        e.departure_time, e.arrival_time, e.lateAvg                      \
                 FROM curr AS c INNER JOIN edges AS e INNER JOIN vertices AS v           \
                 ON c.id==e.src                                                          \
                 AND e.dst==v.id                                                         \
                 WHERE (e.type!=\'walk\' OR c.just_walked==0)                            \
                 AND (e.type==\'walk\'                                                   \
                 OR (e.departure_time>=c.time AND c.day==e.departure_day)                \
                 OR (e.departure_time<c.time AND c.day!=e.departure_day))'

        curr = spark.sql(query).withColumn('time', addTime('arrival_time', 'past_time', 'lateAvg')) \
                               .withColumn('day', checkDay('day', 'past_time', 'time'))             \
                               .filter(checkIfValid('time', 'day')=='true')                         \
                               .withColumn('just_walked', checkWalk('type'))                        \
                               .select('id', 'lon', 'lat', 'time', 'day', 'just_walked')
        curr.persist()
        reachable = reachable.union(curr)
    
    @functions.udf
    def computeRadius(arr_time, day):
        tmp = arr_time.split(':')
        arr_date = startDateTime.date() if day==calendar.day_name[startDateTime.weekday()] else finishDateTime.date()
        arrival = datetime.combine(arr_date, dt.time(int(tmp[0]), int(tmp[1]), int(tmp[2])))
        return (finishDateTime - arrival).seconds * 5 // 3.6
    
    reachable = reachable.withColumn('radius', computeRadius('time', 'day')) \
                         .select('id', 'lon', 'lat', 'radius', 'time', 'just_walked')                     \
                         .toPandas()
    
    g.unpersist();
    
    return reachable

In [65]:
graph = GraphFrame(vertices, edges)
fromStation = 'Dietlikon'
startDateTime  = datetime(2019, 5, 31, 23, 45)
duration = 11

start = time.time()
reachable = howFarNaive(graph, fromStation, startDateTime, duration)
print(time.time() - start)
reachable

Friday 23:45:00
Friday 23:56:00
9.501955032348633


Unnamed: 0,id,lon,lat,radius,time,just_walked
0,Dietlikon,8.619255,47.420195,916.0,23:45:00,0
1,"Dietlikon, Bahnhof",8.619087,47.420359,895.0,23:45:15,1
2,"Brüttisellen, Gsellhof",8.629667,47.421708,116.0,23:54:36,1
3,"Dietlikon, Bahnhof/Bad",8.620805,47.42196,688.0,23:47:44,1
4,"Dietlikon, Brandbachstrasse",8.624657,47.416965,374.0,23:51:30,1
5,"Dietlikon, Dornenstrasse",8.616896,47.416585,477.0,23:50:16,1
6,"Dietlikon, Dübendorferstrasse",8.619558,47.413509,173.0,23:53:55,1
7,"Dietlikon, Hofwiesen",8.618518,47.426391,226.0,23:53:17,1
8,"Dietlikon, Industriestrasse",8.621946,47.414249,226.0,23:53:17,1
9,"Dietlikon, Fuchshalde",8.612612,47.420029,416.0,23:51:00,1
