In [74]:
#%load_ext sparkmagic.magics
%load_ext sparkmagic.magics

The sparkmagic.magics extension is already loaded. To reload it, use:
  %reload_ext sparkmagic.magics


## Initialize spark session

In [2]:
import os
from IPython import get_ipython

username = os.environ['RENKU_USERNAME']
server = "http://iccluster029.iccluster.epfl.ch:8998"

get_ipython().run_cell_magic(
    'spark',
    line='config', 
    cell="""{{ "name": "{0}-week7", "executorMemory": "4G", "executorCores": 4, "numExecutors": 10, "driverMemory": "4G"}}""".format(username)
)

In [3]:
get_ipython().run_line_magic(
    "spark", "add -s {0}-week7 -l python -u {1} -k".format(username, server)
)

Starting Spark application


An error was encountered:
Session 6993 did not start up in 60 seconds.


In [None]:
%%spark?

## Get connected to Hive

In [1]:
import os
import pandas as pd
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)

from pyhive import hive

# Set python variables from environment variables
username = os.environ['RENKU_USERNAME']
hive_host = os.environ['HIVE_SERVER2'].split(':')[0]
hive_port = os.environ['HIVE_SERVER2'].split(':')[1]

# create connection
conn = hive.connect(host=hive_host,
                    port=hive_port,
                    username=username) 
# create cursor
cur = conn.cursor()

In [2]:
# create database, only run one time
query = """
    drop database if exists {0}_final cascade
""".format(username)
cur.execute(query)

query = """
    create database {0}_final location "/user/{0}/hive"
""".format(username)
cur.execute(query)

query = """
    use {0}
""".format(username)
cur.execute(query)

In [45]:
# create table from allstops data
query = """
    drop table if exists {0}_final.allstop
""".format(username)
cur.execute(query)

query = """
    create external table {0}_final.allstop(
        STOP_ID        string,
        STOP_NAME      string,
        STOP_LAT       double,
        STOP_LON       double,
        LOCATION_TYPE  string,
        PARENT_STATION string
    )
    stored as orc
    location '/data/sbb/orc/allstops'
    tblproperties ('orc.compress'='SNAPPY','immutable'='true')
""".format(username)
cur.execute(query)

In [35]:
# verify schema
cur.execute("DESCRIBE {0}_final.allstop".format(username))
cur.fetchall()

[('stop_id', 'string', ''),
 ('stop_name', 'string', ''),
 ('stop_lat', 'double', ''),
 ('stop_lon', 'double', ''),
 ('location_type', 'string', ''),
 ('parent_station', 'string', '')]

In [49]:
query = """
    select * from {0}_final.allstop
""".format(username)

In [52]:
df = pd.read_sql(query, conn)
df.columns = ['stop_id','stop_name', 'stop_lat', 'stop_lon', 'location_type', 'parent_station']

In [53]:
df.head()

Unnamed: 0,stop_id,stop_name,stop_lat,stop_lon,location_type,parent_station
0,1100008,"Zell (Wiesental), Wilder Mann",47.710084,7.859648,,
1,1100009,"Zell (Wiesental), Grönland",47.713191,7.862909,,
2,1100010,Atzenbach,47.714618,7.87235,,
3,1100011,"Mambach, Brücke",47.728209,7.87747,,
4,1100012,"Mambach, Mühlschau",47.734082,7.881387,,


In [58]:
# Note: copied from https://stackoverflow.com/questions/19412462/getting-distance-between-two-points-based-on-latitude-longitude

from math import sin, cos, sqrt, atan2, radians

def dist(lat1, lon1, lat2, lon2):
    '''
    Calculate distance based on coordinates
    '''
    
    R = 6373.0
    
    lat1 = radians(lat1)
    lon1 = radians(lon1)
    lat2 = radians(lat2)
    lon2 = radians(lon2)
    
    dlon = lon2 - lon1
    dlat = lat2 - lat1

    a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
    c = 2 * atan2(sqrt(a), sqrt(1 - a))

    return R * c

In [54]:
# filter stations
zurich_hb = (47.3781762039461, 8.54021154209037)
df = df[df.apply(lambda s: dist(s['stop_lat'], s['stop_lon'], *zurich_hb) < 15, axis = 1)]

In [55]:
df

