In [1]:
%load_ext sparkmagic.magics
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import glob
import os

username = os.environ['RENKU_USERNAME']
server = "http://iccluster029.iccluster.epfl.ch:8998"
from IPython import get_ipython
get_ipython().run_cell_magic('spark', line="config", 
                    cell="""{{ "name":"{0}-demo2",
                               "executorMemory":"4G",
                               "executorCores":4,
                               "numExecutors":10 }}""".format(username))

In [3]:
get_ipython().run_line_magic(
    "spark", "add -s {0}-demo2 -l python -u {1} -k".format(username, server)
)

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
8521,application_1652960972356_4291,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


# Load stop data

In [4]:
%%spark -o df_stops -n 50000
df_stops = spark.read.orc('hdfs:///data/sbb/orc/allstops/000000_0')
df_stops.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

46689

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
df_stops['stop_id'] = df_stops['stop_id'].apply(lambda x: x.split(':')[0])
df_stops = df_stops.drop_duplicates(subset=['stop_id'])

In [6]:
df_stops = df_stops.set_index('stop_id')

In [7]:
def stop_id_to_name(stop_id, show_city=False):
    '''
    Convert stop_id to stop_name
    '''
    return df_stops.loc[stop_id]['stop_name']

def stop_id_to_latlon(stop_id):
    '''
    Convert stop_id to stop coordinates
    '''
    return (df_stops.loc[stop_id]['stop_lon'], df_stops.loc[stop_id]['stop_lat'])
    
print(stop_id_to_name('8503016'))
print(stop_id_to_latlon('8503016'))


Zürich Flughafen
(8.56239992961121, 47.4503866318972)


# Load Route data

In [8]:
%%spark -o df_route -n 50000
df_route = spark.read.csv('hdfs:///data/sbb/csv/routes/2019/05/15/routes.txt', header=True)
df_route.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

5026

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [9]:
df_route = df_route.set_index('route_id')

In [10]:
def route_id_to_name(route_id):
    try:
        if not route_id:
            return "walk"
        else:
            route_name = df_route.loc[route_id]['route_short_name']
            transport_type = df_route.loc[route_id]['route_desc']
            return f'{transport_type} {route_name}'
    except KeyError:
        print(f'WARNING: route id {route_id} not found')
        return None

print(route_id_to_name('11-40-j19-1'))
print(route_id_to_name('26-9-B-j19-1'))

Bus 040
Tram 9


# CSA

In [18]:
import csa
import ast

In [12]:
import json

with open("../data/transfer_station.json") as f:
    nearby_station_dic = json.load(f)

In [13]:
cached_connections = None
cached_deadline = None

def query_csa(origin_id, destination_id, deadline, confidence):
    """
    Returns a list of Itineraries.
    - origin_id: id of origin station.
    - destination_id: id of destination station.
    - deadline: latest acceptable arrival time.
    - confidence: (In percentage) only itineraries with confidence larger than this threshold will be returned.
    """
    global cached_connections
    global cached_deadline
    
    # generate connections
    if cached_deadline == deadline:
        connections = cached_connections
        print('Using cached connections')
    else:    # get data
        path = r'../data/connections' # use your path
        all_files = glob.glob(os.path.join(path , "*.csv"))

        li = []
        hour, minute, second = deadline.split(":")
        for filename in all_files:
            df = pd.read_csv(filename, index_col=None,header=None)

            li.append(df[(df[0]>=":".join([str(int(hour)-2).zfill(2), minute, second])) & (df[2]<=deadline)])

        df_selected = pd.concat(li, axis=0, ignore_index=True)
        df_selected.columns = ["departure_time", "departure_stop", "arrival_time", "arrival_stop", "trip_id","route_id","delay"]    
        origin_id, destination_id = str(origin_id), str(destination_id)
        df_selected = df_selected.sort_values(by=['departure_time'],ascending=False)

        
        connections = []

        print('Creating connections...', end='')
        for idx, row in df_selected.iterrows():
            connections.append(
                csa.Connection(
                    start_station=row['departure_stop'].split(':')[0],
                    start_time=row['departure_time'],
                    end_station=row['arrival_stop'].split(':')[0],
                    end_time=row['arrival_time'],
                    route_id=row['route_id'],
                    trip_id=row['trip_id'],
                    delay=ast.literal_eval(row['delay']) if isinstance(row['delay'],str) else [1]*11
                )
            )
        print('done')
        cached_connections = connections
        cached_deadline = deadline
    
    stations = set([c.start_station for c in connections] + [c.end_station for c in connections])
    
    # No routes available
    if origin_id not in stations:
        return None
    
    deadline = csa.hhmm_to_int(deadline)
    print('Running CSA...', end='')
    results = csa.csa(stations, connections, origin_id, destination_id, deadline, confidence/100, nearby_station_dic)
    print('done')
    itineraries = [csa.build_itinerary_from_profile_entry(r, origin_id, destination_id, nearby_station_dic) for r in results]

    itineraries.sort(key=lambda x: x.legs[0].start_time, reverse=True)
    
    return list(filter(lambda x: x.confidence >= confidence/100, itineraries))

