# Analyze delays based on historical data

In this notebook, we try to come up with a sensible estimation of the delay of every type of connection or trip by analysing the historical data on the public transport system of Zurich. This delay probability will play a key role in our predictive model given that it is based on an exponential distribution. Therefore , it is essential to have a good estimation and that is why we calculate three different levels of probabilities. The notebook is structured as follows: 

*   **[Start Spark](#spark)** 
*   **[Load and filter historical SBB data](#sbb)**  
*   **[Three-level delay probability](#prob)** 
*   **[Final probability](#final)**



<a id = 'spark'></a>
### 1. Start Spark

We will be using a Spark Session for performing different transformations and actions on dataframes

In [63]:
%%configure -f
{
    "conf": {
        "spark.app.name": "datavirus_delay_analysis",
    }
}

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
8318,application_1589299642358_2850,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
8293,application_1589299642358_2825,pyspark,idle,Link,Link,
8295,application_1589299642358_2827,pyspark,idle,Link,Link,
8302,application_1589299642358_2834,pyspark,busy,Link,Link,
8305,application_1589299642358_2837,pyspark,idle,Link,Link,
8306,application_1589299642358_2838,pyspark,idle,Link,Link,
8310,application_1589299642358_2842,pyspark,idle,Link,Link,
8312,application_1589299642358_2844,pyspark,busy,Link,Link,
8315,application_1589299642358_2847,pyspark,idle,Link,Link,
8316,application_1589299642358_2848,pyspark,busy,Link,Link,
8317,application_1589299642358_2849,pyspark,idle,Link,Link,


In [64]:
# Initialization
spark

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<pyspark.sql.session.SparkSession object at 0x7f0b7d6c75d0>

In [66]:
import pyspark.sql.functions as f

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<a id = 'sbb'></a>
### 2.Load and filter the historical SBB data

All historical data is loaded. We filter out irrelevant data and we compute the delay between expected arrival time and real (if AN_PROGNOSE STATUS = REAL) arrival time or estimated (if AN_PROGNOSE STATUS = PROGNOSE) arrival time for every stop in every trip. Of couse, we only take into consideratiosn stops within 15 km of Zurich

In [67]:
sbb = spark.read.orc('/data/sbb/orc/istdaten')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [70]:
sbb_filtered = (
    sbb
    
    # filter out any trips that we are not interested in
    .where(sbb.zusatzfahrt_tf == 'false')
    .where(sbb.faellt_aus_tf == 'false')
    .where(sbb.durchfahrt_tf == 'false')
    .where(sbb.ankunftszeit.isNotNull())
    .where(sbb.an_prognose.isNotNull())
    .where(sbb.an_prognose_status.isin('REAL', 'PROGNOSE'))
    .where(sbb.betreiber_id.substr(0, 2) == '85')
    
    # convert produkt_id to lowercase
    .withColumn('produkt_id', f.lower(sbb.produkt_id))
    .where(f.col('produkt_id').isin('bus', 'tram', 'zug'))
    
    # create `id` column that is used
    # to identify trips from this dataframe
    # with trips in the connections dataframe
    .withColumn(
        'id',
        f.when(f.col('produkt_id') == 'zug', f.concat_ws(':', sbb.betreiber_id, sbb.linien_id))
        .otherwise(f.concat_ws(':', sbb.betreiber_id, sbb.linien_text))
    )
    
    # calculate delay
    .withColumn('ankunftszeit_ts', f.unix_timestamp(f.col('ankunftszeit'), "dd.MM.yyyy HH:mm").cast('long'))
    .withColumn('an_prognose_ts', f.unix_timestamp(f.col('an_prognose'), "dd.MM.yyyy HH:mm:ss").cast('long'))
    .withColumn('delay', f.col('an_prognose_ts') - f.col('ankunftszeit_ts'))
    .where(f.col('delay').isNotNull())
    
    # sometimes there are more multiple REAL and PROGNOSE values for the exactly same trip on the same day
    # and time. There is no way which one is the closest to the actual arrival time of the train so we 
    # just drop all duplicates and keep one "randomly".
    
    .repartition(150, 'id')
    
    .select(['betriebstag', 'id', 'produkt_id', 'linien_text', 'bpuic', 'ankunftszeit', 'an_prognose_status', 'delay'])
    .dropDuplicates(['id', 'betriebstag', 'bpuic', 'ankunftszeit', 'an_prognose_status'])
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [74]:
# remove all stations that are not in 15km radius around zurich 

stations = spark.read.csv("../data/zurich_stations_ids.csv")
stations = (
    stations
    .select(stations._c0.alias('stop_id'))
)
stop_ids = [row['stop_id'] for row in stations.collect()]
stop_ids = spark.sparkContext.broadcast(stop_ids)


sbb_zurich = (
    sbb_filtered.
    where(sbb_filtered.bpuic.isin(stop_ids.value))
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [76]:
delays = (
    sbb_zurich
    
    # for each trip we only kept one REAL and one PROGNOSE row
    # now we aggregate them into a single same row
    .groupBy(['id', 'ankunftszeit', 'bpuic'])
    .agg(
        f.first('produkt_id').alias('produkt_id'),
        f.first('linien_text').alias('linien_text'),
        f.first(f.when(f.col('an_prognose_status') == 'REAL', f.col('delay')).otherwise(None)).alias('real_delay'),
        f.first(f.when(f.col('an_prognose_status') == 'PROGNOSE', f.col('delay')).otherwise(None)).alias('prognose_delay'),
    )
    
    # if there is REAL delay we select it, otherwise we take the PROGNOSE delay 
    .withColumn('delay', f.when(f.col('real_delay').isNotNull(), f.col('real_delay')).otherwise(f.col('prognose_delay')))
    .withColumn('is_delayed', f.when(f.col('delay') > 0, 1).otherwise(0))
    
    # create extra columns that contain time without date
    .withColumn('ankunftszeit_minute', f.col('ankunftszeit').substr(12, 5))
    .withColumn('ankunftszeit_hour', f.col('ankunftszeit').substr(12, 2))
    .cache()
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<a id = 'prob'></a>
### 3.Three-level delay probability

The delay probabilities are calculated in three different levels, as follows:

- Low-level probability (most detailed one) -->  Different parameter for every: <code>line, station, minute of the day</code> 
- Medium-level probability --> Different parameter  for every: <code>line, hour of the day </code>
- High level probability --> Different parameter for every <code> transport_type, hour of the day</code>

In [77]:
id_probability = (
    delays
    
    .groupBy(['id', 'ankunftszeit_minute', 'bpuic'])
    
    .agg(
        f.first('produkt_id').alias('produkt_id'),
        f.first('linien_text').alias('linien_text'),
        f.mean(delays.is_delayed).alias('id_delay_probability'),
        (1.0 / f.mean(f.when(delays.is_delayed == 1, f.when(delays.delay > 30*60, 30*60).otherwise(delays.delay)))).alias('id_delay_parameter'),
        f.count(delays.delay).alias('id_n')
    )
)

id_probability.write.format('orc').save('/user/datavirus/id_probability.orc', mode='overwrite')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [81]:
line_probability = (
    delays
    
    .groupBy(['id', 'ankunftszeit_hour'])
    
    .agg(
        f.mean(delays.is_delayed).alias('line_delay_probability'),
        (1.0 / f.mean(f.when(delays.is_delayed == 1, f.when(delays.delay > 30*60, 30*60).otherwise(delays.delay)))).alias('line_delay_parameter'),
        f.count(delays.delay).alias('line_n')
    )
    
)
line_probability.write.format('orc').save('/user/datavirus/line_probability.orc', mode='overwrite')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [84]:
transport_probability = (
    delays
    
    .groupBy(['ankunftszeit_hour', 'produkt_id'])
    
    .agg(
        f.mean(delays.is_delayed).alias('transport_delay_probability'),
        (1.0 / f.mean(f.when(delays.is_delayed == 1, f.when(delays.delay > 30*60, 30*60).otherwise(delays.delay)))).alias('transport_delay_parameter'),
        f.count(delays.delay).alias('transport_n')
    )
)
transport_probability.write.format('orc').save('/user/datavirus/transport_probability.orc', mode='overwrite')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<a id = 'final'></a>
### 4.Final probability

For each connection in our connections dataframe we check whether the most detailed probability data has enough data points **(>500)**. If this is the case we assume that the calculated probabilities are good and use them for that  connection. If there is not enough data we move one level of detail lower and check if there is enough data (again >500 points). If this is the case we use that data. Otherwise we use the <code>transport_type, hour of day</code>  data that covers all the remaining connections.

In [88]:
full_probability = (
    id_probability
    .withColumn('ankunftszeit_hour', f.col('ankunftszeit_minute').substr(0, 2))
    .join(f.broadcast(line_probability), ['id', 'ankunftszeit_hour'])
    .join(f.broadcast(transport_probability), ['produkt_id', 'ankunftszeit_hour'])
    .select(
        id_probability.id,
        id_probability.ankunftszeit_minute.alias('arrival_time_minute'),
        id_probability.bpuic.alias('station_id'),
        id_probability.produkt_id.alias('transport_type'),
        id_probability.linien_text.alias('line_text'),
        id_probability.id_delay_probability,
        id_probability.id_delay_parameter,
        id_probability.id_n,
        line_probability.line_delay_probability,
        line_probability.line_delay_parameter,
        line_probability.line_n,
        transport_probability.transport_delay_probability,
        transport_probability.transport_delay_parameter,
        transport_probability.transport_n,
    )
)
full_probability.write.format('orc').save('/user/datavirus/full_probability.orc', mode='overwrite')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [92]:
probability = (
    full_probability
    .select(
        full_probability.id,
        full_probability.arrival_time_minute,
        full_probability.station_id,
        full_probability.transport_type,
        full_probability.line_text,
        (
            f.when(full_probability.id_n > 500, full_probability.id_delay_probability)
            .otherwise(
                f.when(full_probability.line_n > 500, full_probability.line_delay_probability)
                .otherwise(full_probability.transport_delay_probability)
            )
        ).alias('delay_probability'),
        (
            f.when(full_probability.id_n > 500, full_probability.id_delay_parameter)
            .otherwise(
                f.when(full_probability.line_n > 500, full_probability.line_delay_parameter)
                .otherwise(full_probability.transport_delay_parameter)
            )
        ).alias('delay_parameter')  
    )
)
probability.write.format('orc').save('/user/datavirus/probability.orc', mode='overwrite')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [95]:
probability.show(100, False)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+-------------------+----------+--------------+---------+-------------------+--------------------+
|id         |arrival_time_minute|station_id|transport_type|line_text|delay_probability  |delay_parameter     |
+-----------+-------------------+----------+--------------+---------+-------------------+--------------------+
|85:11:13792|01:47              |8503126   |zug           |SN9      |0.8125874125874126 |0.0127487766879512  |
|85:11:13792|01:51              |8503127   |zug           |SN9      |0.8125874125874126 |0.0127487766879512  |
|85:11:13792|01:54              |8503128   |zug           |SN9      |0.8125874125874126 |0.0127487766879512  |
|85:11:13792|01:58              |8503147   |zug           |SN9      |0.8125874125874126 |0.0127487766879512  |
|85:11:13792|02:03              |8503003   |zug           |SN9      |0.7451117318435754 |0.012608716203441103|
|85:11:13792|02:06              |8503000   |zug           |SN9      |0.7451117318435754 |0.012608716203441103|
|