Unnamed: 0,stop_id,stop_name,stop_lat,stop_lon,location_type,parent_station
7895,176,Zimmerberg-Basistunnel,47.351678,8.521958,,
10877,8500926,"Oetwil a.d.L., Schweizäcker",47.423627,8.403183,,
12362,8502075,"Zürich Flughafen, Carterminal",47.451024,8.563729,,
12533,8502186,Dietikon Stoffelbach,47.393327,8.398960,,Parent8502186
12534,8502186:0,Dietikon Stoffelbach,47.393400,8.398942,,Parent8502186
...,...,...,...,...,...,...
46613,Parent8587020,Bahnhof,47.405834,8.404521,1,
46621,Parent8587651,Glatt (Bus),47.409209,8.595548,1,
46652,Parent8590279,Shopping Center,47.420448,8.368517,1,
46658,Parent8590464,Bahnhof,47.311781,8.524311,1,


In [62]:
df[df['stop_name'] == 'Dietikon Stoffelbach']

Unnamed: 0,stop_id,stop_name,stop_lat,stop_lon,location_type,parent_station
12533,8502186,Dietikon Stoffelbach,47.393327,8.39896,,Parent8502186
12534,8502186:0,Dietikon Stoffelbach,47.3934,8.398942,,Parent8502186
12535,8502186:0:1/2,Dietikon Stoffelbach,47.3934,8.398942,,Parent8502186
12536,8502186P,Dietikon Stoffelbach,47.3934,8.398942,1.0,
45662,Parent8502186,Dietikon Stoffelbach,47.393327,8.39896,1.0,


## Get connections by spark

In [4]:
%%spark
# Note: copied from https://stackoverflow.com/questions/19412462/getting-distance-between-two-points-based-on-latitude-longitude

from math import sin, cos, sqrt, atan2, radians

def dist(lat1, lon1, lat2, lon2):
    '''
    Calculate distance based on coordinates
    '''
    
    R = 6373.0
    
    lat1 = radians(lat1)
    lon1 = radians(lon1)
    lat2 = radians(lat2)
    lon2 = radians(lon2)
    
    dlon = lon2 - lon1
    dlat = lat2 - lat1

    a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
    c = 2 * atan2(sqrt(a), sqrt(1 - a))

    return R * c

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
%%spark
from pyspark.sql.types import StructType, IntegerType, StringType, DoubleType, BooleanType


schema = StructType() \
      .add("stop_id",StringType(),True) \
      .add("stop_name",StringType(),True) \
      .add("stop_lat",DoubleType(),True) \
      .add("stop_lon",DoubleType(),True) \
      .add("location_type",StringType(),True) \
      .add("parent_station",StringType(),True) 
      
stop_info = spark.read.format("csv") \
      .options(header = True)\
      .schema(schema) \
      .load('/data/sbb/csv/allstops')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [105]:
%%spark
stop_info.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+--------------------+----------------+----------------+-------------+--------------+
|stop_id|           stop_name|        stop_lat|        stop_lon|location_type|parent_station|
+-------+--------------------+----------------+----------------+-------------+--------------+
|1100008|Zell (Wiesental),...|47.7100842702352|7.85964788274668|         null|          null|
|1100009|Zell (Wiesental),...|47.7131911044794|7.86290876722849|         null|          null|
|1100010|           Atzenbach|47.7146175266411| 7.8723500608659|         null|          null|
|1100011|     Mambach, Brücke|47.7282088873189| 7.8774704579861|         null|          null|
|1100012|  Mambach, Mühlschau|47.7340818684375| 7.8813871126254|         null|          null|
+-------+--------------------+----------------+----------------+-------------+--------------+
only showing top 5 rows

In [6]:
%%spark
trip_info = spark.read.options(header = True).csv("/data/sbb/csv/trips/2019/05/15/trips.txt")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [107]:
%%spark
trip_info.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+----------+--------------------+------------------+---------------+------------+
|   route_id|service_id|             trip_id|     trip_headsign|trip_short_name|direction_id|
+-----------+----------+--------------------+------------------+---------------+------------+
|1-1-C-j19-1|  TA+b0001|5.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            108|           1|
|1-1-C-j19-1|  TA+b0001|7.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            112|           1|
|1-1-C-j19-1|  TA+b0001|9.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            116|           1|
|1-1-C-j19-1|  TA+b0001|11.TA.1-1-C-j19-1...|Zofingen, Altachen|            120|           1|
|1-1-C-j19-1|  TA+b0001|13.TA.1-1-C-j19-1...|Zofingen, Altachen|            124|           1|
+-----------+----------+--------------------+------------------+---------------+------------+
only showing top 5 rows

In [7]:
%%spark