Some station ids to try out

- 8591123 Zürich, ETH/Universitätsspital
- 8503016 Zürich Flughafen
- 8503006 Zürich Oerlikon
- 8591105 Zürich, Bürkliplatz
- 8503000 Zürich HB
- 8587348 Zürich, Bahnhofplatz/HB
- 8591283 Zürich, Museum Rietberg

## Visualization

In [14]:
import pandas as pd
stop_info = pd.read_csv("../data/stop_info_filtered.csv")
stop_info['stop_id'] = stop_info['stop_id'].astype(str)
stop_info = stop_info.sort_values("stop_name")

In [15]:
from datetime import datetime, timedelta

# Create starting and end datetime object from string
start = datetime.strptime("08:00:00", "%H:%M:%S")
end = datetime.strptime("22:05:00", "%H:%M:%S")

# min_gap
min_gap = 5

# compute datetime interval
arr = [(start + timedelta(hours=min_gap*i/60)).strftime("%H:%M:%S")
       for i in range(int((end-start).total_seconds() / 60.0 / min_gap))]

In [16]:
import matplotlib.pyplot as plt
%matplotlib inline

In [17]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display

import pandas as pd
import plotly.express as px

def plot_on_map(route):
    
    # print itinerary
    for leg in route.legs:
        print('\t{} {:30} => {} {:30} {:10}'\
              .format(csa.to_hhmm(leg.start_time),\
                      stop_id_to_name(leg.start_station),\
                      csa.to_hhmm(leg.end_time),\
                      stop_id_to_name(leg.end_station),\
                      route_id_to_name(leg.route_id) if leg.route_id else 'Walk'
                     ))
    print('')
    
    waypoints_df = route.waypoints_df(stop_id_to_name, stop_id_to_latlon, route_id_to_name)
    # Draw route on map
    #color_discrete_map = {False: 'rgb(255,0,0)',True: 'rgb(0,0,255)'}
    fig = px.line_mapbox(waypoints_df, lat="lat", lon="lon", hover_data = {'station':True,'time':True,'route':True,'lat':False,'lon':False,'is_walking':False},
                         zoom=10, height=500, width=500, color_discrete_sequence=px.colors.qualitative.Set1)

    fig.update_layout(mapbox_style="open-street-map", mapbox_zoom=10, showlegend=False, margin={"r":0,"t":0,"l":0,"b":0})
    fig.update_traces(line=dict(width=4))

    fig.show()

def query(departure, arrival, t, confidence):
    
    if departure == arrival:
        print('Walk')
        return
        
    print(f"Query from {stop_info[stop_info['stop_id']==departure]['stop_name'].values[0]} to {stop_info[stop_info['stop_id']==arrival]['stop_name'].values[0]} before {t} with confidence {confidence}%")
    results = query_csa(departure, arrival, t, confidence)
    

    if not results:
        print("Arrival too early")
    elif len(results)==0:
        print("No route available")
    else:
        route = widgets.Dropdown(
            options=[(itinerary.summary(), itinerary) for itinerary in results],
            value=results[0],
            description='Route',
            disabled=False,
            width='auto'
        )

        interact(plot_on_map, route=route)

# Query choice
                          
departure = widgets.Dropdown(
    options=list(zip(stop_info.stop_name,stop_info.stop_id)),
    value=stop_info.stop_id[0],
    description='Departure',
    disabled=False,
)

arrival = widgets.Dropdown(
    options=list(zip(stop_info.stop_name,stop_info.stop_id)),
    value=stop_info.stop_id[0],
    description='Arrival',
)

arr_time = widgets.Dropdown(
    options=arr,
    value='08:00:00',
    description='Arrival Time',
)

confidence = widgets.IntSlider(
    value=80,
    min=0,
    max=100,
    step=1,
    description='Confidence',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

my_interact_manual = interact_manual.options(manual_name="search")
_ = my_interact_manual(query,departure = departure, arrival = arrival, t = arr_time, confidence = confidence)

interactive(children=(Dropdown(description='Departure', index=672, options=(('Adlikon b. R., Dorf', '8576253')…