In [1]:
import sys

MODULES_PATH = "../../modules"
if MODULES_PATH not in sys.path:
    sys.path.append(MODULES_PATH)
    
from utils import get_env_vars, setup_spark    

%reload_ext autoreload
%autoreload 2

In [2]:
%reload_ext sparkmagic.magics

Cleaning up livy sessions on exit is enabled


In [3]:
username, _, _ = get_env_vars()

In [4]:
setup_spark(username)

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
7472,application_1680948035106_6865,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


In [5]:
%%spark
usernames = ["boukil", "hdasilva", "elmalki", "ouerghem", "berkane", "sly"]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [6]:
%%spark
print('We are using Spark %s' % spark.version)

import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.functions import (
    col, lit, concat,
    to_date, date_format, to_timestamp, unix_timestamp, dayofweek,
    when, isnan,
    count, lower, trim,
    toRadians, sin, cos, sqrt, asin,
    lead
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

We are using Spark 2.4.8.7.1.8.0-801

#### Preprocess stops

In [7]:
%%spark
import numpy as np

EARTH_RADIUS = 6371 * 1e3 # in meters
def haversine_dist(
    x_lat,
    x_lon,
    y_lat,
    y_lon, 
    in_radian =False):
    # link to the formula: https://fr.wikipedia.org/wiki/Formule_de_haversine
    
    # convert to radians if needed
    if not in_radian:
        x_lon = toRadians(x_lon)
        x_lat = toRadians(x_lat)
        y_lon = toRadians(y_lon)
        y_lat = toRadians(y_lat)
    
    sin2_lat = F.pow(sin(0.5*(y_lat - x_lat)), 2)
    sin2_lon = F.pow(sin(0.5*(y_lon - x_lon)), 2)
    cos_lat_prod = cos(y_lat) * cos(x_lat)
    
    # distance in meters
    dist = 2 * EARTH_RADIUS * asin(sqrt(sin2_lat + cos_lat_prod * sin2_lon))
    return dist

# load all the stops
stops = spark.read.format("orc").load("/data/sbb/orc/allstops").dropDuplicates()

# drop null stops
stops = stops.dropna(how="any", subset=["stop_id"])

# lowercase the stop names
stops = stops.withColumn("stop_name", trim(lower(col("stop_name"))))

# only keep stops within 15km (=15'000m) of Zurich
ZURICH_TRAIN_STATION_LAT = 47.378177
ZURICH_TRAIN_STATION_LON = 8.540192
stops = stops.withColumn("dist_to_zurich_train_station", 
                         haversine_dist(F.col("stop_lat"), F.col("stop_lon"), F.lit(ZURICH_TRAIN_STATION_LAT), F.lit(ZURICH_TRAIN_STATION_LON)))
stops = stops.where("dist_to_zurich_train_station <= 16000")

# drop useless columns
stops = stops.drop("location_type", "dist_to_zurich_train_station")

# caching
stops = stops.cache()

# count
print("final number of stops =", stops.count())

stops.printSchema()
stops.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

final number of stops = 2260
root
 |-- stop_id: string (nullable = true)
 |-- stop_name: string (nullable = true)
 |-- stop_lat: double (nullable = true)
 |-- stop_lon: double (nullable = true)
 |-- parent_station: string (nullable = true)

+-----------+-------------------+----------------+----------------+--------------+
|    stop_id|          stop_name|        stop_lat|        stop_lon|parent_station|
+-----------+-------------------+----------------+----------------+--------------+
|8503015:0:1|   zürich wipkingen|47.3930409564212|8.52937785976238| Parent8503015|
|    8503152|uster, brandschenke|47.3611410570261|8.71073873249851|              |
+-----------+-------------------+----------------+----------------+--------------+
only showing top 2 rows

#### Preprocess stop times

In [8]:
%%spark
# most recent stop times file
stop_times = spark.read.format("orc").load("/data/sbb/part_orc/timetables/stop_times/year=2023/month=3/day=29")

# filter times that are not in reasonable work hours (time should be between 6am included and 10pm excluded)
stop_times = stop_times.where(
    """
    departure_time >= '06:00:00' AND departure_time < '22:00:00'
    AND arrival_time >= '06:00:00' AND arrival_time < '22:00:00'""")

# keep only stop times of the stops in our area
stop_times = stop_times.join(stops, on="stop_id", how="leftsemi")

stop_times = stop_times.withColumn("arrival_time", to_timestamp(col("arrival_time"), "HH:mm:ss"))
stop_times = stop_times.withColumn("arrival_time", date_format(col("arrival_time"), "HH:mm:ss"))

stop_times = stop_times.withColumn("departure_time", to_timestamp(col("departure_time"), "HH:mm:ss"))
stop_times = stop_times.withColumn("departure_time", date_format(col("departure_time"), "HH:mm:ss"))

stop_times = stop_times.withColumn('stop_sequence', col('stop_sequence').cast('int'))\
    .withColumnRenamed("stop_sequence", "stop_seqnum")

# drop unused columns
stop_times = stop_times.drop("pickup_type", "drop_off_type")

# drop duplicates if any
stop_times = stop_times.dropDuplicates()

# cache
stop_times = stop_times.cache()

# count
print("final number of stop times (all trips combined) =", stop_times.count())

stop_times.printSchema()
stop_times.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

final number of stop times (all trips combined) = 1666786
root
 |-- stop_id: string (nullable = true)
 |-- trip_id: string (nullable = true)
 |-- arrival_time: string (nullable = true)
 |-- departure_time: string (nullable = true)
 |-- stop_seqnum: integer (nullable = true)

+-----------+--------------------+------------+--------------+-----------+
|    stop_id|             trip_id|arrival_time|departure_time|stop_seqnum|
+-----------+--------------------+------------+--------------+-----------+
|8573726:0:A|490.TA.96-199-4-j...|    17:32:00|      17:32:00|          1|
|    8573713|493.TA.96-199-4-j...|    21:03:00|      21:03:00|          2|
+-----------+--------------------+------------+--------------+-----------+
only showing top 2 rows

#### Preprocess calendar

In [10]:
%%spark
# most recent calendar file (needs to be filtered a bit)
calendar = spark.read.format("orc").load("/data/sbb/part_orc/timetables/calendar/year=2023/month=3/day=29")

# drop null columns
calendar = calendar.dropna(how="any", subset=["service_id", "monday", "tuesday", "wednesday", "thursday", "friday"])

# format columns
calendar = calendar.withColumn("start_date", to_date(col("start_date"),"yyyyMMdd"))
calendar = calendar.withColumn("end_date", to_date(col("end_date"),"yyyyMMdd"))

# only include services that are valid for all days between monday and friday
calendar = calendar.filter((col("monday") == 'TRUE') & (col("tuesday") == 'TRUE') 
                           & (col("wednesday") == 'TRUE') & (col("thursday") == 'TRUE') & (col("friday") == 'TRUE'))

# drop days columns
calendar = calendar.drop("monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", 
                         "start_date", "end_date")

# drop duplicates if any
calendar = calendar.dropDuplicates(["service_id"])

# caching
calendar = calendar.cache()

# count
print("final number of services =", calendar.count())

calendar.printSchema()
calendar.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

final number of services = 12263
root
 |-- service_id: string (nullable = true)

+----------+
|service_id|
+----------+
|  TA+1ax20|
|  TA+23520|
+----------+
only showing top 2 rows

#### Preprocess trips

In [11]:
%%spark
# most recent trips file
trips = spark.read.format("orc").load("/data/sbb/part_orc/timetables/trips/year=2023/month=3/day=29")

# drop null trip_id and service_id columns, as we cannot use those trips
trips = trips.dropna(how="any", subset=["trip_id", "service_id"])

# drop duplicates
trips = trips.dropDuplicates(["trip_id"])

# drop useless columns
trips = trips.drop("trip_short_name")

# keep trips with service_id in calendar
trips = trips.join(calendar, on="service_id", how="leftsemi")

# caching
trips.cache()

# count
print("final number of trips =", trips.count())

trips.printSchema()
trips.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

final number of trips = 643351
root
 |-- service_id: string (nullable = true)
 |-- route_id: string (nullable = true)
 |-- trip_id: string (nullable = true)
 |-- trip_headsign: string (nullable = true)
 |-- direction_id: string (nullable = true)

+----------+-------------+--------------------+------------------+------------+
|service_id|     route_id|             trip_id|     trip_headsign|direction_id|
+----------+-------------+--------------------+------------------+------------+
|     TA+gz| 91-4-A-j23-1|1.TA.91-4-A-j23-1...|     Zürich HB SZU|           0|
|  TA+3hg00|91-4A-Y-j23-1|1.TA.91-4A-Y-j23-...|Genova P. Principe|           0|
+----------+-------------+--------------------+------------------+------------+
only showing top 2 rows

#### Trips with Stop Times

In [12]:
%%spark
trips_with_stop_times = trips.join(stop_times, on="trip_id", how="inner")
trips_with_stop_times = trips_with_stop_times.cache()

# count
print("final number of edges (excluding walkable edges) =", trips_with_stop_times.count())

trips_with_stop_times.printSchema()
trips_with_stop_times.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

final number of edges (excluding walkable edges) = 387220
root
 |-- trip_id: string (nullable = true)
 |-- service_id: string (nullable = true)
 |-- route_id: string (nullable = true)
 |-- trip_headsign: string (nullable = true)
 |-- direction_id: string (nullable = true)
 |-- stop_id: string (nullable = true)
 |-- arrival_time: string (nullable = true)
 |-- departure_time: string (nullable = true)
 |-- stop_seqnum: integer (nullable = true)

+--------------------+----------+------------+-------------+------------+------------+------------+--------------+-----------+
|             trip_id|service_id|    route_id|trip_headsign|direction_id|     stop_id|arrival_time|departure_time|stop_seqnum|
+--------------------+----------+------------+-------------+------------+------------+------------+--------------+-----------+
|1.TA.91-4-A-j23-1...|     TA+gz|91-4-A-j23-1|Zürich HB SZU|           0| 8503091:0:1|    17:56:00|      17:56:00|          8|
|1.TA.91-4-A-j23-1...|     TA+gz|91-4-A-j23-1

#### Preprocess istdaten

In [13]:
%%spark
istdaten = spark.read.format("orc").load("/data/sbb/part_orc/istdaten")

cols_deu2eng = {
    "haltestellen_name": "stop_name",
    
    "betriebstag": "trip_date",

    "produkt_id": "transport_type",
    "verkehrsmittel_text": "transport_subtype",
    "linien_text": "route_name",
    
    "an_prognose": "real_arrival_time",
    "ab_prognose": "real_departure_time",
    
    "ankunftszeit": "sch_arrival_time",
    "abfahrtszeit": "sch_departure_time",
    
    "durchfahrt_tf": "does_not_stop_here",
    "zusatzfahrt_tf": "irregular",
    "faellt_aus_tf": "failed"
}

# select translated columns, other columns are not useful
istdaten = istdaten.select(*cols_deu2eng.keys())

# rename columns with english names
istdaten_translated = istdaten
for col_deu, col_eng in cols_deu2eng.items():
    istdaten_translated = istdaten_translated.withColumnRenamed(col_deu, col_eng)
    
# lower casing some entries
istdaten_translated = istdaten_translated.withColumn("stop_name", trim(lower(col("stop_name"))))
istdaten_translated = istdaten_translated.withColumn("transport_type", trim(lower(col("transport_type"))))
istdaten_translated = istdaten_translated.withColumn("transport_subtype", trim(lower(col("transport_subtype"))))
istdaten_translated = istdaten_translated.withColumn("route_name", trim(lower(col("route_name"))))

# only three types of transport means: train (zug), bus (bus) or tram (tram, standseilbahn)
train_type_ids = ['en', 'rj', 'ic', 'ag', 're', 'ext', 'ir', 'rb', 'mat', 'te2', 'ice', 'ec', 'cis', 
                'rjx', 'nj', 'zug', 'p', 'd', 'flx', 'ter', 'pe', 'bex', 'nz', 'mp', 'gex', 'ire', 
                'sn', 'at', 'vae', 'r', '', 'atz', 'arz', 'tgv', 's']
bus_type_ids = ['nb', 'bus', 'exb', 'bn', 'nfb', 'ev', 'rub', 'kb', 'b', 'car', 't']
istdaten_translated = istdaten_translated.withColumn(
    "transport_type",
    when((col("transport_type") == "zug") | (col("transport_subtype").isin(train_type_ids)), "train")\
    .when((col("transport_type") != "train") 
          & ((col("transport_type") == "bus") | (col("transport_subtype").isin(bus_type_ids))), "bus")\
    .otherwise("tram")
)

# caching
istdaten_translated = istdaten_translated.cache()

istdaten_translated.printSchema()
istdaten_translated.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- stop_name: string (nullable = true)
 |-- trip_date: string (nullable = true)
 |-- transport_type: string (nullable = false)
 |-- transport_subtype: string (nullable = true)
 |-- route_name: string (nullable = true)
 |-- real_arrival_time: string (nullable = true)
 |-- real_departure_time: string (nullable = true)
 |-- sch_arrival_time: string (nullable = true)
 |-- sch_departure_time: string (nullable = true)
 |-- does_not_stop_here: string (nullable = true)
 |-- irregular: string (nullable = true)
 |-- failed: string (nullable = true)

+--------------------+----------+--------------+-----------------+----------+-------------------+-------------------+----------------+------------------+------------------+---------+------+
|           stop_name| trip_date|transport_type|transport_subtype|route_name|  real_arrival_time|real_departure_time|sch_arrival_time|sch_departure_time|does_not_stop_here|irregular|failed|
+--------------------+----------+--------------+-----------------+-

In [14]:
%%spark
from pyspark.sql.functions import (
    dayofweek, from_unixtime, unix_timestamp, when, col, hour
    )

# FILTERING ISTDATEN

# drop irregular trips
# drop trips where transport does not stop in that stop
# drop trips that failed
# finally, drop these columns as we don't use need them anymore
istdaten_filtered = istdaten_translated.filter(col("irregular") == "false").drop("irregular")\
    .filter(col("does_not_stop_here") == "false").drop("does_not_stop_here")\
    .filter(col("failed") == "false").drop("failed")

# drop trips where arrival time is null or empty, since we cannot use it to estimate arrival delay
istdaten_filtered = istdaten_filtered.dropna(how="any", subset=["sch_arrival_time", "real_arrival_time"])\
    .withColumn("sch_arrival_time", trim(col("sch_arrival_time")))\
    .withColumn("real_arrival_time", trim(col("real_arrival_time")))\
    .filter("sch_arrival_time != '' and real_arrival_time != ''")
# drop trips where departure time is null or empty, since we cannot use them to estimate day period
istdaten_filtered = istdaten_filtered.dropna(how="any", subset=["sch_departure_time", "real_departure_time"])\
    .withColumn("sch_departure_time", trim(col("sch_departure_time")))\
    .withColumn("real_departure_time", trim(col("real_departure_time")))\
    .filter("sch_departure_time != '' and real_departure_time != ''")
# drop no longer needed real departure time
istdaten_filtered = istdaten_filtered.drop("real_departure_time")

# remove the trips on Sundays (dayofweek=1) and Saturdays (dayofweek=7) since we removed them from our calendar
format_ = "dd.MM.yyyy"
istdaten_filtered = istdaten_filtered.where(~(dayofweek(to_timestamp(col("trip_date"), format_)).isin([1, 7])))
# we do not need the trip date anymore
istdaten_filtered = istdaten_filtered.drop("trip_date")

# keep only trips for which the time is reasonable (between 6am and 10pm)
format = "dd.MM.yyyy HH:mm"
istdaten_filtered = istdaten_filtered.where(
    (hour(to_timestamp(col('sch_arrival_time'), format)) >= 6) & (hour(to_timestamp(col('sch_arrival_time'), format)) < 22))
istdaten_filtered = istdaten_filtered.where(
    (hour(to_timestamp(col('sch_departure_time'), format)) >= 6) & (hour(to_timestamp(col('sch_departure_time'), format)) < 22))
# drop no longer needed real departure time
istdaten_filtered = istdaten_filtered.drop("sch_departure_time")

# keep trips for which we have stop names
istdaten_filtered = istdaten_filtered.join(stops, on="stop_name", how="leftsemi")

# caching
istdaten_filtered = istdaten_filtered.cache()

print("number of datapoints used to estimate delay distributions =", istdaten_filtered.count())
istdaten_filtered.printSchema()
istdaten_filtered.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

number of datapoints used to estimate delay distributions = 343310441
root
 |-- stop_name: string (nullable = true)
 |-- transport_type: string (nullable = false)
 |-- transport_subtype: string (nullable = true)
 |-- route_name: string (nullable = true)
 |-- real_arrival_time: string (nullable = true)
 |-- sch_arrival_time: string (nullable = true)

+--------------------+--------------+-----------------+----------+-------------------+----------------+
|           stop_name|transport_type|transport_subtype|route_name|  real_arrival_time|sch_arrival_time|
+--------------------+--------------+-----------------+----------+-------------------+----------------+
|killwangen, mühle...|           bus|              nfb|         4|13.02.2019 19:43:12|13.02.2019 19:43|
| killwangen, bahnhof|           bus|              nfb|         4|13.02.2019 19:45:54|13.02.2019 19:45|
+--------------------+--------------+-----------------+----------+-------------------+----------------+
only showing top 2 rows

In [15]:
%%spark

# COMPUTE DELAYS

# parse schedule/real arrival times to unix timestamp
istdaten_delays = istdaten_filtered.withColumn("sch_arrival_unix_t", unix_timestamp(col("sch_arrival_time"), "dd.MM.yyyy HH:mm"))
istdaten_delays = istdaten_delays.withColumn("real_arrival_unix_t", unix_timestamp(col("real_arrival_time"), "dd.MM.yyyy HH:mm:ss"))

# compute arrival delay, keep only non-negative delays
istdaten_delays = istdaten_delays.withColumn("arrival_delay", col("real_arrival_unix_t") - col("sch_arrival_unix_t"))\
    .withColumn("arrival_delay", when(col("arrival_delay") >= 0, col("arrival_delay")).otherwise(0))\
    .drop("real_arrival_unix_t", "sch_arrival_unix_t")\
    .drop("real_arrival_time")

# DEFINE DAY PERIODS
istdaten_delays = istdaten_delays.withColumn("sch_arrival_time", to_timestamp(col("sch_arrival_time"), "dd.MM.yyyy HH:mm"))

day_periods = {
    "morning": ("06:00", "08:30"),
    "prenoon": ("08:30", "12:00"),
    "afternoon": ("12:00", "15:00"),
    "latenoon": ("15:00", "18:00"),
    "evening": ("18:00", "22:00")
}

istdaten_delays_periods = istdaten_delays.withColumn(
    "day_period", 
    when(date_format(col("sch_arrival_time"), "HH:mm") < day_periods["morning"][1], "morning")\
    .when((day_periods["prenoon"][0] <= date_format(col("sch_arrival_time"), "HH:mm")) 
          & (date_format(col("sch_arrival_time"), "HH:mm") < day_periods["prenoon"][1]), "prenoon")\
    .when((day_periods["afternoon"][0] <= date_format(col("sch_arrival_time"), "HH:mm"))
          & (date_format(col("sch_arrival_time"), "HH:mm") < day_periods["afternoon"][1]), "afternoon")\
    .when((day_periods["latenoon"][0] <= date_format(col("sch_arrival_time"), "HH:mm"))
          & (date_format(col("sch_arrival_time"), "HH:mm") < day_periods["latenoon"][1]), "latenoon")\
    .otherwise("evening")
).drop("sch_arrival_time")

# caching
istdaten_delays_periods.cache()

istdaten_delays_periods.printSchema()
istdaten_delays_periods.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- stop_name: string (nullable = true)
 |-- transport_type: string (nullable = false)
 |-- transport_subtype: string (nullable = true)
 |-- route_name: string (nullable = true)
 |-- arrival_delay: long (nullable = true)
 |-- day_period: string (nullable = false)

+--------------------+--------------+-----------------+----------+-------------+----------+
|           stop_name|transport_type|transport_subtype|route_name|arrival_delay|day_period|
+--------------------+--------------+-----------------+----------+-------------+----------+
|killwangen, mühle...|           bus|              nfb|         4|           12|   evening|
| killwangen, bahnhof|           bus|              nfb|         4|           54|   evening|
+--------------------+--------------+-----------------+----------+-------------+----------+
only showing top 2 rows

In [16]:
%%spark
import pyspark.sql.functions as F
from pyspark.sql.types import DoubleType

# Define UDF
@F.udf(DoubleType())
def std_udf(delays):
    import numpy as np
    return float(np.std(delays))

# ESTIMATION OF MEAN DELAY AND STANDARD DEVIATION

avg_delay_by_tsptype_tspsubtype_dper = istdaten_delays_periods.groupBy(
    "stop_name", "day_period", "transport_type", "transport_subtype").agg(F.mean("arrival_delay").alias("mean_arrival_delay"))

avg_delay_by_tsptype_dper = istdaten_delays_periods.groupBy("stop_name", "day_period", "transport_type")\
    .agg(F.mean("arrival_delay").alias("mean_arrival_delay"))

avg_delay_by_dper = istdaten_delays_periods.groupBy("stop_name", "day_period")\
    .agg(F.mean("arrival_delay").alias("mean_arrival_delay"))

avg_delay = istdaten_delays_periods.groupBy("stop_name").agg(F.mean("arrival_delay").alias("mean_arrival_delay"))

std_delay_by_tsptype_tspsubtype_dper = istdaten_delays_periods.groupBy(
    "stop_name", "day_period", "transport_type", "transport_subtype").agg(std_udf(F.collect_list("arrival_delay")).alias("std_arrival_delay"))

std_delay_by_tsptype_dper = istdaten_delays_periods.groupBy("stop_name", "day_period", "transport_type")\
    .agg(std_udf(F.collect_list("arrival_delay")).alias("std_arrival_delay"))

std_delay_by_dper = istdaten_delays_periods.groupBy("stop_name", "day_period")\
    .agg(std_udf(F.collect_list("arrival_delay")).alias("std_arrival_delay"))

std_delay = istdaten_delays_periods.groupBy("stop_name")\
    .agg(std_udf(F.collect_list("arrival_delay")).alias("std_arrival_delay"))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [17]:
%%spark
avg_delay_by_tsptype_tspsubtype_dper.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+----------+--------------+-----------------+------------------+
|          stop_name|day_period|transport_type|transport_subtype|mean_arrival_delay|
+-------------------+----------+--------------+-----------------+------------------+
|    uster, buchholz|   evening|           bus|                b|  69.8245228696091|
|erlenbach zh, chapf| afternoon|           bus|                b| 90.60486694677871|
+-------------------+----------+--------------+-----------------+------------------+
only showing top 2 rows

In [18]:
%%spark
avg_delay_by_tsptype_dper.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------------+----------+--------------+------------------+
|        stop_name|day_period|transport_type|mean_arrival_delay|
+-----------------+----------+--------------+------------------+
| meilen, aebleten|  latenoon|           bus| 74.58796111333066|
|zürich, sädlenweg|   prenoon|           bus|  71.4975922953451|
+-----------------+----------+--------------+------------------+
only showing top 2 rows

In [19]:
%%spark
avg_delay_by_dper.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+----------+------------------+
|           stop_name|day_period|mean_arrival_delay|
+--------------------+----------+------------------+
| zürich, sihlstrasse|  latenoon|51.977935486896506|
|zürich, saalsport...|   morning| 69.96632419184942|
+--------------------+----------+------------------+
only showing top 2 rows

In [20]:
%%spark
avg_delay.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+------------------+
|           stop_name|mean_arrival_delay|
+--------------------+------------------+
|wettswil a.a., st...| 79.44618126272913|
|     illnau, wingert| 99.74798847976857|
+--------------------+------------------+
only showing top 2 rows

In [21]:
%%spark
std_delay_by_tsptype_tspsubtype_dper.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+----------+--------------+-----------------+-----------------+
|           stop_name|day_period|transport_type|transport_subtype|std_arrival_delay|
+--------------------+----------+--------------+-----------------+-----------------+
|adliswil, landolt...|   morning|           bus|              bus|53.92332298653285|
|     buchs zh, linde|   morning|           bus|              bus|48.94721520536078|
+--------------------+----------+--------------+-----------------+-----------------+
only showing top 2 rows

In [22]:
%%spark
std_delay_by_tsptype_tspsubtype_dper.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+----------+--------------+-----------------+-----------------+
|           stop_name|day_period|transport_type|transport_subtype|std_arrival_delay|
+--------------------+----------+--------------+-----------------+-----------------+
|adliswil, landolt...|   morning|           bus|              bus|53.92332298653285|
|     buchs zh, linde|   morning|           bus|              bus|48.94721520536078|
+--------------------+----------+--------------+-----------------+-----------------+
only showing top 2 rows

In [23]:
%%spark
std_delay_by_dper.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+----------+------------------+
|           stop_name|day_period| std_arrival_delay|
+--------------------+----------+------------------+
|   bülach, sonnenhof|   prenoon|167.15751763104709|
|dietlikon, dübend...|   evening|103.97810238124275|
+--------------------+----------+------------------+
only showing top 2 rows

In [24]:
%%spark
std_delay.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------------+------------------+
|      stop_name| std_arrival_delay|
+---------------+------------------+
|illnau, wingert|266.00226904698036|
|      stettbach| 73.48660427804474|
+---------------+------------------+
only showing top 2 rows

#### Preprocess routes

In [25]:
%%spark
# most recent routes file
routes = spark.read.format("orc").load("/data/sbb/part_orc/timetables/routes/year=2023/month=3/day=29")

# drop useless columns
# route_long_name is always null, so we use route_short_name as route name
routes = routes.drop("route_long_name", "agency_id", "route_desc")
routes = routes.withColumnRenamed("route_short_name", "route_name")

# lowercase the route names
routes = routes.withColumn("route_name", trim(lower(col("route_name"))))

# infer transport type of each route from istdaten
istdaten_transport_types = istdaten_translated[["transport_type", "transport_subtype", "route_name"]]\
    .withColumnRenamed("route_name", "route_name_istdaten")\
    .dropDuplicates()
routes_with_istdaten_types = routes.join(
    istdaten_transport_types,
    on=((istdaten_transport_types.transport_subtype == routes.route_name) |
        (istdaten_transport_types.route_name_istdaten == routes.route_name)),
    how="left").drop("route_name_istdaten")

# for each (route_id, transport_type), keep the most frequent transport_subtype
tspsubtype_window = Window.partitionBy(["route_id", "transport_type"]).orderBy(F.desc("count"))
routes_with_type_and_subtype = routes_with_istdaten_types.withColumn("count", lit(1))\
    .groupBy("route_id", "transport_type", "transport_subtype")\
    .agg(F.sum("count").alias("count"))\
    .withColumn("rank", F.row_number().over(tspsubtype_window))\
    .filter("rank == 1").drop("rank", "count")

# for each route_id in the previous result, keep the most frequent transport_type
tsptype_window = Window.partitionBy(["route_id"]).orderBy(F.desc("count"))
routes_with_type = routes_with_type_and_subtype.withColumn("count", lit(1))\
    .groupBy("route_id", "transport_type")\
    .agg(F.sum("count").alias("count"))\
    .withColumn("rank", F.row_number().over(tsptype_window))\
    .filter("rank == 1").drop("rank", "count")
routes_with_type_and_subtype = routes_with_type_and_subtype.join(
    routes_with_type, on=["route_id", "transport_type"], how="inner")

# join the initial collection of routes with the obtained one to get the types
routes = routes.join(
    routes_with_type_and_subtype,
    on="route_id",
    how="left")

# keep only the routes in trips
routes = routes.join(trips_with_stop_times.dropDuplicates(["trip_id", "route_id"]), on="route_id", how='leftsemi')

# caching
routes = routes.cache()

routes.printSchema()
routes.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- route_id: string (nullable = true)
 |-- route_name: string (nullable = true)
 |-- route_type: string (nullable = true)
 |-- transport_type: string (nullable = true)
 |-- transport_subtype: string (nullable = true)

+--------------+----------+----------+--------------+-----------------+
|      route_id|route_name|route_type|transport_type|transport_subtype|
+--------------+----------+----------+--------------+-----------------+
| 91-10-E-j23-1|        10|       900|          tram|              trm|
|92-305-A-j23-1|       305|       700|         train|                 |
+--------------+----------+----------+--------------+-----------------+
only showing top 2 rows

#### Save pre-processed files

In [35]:
%%spark
for username in usernames:
    stops.write.saveAsTable(f"{username}.pp_stops", path=f"/user/{username}/preprocessed/pp_stops", mode='overwrite', format='orc')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [36]:
%%spark
for username in usernames:
    stop_times.write.saveAsTable(
        f"{username}.pp_stop_times", path=f"/user/{username}/preprocessed/pp_stop_times", mode='overwrite', format='orc')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [37]:
%%spark
for username in usernames:
    trips.write.saveAsTable(
        f"{username}.pp_trips", path=f"/user/{username}/preprocessed/pp_trips", mode='overwrite', format='orc')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [38]:
%%spark
for username in usernames:
    routes.write.saveAsTable(
        f"{username}.pp_routes", path=f"/user/{username}/preprocessed/pp_routes", mode='overwrite', format='orc')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [39]:
%%spark
for username in usernames:
    trips_with_stop_times.write.saveAsTable(
        f"{username}.pp_trips_with_stop_times", path=f"/user/{username}/preprocessed/pp_trips_with_stop_times", mode='overwrite', format='orc')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [40]:
%%spark

fname_to_df = {
    # files with mean values
    "delays/avg_delay_tsptype_tspsubtype_dper": avg_delay_by_tsptype_tspsubtype_dper,
    "delays/avg_delay_tsptype_dper": avg_delay_by_tsptype_dper,
    "delays/avg_delay_dper": avg_delay_by_dper,
    "delays/avg_delay": avg_delay,
    
    # files with std values
    "delays/std_delay_tsptype_tspsubtype_dper": std_delay_by_tsptype_tspsubtype_dper,
    "delays/std_delay_tsptype_dper": std_delay_by_tsptype_dper,
    "delays/std_delay_dper": std_delay_by_dper,
    "delays/std_delay": std_delay,
}

for username in usernames:
    for fname, df in fname_to_df.items():
        df.write.parquet(f"/user/{username}/{fname}", mode="overwrite", compression="gzip")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Create edges

##### Add sbb edges

In [26]:
%%spark
from pyspark.sql.window import Window

# join trips_with_stop_times with routes
edges = trips_with_stop_times.join(routes, on="route_id", how="left")

# join edges with stops
edges = edges.join(stops, on="stop_id", how="inner")\
    .select([
    "trip_id", "trip_headsign",
    "route_id", "route_name", "route_type", 
    "transport_type", "transport_subtype",
    "stop_id", "stop_name",
    "arrival_time", "departure_time", "stop_seqnum",
    "parent_station"])

# replicate stop_id, departure_time and arrival_time column and shift it by one row
# and remove last edge of trip that points to None
window = Window.partitionBy("trip_id").orderBy("stop_seqnum")
edges = edges.withColumn("next_stop_id", lead("stop_id").over(window))
edges = edges.withColumn("next_stop_name", lead("stop_name").over(window))
edges = edges.withColumn("next_arrival_time", lead("arrival_time").over(window))
edges = edges.withColumn("next_parent_station", lead("parent_station").over(window))
edges = edges.withColumn("next_departure_time", lead("departure_time").over(window))
edges = edges.dropna(subset=["next_stop_id"])

# add is_walkable attribute
edges = edges.withColumn("is_walkable", lit(False))
# compute duration as difference between next_arrival_time and departure_time in seconds
edges = edges.withColumn("duration_s", 
                         unix_timestamp(col("next_arrival_time"), "HH:mm:ss") - unix_timestamp(col("departure_time"), "HH:mm:ss"))\
                        .withColumn("duration_s", col("duration_s").cast("int"))

# stop waiting time = stop's departure time - stop's arrival time
edges = edges.withColumn("waiting_time_s", 
                         unix_timestamp(col("departure_time"), "HH:mm:ss") - unix_timestamp(col("arrival_time"), "HH:mm:ss"))\
                         .withColumn("waiting_time_s", col("waiting_time_s").cast("int"))

print("final number of non-walkable edges =", edges.count())
edges.printSchema()
edges.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

final number of non-walkable edges = 358009
root
 |-- trip_id: string (nullable = true)
 |-- trip_headsign: string (nullable = true)
 |-- route_id: string (nullable = true)
 |-- route_name: string (nullable = true)
 |-- route_type: string (nullable = true)
 |-- transport_type: string (nullable = true)
 |-- transport_subtype: string (nullable = true)
 |-- stop_id: string (nullable = true)
 |-- stop_name: string (nullable = true)
 |-- arrival_time: string (nullable = true)
 |-- departure_time: string (nullable = true)
 |-- stop_seqnum: integer (nullable = true)
 |-- parent_station: string (nullable = true)
 |-- next_stop_id: string (nullable = true)
 |-- next_stop_name: string (nullable = true)
 |-- next_arrival_time: string (nullable = true)
 |-- next_parent_station: string (nullable = true)
 |-- next_departure_time: string (nullable = true)
 |-- is_walkable: boolean (nullable = false)
 |-- duration_s: integer (nullable = true)
 |-- waiting_time_s: integer (nullable = true)

+----------

##### Add walkable edges

In [27]:
%%spark

# a walkable edge is one between two stops in the area of 16km around Zurich, such that the distance between them is at most 500m
stops_2 = stops
for colname in ["stop_id", "stop_name", "stop_lat", "stop_lon", "parent_station"]:
    stops_2 = stops_2.withColumnRenamed(colname, colname+"_2")

all_stops_pairs = stops.crossJoin(stops_2).where(col("stop_id") != col("stop_id_2"))

# calculate distance between stops using haversine_dist function
all_stops_pairs = all_stops_pairs.withColumn("distance_m",
                                             haversine_dist(col("stop_lat"), col("stop_lon"), col("stop_lat_2"), col("stop_lon_2")))\
    .drop("stop_lat", "stop_lon", "stop_lat_2", "stop_lon_2")

# filter out rows where distance > 500m
walkable_edges = all_stops_pairs.filter(col("distance_m") <= 500)

# transfer time is at least 2 min, we compute it is seconds 
walkable_edges = walkable_edges.withColumn("duration_s", F.round((2 + col("distance_m") / 50) * 60)).drop("distance_m") # in seconds

# create column is_walkable with value to True 
walkable_edges = walkable_edges.withColumn("is_walkable", lit(True))

# create edge columns not used for walkable edges
walkable_edges = walkable_edges.withColumn("departure_time", lit(None))
walkable_edges = walkable_edges.withColumn("arrival_time", lit(None))
walkable_edges = walkable_edges.withColumn("next_departure_time", lit(None))
walkable_edges = walkable_edges.withColumn("next_arrival_time", lit(None))
walkable_edges = walkable_edges.withColumn("waiting_time_s", lit(None))

walkable_edges = walkable_edges.withColumn("trip_id", lit(None))
walkable_edges = walkable_edges.withColumn("trip_headsign", lit(None))
walkable_edges = walkable_edges.withColumn("stop_seqnum", lit(None))

walkable_edges = walkable_edges.withColumn("route_id", lit(None))
walkable_edges = walkable_edges.withColumn("route_name", lit(None))
walkable_edges = walkable_edges.withColumn("route_type", lit(None))

walkable_edges = walkable_edges.withColumn("transport_type", lit("walk"))
walkable_edges = walkable_edges.withColumn("transport_subtype", lit("walk"))

# rename columns
walkable_edges = walkable_edges.withColumnRenamed("stop_id_2", "next_stop_id")
walkable_edges = walkable_edges.withColumnRenamed("stop_name_2", "next_stop_name")
walkable_edges = walkable_edges.withColumnRenamed("parent_station_2", "next_parent_station")

walkable_edges = walkable_edges.cache()

print("final number of walkable edges =", walkable_edges.count())
walkable_edges.printSchema()
walkable_edges.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

final number of walkable edges = 19012
root
 |-- stop_id: string (nullable = true)
 |-- stop_name: string (nullable = true)
 |-- parent_station: string (nullable = true)
 |-- next_stop_id: string (nullable = true)
 |-- next_stop_name: string (nullable = true)
 |-- next_parent_station: string (nullable = true)
 |-- duration_s: double (nullable = true)
 |-- is_walkable: boolean (nullable = false)
 |-- departure_time: null (nullable = true)
 |-- arrival_time: null (nullable = true)
 |-- next_departure_time: null (nullable = true)
 |-- next_arrival_time: null (nullable = true)
 |-- waiting_time_s: null (nullable = true)
 |-- trip_id: null (nullable = true)
 |-- trip_headsign: null (nullable = true)
 |-- stop_seqnum: null (nullable = true)
 |-- route_id: null (nullable = true)
 |-- route_name: null (nullable = true)
 |-- route_type: null (nullable = true)
 |-- transport_type: string (nullable = false)
 |-- transport_subtype: string (nullable = false)

+-----------+----------------+---------

##### Merge edges

In [28]:
%%spark
all_edges = edges.unionByName(walkable_edges)

print("final total number of edges in the network =", all_edges.count())
all_edges.printSchema()
all_edges.filter("is_walkable == 'true'").show(2)
all_edges.filter("is_walkable == 'false'").show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

final total number of edges in the network = 377021
root
 |-- trip_id: string (nullable = true)
 |-- trip_headsign: string (nullable = true)
 |-- route_id: string (nullable = true)
 |-- route_name: string (nullable = true)
 |-- route_type: string (nullable = true)
 |-- transport_type: string (nullable = true)
 |-- transport_subtype: string (nullable = true)
 |-- stop_id: string (nullable = true)
 |-- stop_name: string (nullable = true)
 |-- arrival_time: string (nullable = true)
 |-- departure_time: string (nullable = true)
 |-- stop_seqnum: integer (nullable = true)
 |-- parent_station: string (nullable = true)
 |-- next_stop_id: string (nullable = true)
 |-- next_stop_name: string (nullable = true)
 |-- next_arrival_time: string (nullable = true)
 |-- next_parent_station: string (nullable = true)
 |-- next_departure_time: string (nullable = true)
 |-- is_walkable: boolean (nullable = false)
 |-- duration_s: double (nullable = true)
 |-- waiting_time_s: integer (nullable = true)

+---

In [29]:
%%spark
from pyspark.sql.types import IntegerType

# create a mapping from string values to integer values
string_to_int = {}
next_int = 0
for row in all_edges.select(col("stop_id"), col("next_stop_id")).distinct().collect():
    source = row["stop_id"]
    target = row["next_stop_id"]
    if source not in string_to_int:
        string_to_int[source] = next_int
        next_int += 1
    if target not in string_to_int:
        string_to_int[target] = next_int
        next_int += 1
        
# define a UDF to map the node ID to the corresponding integer
string_to_int_udf = F.udf(lambda x: string_to_int[x], IntegerType())

# add new columns with the mapped integer values
all_edges = all_edges.withColumn("node_id", string_to_int_udf(col("stop_id")))
all_edges = all_edges.withColumn("next_node_id", string_to_int_udf(col("next_stop_id")))

all_edges.printSchema()
all_edges.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- trip_id: string (nullable = true)
 |-- trip_headsign: string (nullable = true)
 |-- route_id: string (nullable = true)
 |-- route_name: string (nullable = true)
 |-- route_type: string (nullable = true)
 |-- transport_type: string (nullable = true)
 |-- transport_subtype: string (nullable = true)
 |-- stop_id: string (nullable = true)
 |-- stop_name: string (nullable = true)
 |-- arrival_time: string (nullable = true)
 |-- departure_time: string (nullable = true)
 |-- stop_seqnum: integer (nullable = true)
 |-- parent_station: string (nullable = true)
 |-- next_stop_id: string (nullable = true)
 |-- next_stop_name: string (nullable = true)
 |-- next_arrival_time: string (nullable = true)
 |-- next_parent_station: string (nullable = true)
 |-- next_departure_time: string (nullable = true)
 |-- is_walkable: boolean (nullable = false)
 |-- duration_s: double (nullable = true)
 |-- waiting_time_s: integer (nullable = true)
 |-- node_id: integer (nullable = true)
 |-- next_node_id

#### Store Edges

In [30]:
%%spark
for username in usernames:
    all_edges.write.orc(f"/user/{username}/network_data/all_edges", mode='overwrite')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [31]:
%%spark
for username in usernames:
    edges.write.orc(f"/user/{username}/network_data/sbb_edges", mode='overwrite')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Load All Edges

In [32]:
%%spark
username = "sly"
test_1 = spark.read.parquet(f"/user/{username}/delays/avg_delay")
test_1.printSchema()

username = "ouerghem"
test_2 = spark.read.orc(f"/user/{username}/preprocessed/pp_trips_with_stop_times")
test_2.printSchema()

username = "berkane"
test_3 = spark.read.orc(f"/user/{username}/network_data/all_edges")
test_3.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- stop_name: string (nullable = true)
 |-- mean_arrival_delay: double (nullable = true)

root
 |-- trip_id: string (nullable = true)
 |-- service_id: string (nullable = true)
 |-- route_id: string (nullable = true)
 |-- trip_headsign: string (nullable = true)
 |-- direction_id: string (nullable = true)
 |-- stop_id: string (nullable = true)
 |-- arrival_time: string (nullable = true)
 |-- departure_time: string (nullable = true)
 |-- stop_seqnum: integer (nullable = true)

root
 |-- trip_id: string (nullable = true)
 |-- trip_headsign: string (nullable = true)
 |-- route_id: string (nullable = true)
 |-- route_name: string (nullable = true)
 |-- route_type: string (nullable = true)
 |-- transport_type: string (nullable = true)
 |-- transport_subtype: string (nullable = true)
 |-- stop_id: string (nullable = true)
 |-- stop_name: string (nullable = true)
 |-- arrival_time: string (nullable = true)
 |-- departure_time: string (nullable = true)
 |-- stop_seqnum: integer (nullable 