schema = StructType() \
      .add("trip_id",StringType(),True) \
      .add("arrival_time",StringType(),True) \
      .add("depearture_time",StringType(),True) \
      .add("stop_id",StringType(),True) \
      .add("stop_sequence",IntegerType(),True) \
      .add("pickup_type",StringType(),True)\
      .add("drop_off_type",StringType(),True)
      
stops = spark.read.format("csv") \
      .options(header = True)\
      .schema(schema) \
      .load('/data/sbb/csv/stop_times/2019/05/15/stop_times.txt')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [109]:
%%spark
stops.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+------------+---------------+-----------+-------------+-----------+-------------+
|             trip_id|arrival_time|depearture_time|    stop_id|stop_sequence|pickup_type|drop_off_type|
+--------------------+------------+---------------+-----------+-------------+-----------+-------------+
|1.TA.1-1-B-j19-1.1.R|    04:20:00|       04:20:00|8500010:0:3|            1|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:24:00|       04:24:00|8500020:0:3|            2|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:28:00|       04:28:00|8500021:0:5|            3|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:30:00|       04:30:00|8517131:0:2|            4|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:32:00|       04:32:00|8500300:0:5|            5|          0|            0|
+--------------------+------------+---------------+-----------+-------------+-----------+-------------+
only showing top 5 rows

In [8]:
%%spark
from pyspark.sql import functions as F
from pyspark.sql.functions import col

# filter stops by distance
zurich_hb = (47.3781762039461, 8.54021154209037)
dist_udf = F.udf(lambda lat,lon: dist(lat, lon, *zurich_hb) < 15, BooleanType())

stops_15 = stops.join(stop_info,['stop_id'] , how = 'left')
stops_15 = stops_15.select('trip_id','arrival_time','depearture_time','stop_id','stop_sequence','stop_lat','stop_lon')
stops_15 = stops_15.dropna()
stops_15 = stops_15.filter(dist_udf('stop_lat', 'stop_lon'))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [12]:
%%spark
stops_15.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- trip_id: string (nullable = true)
 |-- arrival_time: string (nullable = true)
 |-- depearture_time: string (nullable = true)
 |-- stop_id: string (nullable = true)
 |-- stop_sequence: integer (nullable = true)
 |-- stop_lat: double (nullable = true)
 |-- stop_lon: double (nullable = true)

In [48]:
%%spark
stops_15.show(10)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+------------+---------------+---------------+-------------+----------------+----------------+
|             trip_id|arrival_time|depearture_time|        stop_id|stop_sequence|        stop_lat|        stop_lon|
+--------------------+------------+---------------+---------------+-------------+----------------+----------------+
|5.TA.1-1-A-j19-1.3.H|    02:42:00|       02:42:00|    8503305:0:3|            2|47.4258149579094|8.68668184918644|
|5.TA.1-1-A-j19-1.3.H|    02:46:00|       02:46:00|    8503306:0:3|            3|47.4201990480087|8.61927227025673|
|5.TA.1-1-A-j19-1.3.H|    02:50:00|       02:50:00|    8503147:0:1|            4|47.3972125517017|8.59614065168743|
|5.TA.1-1-A-j19-1.3.H|    02:55:00|       02:55:00|    8503003:0:1|            5|47.3666111556789|8.54848502585826|
|5.TA.1-1-A-j19-1.3.H|    02:58:00|       03:00:00|8503000:0:41/42|            6|47.3781762039461|8.54021154209037|
|5.TA.1-1-A-j19-1.3.H|    03:02:00|       03:02:00|    8503020:0:4|     

In [9]:
%%spark
# get depearture and arrival timetable
depearture = stops_15.select('trip_id','depearture_time','stop_id','stop_sequence')
depearture = depearture.withColumn("stop_sequence_next",depearture.stop_sequence+1)
depearture = depearture.select('trip_id','depearture_time','stop_id','stop_sequence_next')\
                       .withColumnRenamed("stop_sequence_next","stop_sequence")\
                       .withColumnRenamed("stop_id","depearture_stop")

arrival = stops_15.select('trip_id','arrival_time','stop_id','stop_sequence')\
                 .withColumnRenamed('stop_id','arrival_stop')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [28]:
%%spark
depearture.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+---------------+---------------+-------------+
|             trip_id|depearture_time|depearture_stop|stop_sequence|
+--------------------+---------------+---------------+-------------+
|5.TA.1-1-A-j19-1.3.H|       02:42:00|    8503305:0:3|            3|
|5.TA.1-1-A-j19-1.3.H|       02:46:00|    8503306:0:3|            4|
|5.TA.1-1-A-j19-1.3.H|       02:50:00|    8503147:0:1|            5|
|5.TA.1-1-A-j19-1.3.H|       02:55:00|    8503003:0:1|            6|
|5.TA.1-1-A-j19-1.3.H|       03:00:00|8503000:0:41/42|            7|
+--------------------+---------------+---------------+-------------+
only showing top 5 rows

In [29]:
%%spark
arrival.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+------------+---------------+-------------+
|             trip_id|arrival_time|   arrival_stop|stop_sequence|
+--------------------+------------+---------------+-------------+
|5.TA.1-1-A-j19-1.3.H|    02:42:00|    8503305:0:3|            2|
|5.TA.1-1-A-j19-1.3.H|    02:46:00|    8503306:0:3|            3|
|5.TA.1-1-A-j19-1.3.H|    02:50:00|    8503147:0:1|            4|
|5.TA.1-1-A-j19-1.3.H|    02:55:00|    8503003:0:1|            5|
|5.TA.1-1-A-j19-1.3.H|    02:58:00|8503000:0:41/42|            6|
+--------------------+------------+---------------+-------------+
only showing top 5 rows

In [10]:
%%spark
# get connection based on depearture and arrival timetable
connections = depearture.join(arrival,['trip_id','stop_sequence'], how = 'left')
connections = connections.dropna()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [12]:
%%spark
# get route_id for each connection
connections = connections.join(trip_info, ['trip_id'], how = 'left')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [13]:
%%spark
# selet revelent columns
connections = connections.select('depearture_time','depearture_stop','arrival_time','arrival_stop','trip_id','route_id')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [116]:
%%spark
connections.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------------+---------------+------------+------------+--------------------+-----------+
|depearture_time|depearture_stop|arrival_time|arrival_stop|             trip_id|   route_id|
+---------------+---------------+------------+------------+--------------------+-----------+
|       28:50:00|        8591429|    28:50:00|     8591180|1.TA.26-18-j19-1.1.H|26-18-j19-1|
|       28:45:00|        8591315|    28:46:00|     8591142|1.TA.26-18-j19-1.1.H|26-18-j19-1|
|       28:50:00|        8591180|    28:52:00|     8530812|1.TA.26-18-j19-1.1.H|26-18-j19-1|
|       28:52:00|        8530812|    28:53:00|     8591364|1.TA.26-18-j19-1.1.H|26-18-j19-1|
|       28:46:00|        8591142|    28:47:00|     8530811|1.TA.26-18-j19-1.1.H|26-18-j19-1|
+---------------+---------------+------------+------------+--------------------+-----------+
only showing top 5 rows

In [122]:
%%spark
# select connections within typical working hours
select_connections = connections.filter("depearture_time >= '07:00:00' and depearture_time <= '22:00:00'")
# sort dataframe
select_connections.sort(select_connections.depearture_time.desc())

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [None]:
%%spark
# save results to hdfs
select_connections.write.format("csv").save("hdfs:///group/Big-data-projects/connections.csv")

## Generate Transfer station list

In [105]:
stop_15.head()

Unnamed: 0,stop_id,stop_name,stop_lat,stop_lon,location_type,parent_station
2395,8500926,"Oetwil a.d.L., Schweizäcker",47.423627,8.403183,,
3793,8502186,Dietikon Stoffelbach,47.393406,8.398942,,8502186P
3797,8502187,Rudolfstetten Hofacker,47.364695,8.377095,,8502187P
3801,8502188,Zufikon Hammergut,47.355835,8.354727,,8502188P
3870,8502208,Horgen Oberdorf,47.258748,8.589799,,8502208P


In [92]:
# Get unique coordinate for each stop
stop_15 = stop_15[~stop_15['stop_id'].str.contains(':')]
stop_15['stop_id'] = stop_15['stop_id'].str[:7]

stop_15 = stop_15.drop_duplicates(subset='stop_id')

In [106]:
stop_15.head()

Unnamed: 0,stop_id,stop_name,stop_lat,stop_lon,location_type,parent_station
2395,8500926,"Oetwil a.d.L., Schweizäcker",47.423627,8.403183,,
3793,8502186,Dietikon Stoffelbach,47.393406,8.398942,,8502186P
3797,8502187,Rudolfstetten Hofacker,47.364695,8.377095,,8502187P
3801,8502188,Zufikon Hammergut,47.355835,8.354727,,8502188P
3870,8502208,Horgen Oberdorf,47.258748,8.589799,,8502208P


In [107]:
import json

stop_15 = stop_15.reset_index(drop=True)

i = 0
transfer_st = {s:[] for s in stop_15['stop_id']}

for i in range(len(stop_15)-1):
    
    lat1 = stop_15.loc[i,'stop_lat']
    lon1 = stop_15.loc[i,'stop_lon']
    
    for j in range(i+1, len(stop_15)):

        lat2 = stop_15.loc[j,'stop_lat']
        lon2 = stop_15.loc[j,'stop_lon']
        
        # calculate distance
        dis = dist(lat1, lon1, lat2, lon2)
        
        # possible to transfer by walking
        if dis<0.5:
            transfer_st[stop_15.loc[i,'stop_id']].append(stop_15.loc[j,'stop_id'])
            transfer_st[stop_15.loc[j,'stop_id']].append(stop_15.loc[i,'stop_id'])

In [111]:
with open("../data/transfer_station.json", 'w') as f:
    json.dump(transfer_st,f)

## Transfer station dictionary

In [119]:
%%spark
# Note: copied from https://stackoverflow.com/questions/19412462/getting-distance-between-two-points-based-on-latitude-longitude

from math import sin, cos, sqrt, atan2, radians

def dist(lat1, lon1, lat2, lon2):
    '''
    Calculate distance based on coordinates
    '''
    
    R = 6373.0
    
    lat1 = radians(lat1)
    lon1 = radians(lon1)
    lat2 = radians(lat2)
    lon2 = radians(lon2)
    
    dlon = lon2 - lon1
    dlat = lat2 - lat1

    a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
    c = 2 * atan2(sqrt(a), sqrt(1 - a))

    return R * c

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [120]:
%%spark
from pyspark.sql.types import StructType, IntegerType, StringType, DoubleType, BooleanType


schema = StructType() \
      .add("stop_id",StringType(),True) \
      .add("stop_name",StringType(),True) \
      .add("stop_lat",DoubleType(),True) \
      .add("stop_lon",DoubleType(),True) \
      .add("location_type",StringType(),True) \
      .add("parent_station",StringType(),True) 
      
stop_info = spark.read.format("csv") \
      .options(header = True)\
      .schema(schema) \
      .load('/data/sbb/csv/allstops')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [65]:
%%spark
stop_info.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+--------------------+----------------+----------------+-------------+--------------+
|stop_id|           stop_name|        stop_lat|        stop_lon|location_type|parent_station|
+-------+--------------------+----------------+----------------+-------------+--------------+
|1100008|Zell (Wiesental),...|47.7100842702352|7.85964788274668|         null|          null|
|1100009|Zell (Wiesental),...|47.7131911044794|7.86290876722849|         null|          null|
|1100010|           Atzenbach|47.7146175266411| 7.8723500608659|         null|          null|
|1100011|     Mambach, Brücke|47.7282088873189| 7.8774704579861|         null|          null|
|1100012|  Mambach, Mühlschau|47.7340818684375| 7.8813871126254|         null|          null|
+-------+--------------------+----------------+----------------+-------------+--------------+
only showing top 5 rows

In [121]:
%%spark
from pyspark.sql import functions as F
from pyspark.sql.functions import col

# filter by distance
zurich_hb = (47.3781762039461, 8.54021154209037)
dist_udf = F.udf(lambda lat,lon: dist(lat, lon, *zurich_hb) < 15, BooleanType())

stop_15 = stop_info.filter(dist_udf('stop_lat', 'stop_lon'))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [85]:
%%spark
stop_15.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

2122

In [38]:
def process_id(stop_id):
    '''
    Process stop_id to only contain the main station number.
    stop_id: possible format: 127, 127:0:1, 127P
    return: string
    '''
    if ':' in stop_id:
        return stop_id.split(":")[0]
    elif stop_id[-1] == 'P':
        return stop_id.split("P")[0]
    elif stop_id[0] == 'P':
        return stop_id.split("Parent")[1]
    else:
        return stop_id

In [123]:
# Get unique coordinate for each stop
stop_15.loc[:,'stop_id'] = stop_15.apply(lambda x:process_id(x['stop_id']),axis=1)

stop_15 = stop_15.drop_duplicates(subset='stop_id', keep = 'last')

In [None]:
stop_15.head()

In [128]:
stop_15.iloc[:,:4].to_csv("../data/stop_info.csv",index=False)

In [113]:
import json
import math
# Generate transfer station dictionary
# Each item:
# station id:[(station_id, waking time),...]
stop_15 = stop_15.reset_index(drop=True)

i = 0
transfer_st = {s:[] for s in stop_15['stop_id']}

for i in range(len(stop_15)-1):
    
    lat1 = stop_15.loc[i,'stop_lat']
    lon1 = stop_15.loc[i,'stop_lon']
    
    for j in range(i+1, len(stop_15)):

        lat2 = stop_15.loc[j,'stop_lat']
        lon2 = stop_15.loc[j,'stop_lon']
        
        # calculate distance
        dis = dist(lat1, lon1, lat2, lon2)
        
        # possible to transfer by walking
        if dis<0.5:
            transfer_st[stop_15.loc[i,'stop_id']].append((stop_15.loc[j,'stop_id'], math.ceil(dis*20)))
            transfer_st[stop_15.loc[j,'stop_id']].append((stop_15.loc[i,'stop_id'], math.ceil(dis*20)))

In [114]:
with open("../data/transfer_station.json", 'w') as f:
    json.dump(transfer_st,f)

## Filter stop with connection data

In [None]:
!git lfs pull

In [25]:
import pandas as pd
import glob
import os

path = r'../data/connections' # use your path
all_files = glob.glob(os.path.join(path , "*.csv"))

li = set([])
for filename in all_files:
    df = pd.read_csv(filename, index_col=None,header=None)
    li = li.union(set(list(df[1].values)+list(df[3].values)))

In [73]:
import pandas as pd
stop_info = pd.read_csv("../data/stop_info.csv")

# stop_id exists in connections
stopid = pd.Series(list(li),name = 'stop_id')
stop_id = set((stopid.apply(lambda x:int(process_id(x)))).drop_duplicates())

# get stop_info only contains stops in appearring in connections
stop_info_filtered = stop_info[stop_info['stop_id'].isin(stop_id)]

In [71]:
stop_info_filtered.to_csv("../data/stop_info_filtered.csv",index=False)

## Merge delay with connection

In [12]:
import pandas as pd
import glob
import os

In [3]:
stop_info = pd.read_csv("../data/stop_info_filtered.csv")
delay = pd.read_csv("../data/dist_all_delay_prediction.csv")

In [4]:
stop_info = stop_info[['stop_id','stop_name']]

In [5]:
delay = delay.iloc[:,[1,2,3,6,7,8,9,10,11,12,13,14,15,16]]

In [6]:
delay['delay'] = delay.iloc[:,3:].values.tolist()

In [8]:
delay = delay.iloc[:,[0,1,2,-1]]

In [24]:
delay = delay[delay['day']==3][['stop_name','hour','delay']].reset_index(drop=True)

In [37]:
path = r'../data/connections' # use your path
all_files = glob.glob(os.path.join(path , "*.csv"))


for filename in all_files:
    if filename == '../data/connections/part-00160-613e9e3a-5030-4ca1-9cd0-20a9c88748fe-c000.csv':
        continue
    temp = pd.read_csv(filename, index_col=None,header=None)
    temp.columns = ["departure_time", "departure_stop", "arrival_time", "arrival_stop", "trip_id","route_id"] 
    
    # get stop name
    temp['main_arrival'] = temp['arrival_stop'].str.split(':').str[0].astype(int)
    temp = temp.merge(stop_info, left_on='main_arrival',right_on = 'stop_id', how = 'left')\
           [["departure_time", "departure_stop", "arrival_time", "arrival_stop", "trip_id","route_id","stop_name"]]
    
    # get delay
    temp['hour'] = temp['arrival_time'].str.split(":").str[0].astype(int)
    temp = temp.merge(delay, left_on = ['hour','stop_name'], right_on = ['hour','stop_name'],how='left')\
          [["departure_time", "departure_stop", "arrival_time", "arrival_stop", "trip_id","route_id","delay"]]
    
    # save to file
    temp.to_csv(filename, index=False, header=False)


In [56]:
temp = pd.read_csv('../data/connections/part-00160-613e9e3a-5030-4ca1-9cd0-20a9c88748fe-c000.csv',converters={'delay': pd.eval}).dropna()

In [55]:
import ast
len(ast.literal_eval(temp['Unnamed: 6'][20]))

11

In [57]:
temp['Unnamed: 6'][20]

'[0.0, 0.4132354931299031, 0.7349567438082295, 0.8802791804373937, 0.9459217531406592, 0.9755726965947596, 0.9889661150960812, 0.995015961686266, 0.9977486952121508, 0.9989830789956357, 0.9995406537867736]'