In [1]:
import pandas as pd
import os
import numpy as np
import geopandas as gdp
import folium

from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import StructType, StructField, DateType, IntegerType
from pyspark.sql.functions import col,\
        to_timestamp,hour,dayofmonth,date_format, split, date_trunc, avg,\
        to_date, max, min, desc, lit, sequence, explode, isnan, isnull, \
        unix_timestamp, udf, sum




In [2]:
# Create a Spark session
spark = (
    SparkSession.builder.appName("MAST30034 landing to curated")
    .config("spark.sql.repl.eagerEval.enabled", True)
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config('spark.driver.memory', '4g')
    .config('spark.executor.memory', '2g')
    .getOrCreate()
)

your 131072x1 screen size is bogus. expect trouble
23/08/21 03:29:32 WARN Utils: Your hostname, DESKTOP-LHMPQFC resolves to a loopback address: 127.0.1.1; using 172.19.194.216 instead (on interface eth0)
23/08/21 03:29:32 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/08/21 03:29:34 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Weather datasets

In [3]:
# quality check
sample_2023 = spark.read.parquet("../data/raw/weather_2023", header = True)
sample_2023.show(20, truncate = False)

                                                                                

+-------------------+--------------+-------+
|date               |wnd           |tmp    |
+-------------------+--------------+-------+
|2023-01-01 00:00:00|999,9,C,0000,1|+0100,1|
|2023-01-01 00:51:00|200,5,N,0021,5|+0089,5|
|2023-01-01 01:49:00|190,5,N,0026,5|+0090,5|
|2023-01-01 01:51:00|190,5,N,0026,5|+0094,5|
|2023-01-01 02:03:00|190,5,N,0031,5|+0094,5|
|2023-01-01 02:51:00|230,5,N,0026,5|+0094,5|
|2023-01-01 03:00:00|230,1,N,0026,1|+0094,1|
|2023-01-01 03:51:00|180,5,N,0036,5|+0094,5|
|2023-01-01 04:05:00|160,5,N,0036,5|+0094,5|
|2023-01-01 04:18:00|240,5,N,0041,5|+0094,5|
|2023-01-01 04:49:00|230,5,N,0021,5|+0090,5|
|2023-01-01 04:51:00|240,5,N,0021,5|+0094,5|
|2023-01-01 04:59:00|999,9,9,9999,9|+9999,9|
|2023-01-01 04:59:00|999,9,9,9999,9|+9999,9|
|2023-01-01 05:51:00|240,5,N,0046,5|+0089,5|
|2023-01-01 06:00:00|240,1,N,0046,1|+0089,1|
|2023-01-01 06:49:00|230,5,N,0041,5|+0080,5|
|2023-01-01 06:51:00|240,5,N,0041,5|+0078,5|
|2023-01-01 07:15:00|250,5,N,0046,5|+0078,5|
|2023-01-0

In [4]:
def extract_information(column, df):
    """
    Extracts information from specified columns in the DataFrame.

    Args:
        column (str): The column to extract information from ('tmp' or 'wnd').
        df (DataFrame): The input DataFrame.

    Returns:
        DataFrame: The DataFrame with extracted information columns added.
    """
    # tmp refers to the temperature observation +0078,1. 
    # In this case, “+0078” indicates the temperature in Celsius. 
    # The temperature value is scaled by 10 to avoid missing decimal places
    # The second part indicates the quality (1 = Passed quality control checks) 
    if column == 'tmp':
        # Split the 'tmp' column and create new columns
        df = df.select("*", split(df["tmp"], ",").alias("tmp_components"))
        # observation
        df = df.withColumn("tmp_observation", df.tmp_components[0])
        df = df.withColumn("tmp_observation", df.tmp_components[0].cast("int"))
        # quality
        df = df.withColumn("tmp_quality", df.tmp_components[1])
        # Drop the original 'tmp' column and the temporary 'tmp_components' 
        df = df.drop("tmp_components")
    elif column == 'wnd':
        # Split the 'wnd' column and create new columns
        df = df.select("*", split(df["wnd"], ",").alias("wnd_components"))

        # split to extract different attribs
        df = df.withColumn("direction", df.wnd_components[0])
        df = df.withColumn("direction_quality", df.wnd_components[1])
        df = df.withColumn("observation_type", df.wnd_components[2])
        df = df.withColumn("speed", df.wnd_components[3])
        df = df.withColumn("speed", df.wnd_components[3].cast("int"))
        df = df.withColumn("speed_quality", df.wnd_components[4])
        # Drop the original 'wnd' column and the temporary 'wnd_components' 

        df = df.drop("wnd_components")
    return df

def count(df):
    # Get the number of rows and columns of df
    num_rows = df.count()
    num_columns = len(df.columns)
    print("Number of rows:", num_rows)
    print("Number of columns:", num_columns)
    
# Extract information applied
sample_2022 = spark.read.parquet("../data/raw/weather_2022", header = True)
sample_2022 = extract_information("tmp", sample_2022)
sample_2022 = extract_information("wnd", sample_2022)
sample_2023 = spark.read.parquet("../data/raw/weather_2023", header = True)
sample_2023 = extract_information("tmp", sample_2023)
sample_2023 = extract_information("wnd", sample_2023)
count(sample_2022)
count(sample_2023)

def filter_invalid_rows(df, quality = None):
    """
    Filters out rows with invalid values in specified columns.

    Args:
        df (DataFrame): The input DataFrame.
        quality (bool, optional): Whether to filter based on quality. 
        Defaults to None.

    Returns:
        DataFrame: The filtered DataFrame.
    """
    # filter for 9999 for tmp_observation and 9999 for speed (wind)
    # filter for quality purpose
    filtered_df = df.filter \
        (~((col("tmp_observation") == 9999) | (col("speed") == 9999)))
    if quality:
        return filtered_df
    return filtered_df 

print("Filtering applied")
# Filtering started
sample_2022 = filter_invalid_rows(sample_2022)
sample_2023 = filter_invalid_rows(sample_2023)
count(sample_2022)
count(sample_2023)

def aggregate_hourly(df, cols):
    """
    Aggregates the DataFrame by hour for specified columns.

    Args:
        df (DataFrame): The input DataFrame.
        cols (list): List of columns to aggregate.

    Returns:
        list: List of DataFrames containing aggregated results for each column.
    """
    df = df.withColumn("year_month_date", to_date(col("date")))
    df = df.withColumn("hour", hour(col("date")))
    aggregations_lst = []
    for column in cols:
        print(column)
        aggregation = df.groupBy(col("year_month_date"), \
            col("hour")).agg(avg(col(column)).alias(f"avg_{column}")) \
            .orderBy("year_month_date", "hour")
        aggregations_lst.append(aggregation)

    return aggregations_lst

print("Hourly aggregation applied")
# Hourly aggregation starting
sample_2022_tmp, sample_2022_speed = \
    aggregate_hourly(sample_2022, cols = ['tmp_observation', 'speed'])
sample_2023_tmp, sample_2023_speed = \
    aggregate_hourly(sample_2023, cols = ['tmp_observation', 'speed'])

# Merge tmp and speed
count(sample_2022_tmp)
count(sample_2023_tmp)
count(sample_2022_speed)
count(sample_2023_speed)

# note that the number of rows for 2022 should now be 8760 = 365 days * 24 hrs
# Rows for 2023 should be 5239 rows

print("Merging aggregation applied")
sample_2022 = sample_2022_tmp.join(sample_2022_speed, on = \
    ['year_month_date', 'hour'], how = 'inner')\
        .orderBy("year_month_date", "hour")
sample_2023 = sample_2023_tmp.join(sample_2023_speed, on = \
    ['year_month_date', 'hour'], how = 'inner')\
        .orderBy("year_month_date", "hour")
count(sample_2022)
count(sample_2023)

Number of rows: 13344
Number of columns: 10
Number of rows: 8475
Number of columns: 10
Filtering applied
Number of rows: 12967
Number of columns: 10
Number of rows: 8235
Number of columns: 10
Hourly aggregation applied
tmp_observation
speed
tmp_observation
speed
Number of rows: 8760
Number of columns: 3
Number of rows: 5528
Number of columns: 3
Number of rows: 8760
Number of columns: 3
Number of rows: 5528
Number of columns: 3
Merging aggregation applied
Number of rows: 8760
Number of columns: 4
Number of rows: 5528
Number of columns: 4


In [6]:
sample_2023 = sample_2023.filter(col('year_month_date') < '2023-03-01')
sample_2022 = sample_2022.filter(col('year_month_date') > ('2022-01-31'))
print(sample_2023.tail(1))
print(sample_2023.count())
print(sample_2022.count())
merged_weather = sample_2022.union(sample_2023)
print(merged_weather.count())

[Row(year_month_date=datetime.date(2023, 2, 28), hour=23, avg_tmp_observation=22.0, avg_speed=51.0)]
1416
8016
9432


In [7]:
merged_weather.toPandas().to_csv("../data/curated/weather.csv", \
                                 header= True, index = False)

In [8]:
merged_weather

                                                                                

year_month_date,hour,avg_tmp_observation,avg_speed
2022-02-01,0,-47.0,26.0
2022-02-01,1,-50.0,21.0
2022-02-01,2,-56.0,26.0
2022-02-01,3,-56.0,26.0
2022-02-01,4,-61.0,31.0
2022-02-01,5,-67.0,31.0
2022-02-01,6,-67.0,33.5
2022-02-01,7,-61.0,31.0
2022-02-01,8,-67.0,26.0
2022-02-01,9,-64.0,31.0


## MTA subway

In [9]:
mta = spark.read.parquet("../data/raw/mta", header = True)
mta

transit_timestamp,borough,ridership,georeference
2023-08-12 11:00:00,M,283.0,POINT (-73.937965...
2023-08-12 13:00:00,BK,148.0,POINT (-73.92261 ...
2022-10-08 21:00:00,M,117.0,POINT (-73.94748 ...
2022-05-20 10:00:00,M,356.0,POINT (-73.94748 ...
2022-03-14 08:00:00,M,845.0,POINT (-73.968376...
2022-05-03 22:00:00,M,470.0,POINT (-73.98163 ...
2023-01-28 18:00:00,M,,POINT (-73.98163 ...
2023-07-07 08:00:00,M,730.0,POINT (-73.968376...
2022-02-28 20:00:00,M,482.0,POINT (-73.98163 ...
2022-03-03 00:00:00,M,43.0,POINT (-73.968376...


In [12]:
mta.show(30)

+-------------------+-------+---------+--------------------+
|  transit_timestamp|borough|ridership|        georeference|
+-------------------+-------+---------+--------------------+
|2023-08-12 11:00:00|      M|      283|POINT (-73.937965...|
|2023-08-12 13:00:00|     BK|      148|POINT (-73.92261 ...|
|2022-10-08 21:00:00|      M|      117|POINT (-73.94748 ...|
|2022-05-20 10:00:00|      M|      356|POINT (-73.94748 ...|
|2022-03-14 08:00:00|      M|      845|POINT (-73.968376...|
|2022-05-03 22:00:00|      M|      470|POINT (-73.98163 ...|
|2023-01-28 18:00:00|      M|     null|POINT (-73.98163 ...|
|2023-07-07 08:00:00|      M|      730|POINT (-73.968376...|
|2022-02-28 20:00:00|      M|      482|POINT (-73.98163 ...|
|2022-03-03 00:00:00|      M|       43|POINT (-73.968376...|
|2022-09-08 13:00:00|      M|      396|POINT (-73.94748 ...|
|2022-12-18 02:00:00|      M|       16|POINT (-73.968376...|
|2023-03-23 17:00:00|      M|      680|POINT (-73.968376...|
|2022-06-17 11:00:00|   

In [13]:
mta.filter(col("borough") == 'M')
mta.filter(col("ridership") > 800)


transit_timestamp,borough,ridership,georeference
2022-03-14 08:00:00,M,845,POINT (-73.968376...
2023-01-25 19:00:00,M,997,POINT (-73.98163 ...
2022-05-31 09:00:00,M,899,POINT (-73.98163 ...
2023-08-12 17:00:00,M,978,POINT (-73.98163 ...
2022-09-21 09:00:00,M,805,POINT (-73.968376...
2022-08-10 08:00:00,M,876,POINT (-73.968376...
2023-08-02 15:00:00,M,902,POINT (-73.98163 ...
2022-02-15 08:00:00,M,802,POINT (-73.94748 ...
2022-04-14 14:00:00,M,911,POINT (-73.98163 ...
2022-09-18 16:00:00,M,983,POINT (-73.98163 ...


In [14]:
print(mta.filter(col("ridership") == 999).count())
mta.filter(col("ridership") == 999)

431


transit_timestamp,borough,ridership,georeference
2023-05-08 14:00:00,M,999,POINT (-73.98163 ...
2022-07-30 17:00:00,M,999,POINT (-73.98163 ...
2022-10-27 17:00:00,M,999,POINT (-73.94748 ...
2023-08-05 18:00:00,M,999,POINT (-73.98163 ...
2023-07-01 15:00:00,M,999,POINT (-73.98163 ...
2023-06-26 09:00:00,M,999,POINT (-73.98163 ...
2022-11-15 15:00:00,Q,999,POINT (-73.8627 4...
2023-05-11 15:00:00,Q,999,POINT (-73.8627 4...
2022-08-26 09:00:00,Q,999,POINT (-73.8627 4...
2023-04-13 15:00:00,Q,999,POINT (-73.8627 4...


In [15]:
mta.printSchema()

root
 |-- transit_timestamp: timestamp (nullable = true)
 |-- borough: string (nullable = true)
 |-- ridership: integer (nullable = true)
 |-- georeference: string (nullable = true)



In [27]:
mta = spark.read.parquet("../data/raw/mta", header = True)

def min_max_timestamp(df, column = 'transit_timestamp'):
    """
    Print the minimum and maximum from specified column in the DataFrame.

    Args:
        df (DataFrame): The input DataFrame.
        column (str, optional): The name of the column. 
            Defaults to 'transit_timestamp'.

    Returns:
        None
    """
    
    print("MAX_TIMESTAMP")
    print(df.agg(max(col(column))))
    print("MIN_TIMESTAMP")
    print(df.agg(min(col(column))))
    
def find_nan(df):
    """
    Print rows with NaN or NULL values in the 'sum(ridership)' column.

    Args:
        df (DataFrame): The input DataFrame.

    Returns:
        None
    """
    print(df.filter(isnull(col("sum(ridership)")) | \
                    isnan(col("sum(ridership)"))))

def count(df):
    """
    Print the number of rows in the DataFrame.

    Args:
        df (DataFrame): The input DataFrame.

    Returns:
        None
    """
    print(df.count())

def time_filter(df, date):
    """
    Filter the DataFrame to keep rows with transit_timestamp values e
        arlier than the specified date.

    Args:
        df (DataFrame): The input DataFrame.
        date: The date used for filtering.

    Returns:
        DataFrame: The filtered DataFrame.
    """
    return df.filter(col("transit_timestamp") < lit(date))

def extract_borough(df, borough):
    """
    Extract rows from the DataFrame based on the specified borough.

    Args:
        df (DataFrame): The input DataFrame.
        borough (str): The borough name used for filtering.

    Returns:
        DataFrame: The filtered DataFrame.
    """
    return df.filter(df.borough == borough) 

mta_filtered = time_filter(mta, '2023-03-01')
mta_filtered = mta_filtered.drop("georeference")
grouped = mta_filtered.groupBy("borough", "transit_timestamp"). \
    agg(sum(col('ridership'))).orderBy("borough", 'transit_timestamp')

bk = extract_borough(grouped, 'BK')
queens = extract_borough(grouped, 'Q')
manhattan = extract_borough(grouped, 'M')
bx = extract_borough(grouped, 'BX')

count(bk)
count(queens)
count(manhattan)
count(bx)

min_max_timestamp(bk)
min_max_timestamp(queens)
min_max_timestamp(manhattan)
min_max_timestamp(bx)

                                                                                

9431
9431


                                                                                

9431


                                                                                

9431
MAX_TIMESTAMP


                                                                                

+----------------------+
|max(transit_timestamp)|
+----------------------+
|   2023-02-28 23:00:00|
+----------------------+

MIN_TIMESTAMP


                                                                                

+----------------------+
|min(transit_timestamp)|
+----------------------+
|   2022-02-01 00:00:00|
+----------------------+

MAX_TIMESTAMP
+----------------------+
|max(transit_timestamp)|
+----------------------+
|   2023-02-28 23:00:00|
+----------------------+

MIN_TIMESTAMP
+----------------------+
|min(transit_timestamp)|
+----------------------+
|   2022-02-01 00:00:00|
+----------------------+

MAX_TIMESTAMP
+----------------------+
|max(transit_timestamp)|
+----------------------+
|   2023-02-28 23:00:00|
+----------------------+

MIN_TIMESTAMP
+----------------------+
|min(transit_timestamp)|
+----------------------+
|   2022-02-01 00:00:00|
+----------------------+

MAX_TIMESTAMP


                                                                                

+----------------------+
|max(transit_timestamp)|
+----------------------+
|   2023-02-28 23:00:00|
+----------------------+

MIN_TIMESTAMP
+----------------------+
|min(transit_timestamp)|
+----------------------+
|   2022-02-01 00:00:00|
+----------------------+



                                                                                

In [28]:
mta_filtered.count(), len(mta_filtered.columns)

(3913645, 3)

In [26]:
mta_filtered.show(1)

+-------------------+-------+---------+--------------------+
|  transit_timestamp|borough|ridership|        georeference|
+-------------------+-------+---------+--------------------+
|2022-10-08 21:00:00|      M|      117|POINT (-73.94748 ...|
+-------------------+-------+---------+--------------------+
only showing top 1 row



In [19]:
# Define the start and end dates
# 59 days in 2023: 01/01/2023 --> 28/02/2023, 334 days in 2022: 01/02/2022 -->
# total hours: 9432 hours = 393 days 

# This code below is to fill in the missing timestamp out of 9432 hours
start_date = "2022-02-01 00:00:00"
end_date = "2023-02-28 23:00:00"  # Adjusted to include the full range of the last day

# Generate hourly timestamps between start_date and end_date
timestamps_df = spark.sql(f"SELECT sequence(to_timestamp('{start_date}'), \
                          to_timestamp('{end_date}'), interval 1 hour) AS transit_timestamp")

# Explode the array of timestamps to get one row per timestamp
exploded_df = timestamps_df.\
    select(explode("transit_timestamp").alias("transit_timestamp"))

print(exploded_df.count())
# Show the generated timestamps
# exploded_df.show(truncate=False)
joined_bk = exploded_df.join(bk, "transit_timestamp", "left")\
    .orderBy('transit_timestamp')
joined_queens = exploded_df.join(queens, "transit_timestamp", "left").\
    orderBy('transit_timestamp')
joined_manhattan = exploded_df.join(manhattan, "transit_timestamp", "left").\
    orderBy('transit_timestamp')
joined_bx = exploded_df.join(bx, "transit_timestamp", "left")\
    .orderBy('transit_timestamp')

9432


In [20]:
# 5 datasets in total
# quality check: done -->
data_lst = [joined_bk, joined_queens, joined_manhattan, joined_bx]
def inspect_first_row(df):
    print(df.show(1))
for data in data_lst:
    inspect_first_row(data)
    count(data)

                                                                                

+-------------------+-------+--------------+
|  transit_timestamp|borough|sum(ridership)|
+-------------------+-------+--------------+
|2022-02-01 00:00:00|     BK|          1796|
+-------------------+-------+--------------+
only showing top 1 row

None
9432


                                                                                

+-------------------+-------+--------------+
|  transit_timestamp|borough|sum(ridership)|
+-------------------+-------+--------------+
|2022-02-01 00:00:00|      Q|          1237|
+-------------------+-------+--------------+
only showing top 1 row

None


                                                                                

9432


                                                                                

+-------------------+-------+--------------+
|  transit_timestamp|borough|sum(ridership)|
+-------------------+-------+--------------+
|2022-02-01 00:00:00|      M|          7407|
+-------------------+-------+--------------+
only showing top 1 row

None


                                                                                

9432


                                                                                

+-------------------+-------+--------------+
|  transit_timestamp|borough|sum(ridership)|
+-------------------+-------+--------------+
|2022-02-01 00:00:00|     BX|           712|
+-------------------+-------+--------------+
only showing top 1 row

None




9432


                                                                                

In [21]:
joined_bx.filter(col("transit_timestamp") == '2022-03-13 03:00:00')

                                                                                

transit_timestamp,borough,sum(ridership)
2022-03-13 03:00:00,BX,490


In [22]:
find_nan(joined_bx)
find_nan(joined_bk)
find_nan(joined_queens)
find_nan(joined_manhattan)

                                                                                

+-------------------+-------+--------------+
|  transit_timestamp|borough|sum(ridership)|
+-------------------+-------+--------------+
|2022-03-13 02:00:00|   null|          null|
+-------------------+-------+--------------+



                                                                                

+-------------------+-------+--------------+
|  transit_timestamp|borough|sum(ridership)|
+-------------------+-------+--------------+
|2022-03-13 02:00:00|   null|          null|
+-------------------+-------+--------------+



                                                                                

+-------------------+-------+--------------+
|  transit_timestamp|borough|sum(ridership)|
+-------------------+-------+--------------+
|2022-03-13 02:00:00|   null|          null|
+-------------------+-------+--------------+





+-------------------+-------+--------------+
|  transit_timestamp|borough|sum(ridership)|
+-------------------+-------+--------------+
|2022-03-13 02:00:00|   null|          null|
+-------------------+-------+--------------+



                                                                                

In [23]:
def save_to_csv(df, name):
    df = df.withColumn("transit_timestamp", \
            date_format("transit_timestamp", "yyyy-MM-dd HH:mm:ss")).toPandas()
    df.to_csv(f"../data/curated/{name}.csv", header= True, index = False)
save_to_csv(joined_bk, name = "mta_brooklyn")
save_to_csv(joined_queens, name = "mta_queens")
save_to_csv(joined_manhattan, name = "mta_manhattan")
save_to_csv(joined_bx, name = "mta_bronx")

                                                                                

In [24]:
# merge
mta_queens = pd.read_csv("../data/curated/mta_queens.csv")
mta_queens

Unnamed: 0,transit_timestamp,borough,sum(ridership)
0,2022-02-01 00:00:00,Q,1237.0
1,2022-02-01 01:00:00,Q,610.0
2,2022-02-01 02:00:00,Q,533.0
3,2022-02-01 03:00:00,Q,886.0
4,2022-02-01 04:00:00,Q,3753.0
...,...,...,...
9427,2023-02-28 19:00:00,Q,11867.0
9428,2023-02-28 20:00:00,Q,10234.0
9429,2023-02-28 21:00:00,Q,7861.0
9430,2023-02-28 22:00:00,Q,6025.0


## Taxi

In [29]:
# use pandas to separate zones and map a region to a letter
zone_map = pd.read_csv("../data/landing/taxi_zones.csv")

# filter for required zone map
zone_map = zone_map[(zone_map['LocationID'] > 1) & \
                    (zone_map['LocationID'] < 264)]
def abbreviation_column(x):
    if x == 'Queens':
        return 'Q'
    elif x == 'Bronx':
        return 'BX'
    elif x == 'Manhattan':
        return "M"
    elif x == 'Staten Island':
        return 'S'
    elif x == 'Brooklyn':
        return "BK"
    else:
        return "Others"
    
# append the borough column
zone_map['abbr_col'] = zone_map['Borough'].\
    apply(lambda x: abbreviation_column(x))
print(zone_map['abbr_col'].value_counts())
zone_map = zone_map[['LocationID', 'abbr_col']]
# Convert DataFrame to dictionary
zone_dict = zone_map.set_index('LocationID')['abbr_col'].to_dict()
zone_dict

abbr_col
Q     69
M     69
BK    61
BX    43
S     20
Name: count, dtype: int64


{2: 'Q',
 3: 'BX',
 4: 'M',
 5: 'S',
 6: 'S',
 7: 'Q',
 8: 'Q',
 9: 'Q',
 10: 'Q',
 11: 'BK',
 12: 'M',
 13: 'M',
 14: 'BK',
 15: 'Q',
 16: 'Q',
 17: 'BK',
 18: 'BX',
 19: 'Q',
 20: 'BX',
 21: 'BK',
 22: 'BK',
 23: 'S',
 24: 'M',
 25: 'BK',
 26: 'BK',
 27: 'Q',
 28: 'Q',
 29: 'BK',
 30: 'Q',
 31: 'BX',
 32: 'BX',
 33: 'BK',
 34: 'BK',
 35: 'BK',
 36: 'BK',
 37: 'BK',
 38: 'Q',
 39: 'BK',
 40: 'BK',
 41: 'M',
 42: 'M',
 43: 'M',
 44: 'S',
 45: 'M',
 46: 'BX',
 47: 'BX',
 48: 'M',
 49: 'BK',
 50: 'M',
 51: 'BX',
 52: 'BK',
 53: 'Q',
 54: 'BK',
 55: 'BK',
 56: 'Q',
 57: 'Q',
 58: 'BX',
 59: 'BX',
 60: 'BX',
 61: 'BK',
 62: 'BK',
 63: 'BK',
 64: 'Q',
 65: 'BK',
 66: 'BK',
 67: 'BK',
 68: 'M',
 69: 'BX',
 70: 'Q',
 71: 'BK',
 72: 'BK',
 73: 'Q',
 74: 'M',
 75: 'M',
 76: 'BK',
 77: 'BK',
 78: 'BX',
 79: 'M',
 80: 'BK',
 81: 'BX',
 82: 'Q',
 83: 'Q',
 84: 'S',
 85: 'BK',
 86: 'Q',
 87: 'M',
 88: 'M',
 89: 'BK',
 90: 'M',
 91: 'BK',
 92: 'Q',
 93: 'Q',
 94: 'BX',
 95: 'Q',
 96: 'Q',
 97: 'BK',

In [30]:
from pyspark.sql.types import StructType, StructField, IntegerType, \
    TimestampNTZType, LongType, DoubleType, StringType

# Define the schema using StructType and StructField
schema = StructType([
    StructField('vendorid', IntegerType(), True),
    StructField('tpep_pickup_datetime', TimestampNTZType(), True),
    StructField('tpep_dropoff_datetime', TimestampNTZType(), True),
    StructField('passenger_count', IntegerType(), True),
    StructField('trip_distance', DoubleType(), True),
    StructField('ratecodeid', IntegerType(), True),
    StructField('store_and_fwd_flag', StringType(), True),
    StructField('pulocationid', IntegerType(), True),
    StructField('dolocationid', IntegerType(), True),
    StructField('payment_type', IntegerType(), True),
    StructField('fare_amount', DoubleType(), True),
    StructField('extra', DoubleType(), True),
    StructField('mta_tax', DoubleType(), True),
    StructField('tip_amount', DoubleType(), True),
    StructField('tolls_amount', DoubleType(), True),
    StructField('improvement_surcharge', DoubleType(), True),
    StructField('total_amount', DoubleType(), True),
    StructField('congestion_surcharge', DoubleType(), True),
    StructField('airport_fee', DoubleType(), True)
])

# Print the generated schema
print(schema)


StructType([StructField('vendorid', IntegerType(), True), StructField('tpep_pickup_datetime', TimestampNTZType(), True), StructField('tpep_dropoff_datetime', TimestampNTZType(), True), StructField('passenger_count', IntegerType(), True), StructField('trip_distance', DoubleType(), True), StructField('ratecodeid', IntegerType(), True), StructField('store_and_fwd_flag', StringType(), True), StructField('pulocationid', IntegerType(), True), StructField('dolocationid', IntegerType(), True), StructField('payment_type', IntegerType(), True), StructField('fare_amount', DoubleType(), True), StructField('extra', DoubleType(), True), StructField('mta_tax', DoubleType(), True), StructField('tip_amount', DoubleType(), True), StructField('tolls_amount', DoubleType(), True), StructField('improvement_surcharge', DoubleType(), True), StructField('total_amount', DoubleType(), True), StructField('congestion_surcharge', DoubleType(), True), StructField('airport_fee', DoubleType(), True)])


In [31]:
# test whether read successfully
# data overview
sdf = spark.read.schema(schema).parquet("../data/raw/2*")
sdf.show(1, vertical = True, truncate = 100)

-RECORD 0------------------------------------
 vendorid              | 1                   
 tpep_pickup_datetime  | 2022-10-01 00:03:41 
 tpep_dropoff_datetime | 2022-10-01 00:18:39 
 passenger_count       | 1                   
 trip_distance         | 1.7                 
 ratecodeid            | 1                   
 store_and_fwd_flag    | N                   
 pulocationid          | 249                 
 dolocationid          | 107                 
 payment_type          | 1                   
 fare_amount           | 9.5                 
 extra                 | 3.0                 
 mta_tax               | 0.5                 
 tip_amount            | 2.65                
 tolls_amount          | 0.0                 
 improvement_surcharge | 0.3                 
 total_amount          | 15.95               
 congestion_surcharge  | 2.5                 
 airport_fee           | 0.0                 
only showing top 1 row



In [32]:
# data overview
def count(df):
    print(df.count())
count(sdf)
print(sdf.agg(min(col("tpep_pickup_datetime"))))
print(sdf.agg(max(col("tpep_pickup_datetime"))))

print(sdf.agg(min(col("tpep_dropoff_datetime"))))
print(sdf.agg(max(col("tpep_dropoff_datetime"))))
# print(sdf.agg(max(col("transit_timestamp"))))

43172888


                                                                                

+-------------------------+
|min(tpep_pickup_datetime)|
+-------------------------+
|      2001-01-01 00:03:14|
+-------------------------+

+-------------------------+
|max(tpep_pickup_datetime)|
+-------------------------+
|      2023-04-18 14:30:05|
+-------------------------+



                                                                                

+--------------------------+
|min(tpep_dropoff_datetime)|
+--------------------------+
|       2001-01-01 00:34:17|
+--------------------------+

+--------------------------+
|max(tpep_dropoff_datetime)|
+--------------------------+
|       2023-04-18 23:30:39|
+--------------------------+



In [36]:
sdf = spark.read.schema(schema).parquet("../data/raw/2*")
def min_max_timestamp(df, column = 'transit_timestamp'):
    #  find the min max of column
    print(df.agg(max(col(column))))
    print(df.agg(min(col(column))))

def time_filter(df, start_date, end_date, column):
    # filter df withing the specified range
    df = df.filter(col(column) < lit(end_date))
    df = df.filter(col(column) >= lit(start_date))
    return df

def short_distance_trip_filter(df, min_distance):
    # filter df with short distance
    df = df.filter(col("Trip_distance") >= min_distance)
    return df

# assume that max distance is 100 miles
def long_distance_trip_filter(df, max_distance = 100):
    # filter df with long distance

    df = df.filter(col("Trip_distance") <= max_distance)
    return df

# assume that maximum fare amount is 300 dollars = 3 * 100
def fare_amount_filter(df, min_fare, max_fare):
    # filter fare amount within specified range
    df = df.filter(col("fare_amount") <= max_fare)
    df = df.filter(col("fare_amount") >= min_fare)
    return df


def short_trip_filter(df, min_second):
    # Calculate time difference in seconds
    # Calculate time difference in seconds
    df_with_diff = df.withColumn("time_difference",
                            (unix_timestamp(col("tpep_dropoff_datetime")) \
                             - unix_timestamp(col("tpep_pickup_datetime"))))
    df = df_with_diff.filter(col("time_difference") > min_second)
    return df

def long_trip_filter(df, max_second = 14400):
    # remove long trip
    return df.filter(col("time_difference") < max_second)

# later
def pickup_filter(df):
    # retain those only in concern
    return df.filter((df["pulocationid"] > 1) & (df["pulocationid"] <= 263))

# def passenger_filter(df):
#     # remove passenger
#     return df.filter((df["passenger_count"] > 0) & (df["passenger_count"] <= 10))

# Define a User-Defined Function (UDF) to map IDs to letters
def map_id_to_zone(id):
    return zone_dict.get(id, "Others")


print(f"Original dataset: {sdf.count()}")    
sdf = time_filter(sdf, start_date = '2022-02-01',end_date = '2023-03-01', \
    column = "tpep_pickup_datetime")
print(f"Time filter applied: {sdf.count()}")
sdf = short_distance_trip_filter(sdf, min_distance=0.4)
print(f"Short distance trip filtered: {sdf.count()}")
sdf = short_trip_filter(sdf, min_second=60)
print(f"Short time trip filtered: {sdf.count()}")
sdf = long_trip_filter(sdf)
print(f"Long time trip filtered: {sdf.count()}")
sdf = long_distance_trip_filter(sdf, max_distance = 100)
print(f"Long distance trip filtered: {sdf.count()}")
sdf = fare_amount_filter(sdf, min_fare = 0, max_fare = 300)
print(f"Fare amount filtered: {sdf.count()}")
sdf = pickup_filter(sdf)
print(f'Pick up filtered: {sdf.count()}')
# sdf = passenger_filter(sdf)
# print(f"Passenger count filter: {sdf.count()}")

# Recheck with good quality
# Register the UDF
map_id_to_zone_udf = udf(map_id_to_zone, StringType())
# Add a new column using the UDF
sdf = sdf.withColumn("borough", map_id_to_zone_udf(col("pulocationid")))
print(sdf.show(1, vertical=True))

# 38912804

Original dataset: 43172888


                                                                                

Time filter applied: 43172300


                                                                                

Short distance trip filtered: 41811443


                                                                                

Short time trip filtered: 41753931


                                                                                

Long time trip filtered: 41704217


                                                                                

Long distance trip filtered: 41702718


                                                                                

Fare amount filtered: 41504885


                                                                                

Pick up filtered: 40984627
-RECORD 0------------------------------------
 vendorid              | 1                   
 tpep_pickup_datetime  | 2022-10-01 00:03:41 
 tpep_dropoff_datetime | 2022-10-01 00:18:39 
 passenger_count       | 1                   
 trip_distance         | 1.7                 
 ratecodeid            | 1                   
 store_and_fwd_flag    | N                   
 pulocationid          | 249                 
 dolocationid          | 107                 
 payment_type          | 1                   
 fare_amount           | 9.5                 
 extra                 | 3.0                 
 mta_tax               | 0.5                 
 tip_amount            | 2.65                
 tolls_amount          | 0.0                 
 improvement_surcharge | 0.3                 
 total_amount          | 15.95               
 congestion_surcharge  | 2.5                 
 airport_fee           | 0.0                 
 time_difference       | 898                 
 boroug

In [37]:
# yyy-mm-dd
# sdf.filter(col("tpep_pickup_datetime")> "2022-10-28 15:00:00")

In [38]:
sdf

vendorid,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,ratecodeid,store_and_fwd_flag,pulocationid,dolocationid,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,airport_fee,time_difference,borough
1,2022-10-01 00:03:41,2022-10-01 00:18:39,1,1.7,1,N,249,107,1,9.5,3.0,0.5,2.65,0.0,0.3,15.95,2.5,0.0,898,M
2,2022-10-01 00:14:30,2022-10-01 00:19:48,2,0.72,1,N,151,238,2,5.5,0.5,0.5,0.0,0.0,0.3,9.3,2.5,0.0,318,M
2,2022-10-01 00:27:13,2022-10-01 00:37:41,1,1.74,1,N,238,166,1,9.0,0.5,0.5,2.06,0.0,0.3,12.36,0.0,0.0,628,M
1,2022-10-01 00:32:53,2022-10-01 00:38:55,0,1.3,1,N,142,239,1,6.5,3.0,0.5,2.05,0.0,0.3,12.35,2.5,0.0,362,M
1,2022-10-01 00:44:55,2022-10-01 00:50:21,0,1.0,1,N,238,166,1,6.0,0.5,0.5,1.8,0.0,0.3,9.1,0.0,0.0,326,M
1,2022-10-01 00:22:52,2022-10-01 00:52:14,1,6.8,1,Y,186,41,2,25.5,3.0,0.5,0.0,0.0,0.3,29.3,2.5,0.0,1762,M
2,2022-10-01 00:33:19,2022-10-01 00:44:51,3,1.88,1,N,162,145,2,10.5,0.5,0.5,0.0,0.0,0.3,14.3,2.5,0.0,692,M
1,2022-10-01 00:02:42,2022-10-01 00:50:01,1,12.2,1,N,100,22,1,41.0,3.0,0.5,3.0,0.0,0.3,47.8,2.5,0.0,2839,M
2,2022-10-01 00:06:35,2022-10-01 00:24:38,1,7.79,1,N,138,112,1,23.5,0.5,0.5,4.96,0.0,0.3,31.01,0.0,1.25,1083,Q
2,2022-10-01 00:29:25,2022-10-01 00:43:15,1,4.72,1,N,145,75,1,14.5,0.5,0.5,1.5,0.0,0.3,19.8,2.5,0.0,830,Q


In [39]:
sdf.printSchema()

root
 |-- vendorid: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- ratecodeid: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- pulocationid: integer (nullable = true)
 |-- dolocationid: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: double (nullable = true)
 |-- time_difference: long (nullable = true)
 |-- borough: string (nullable = true)



In [40]:
sdf = sdf.withColumn("tpep_pickup_datetime", \
                     col("tpep_pickup_datetime").cast("timestamp"))
# Round down the hour to the nearest hour while retaining the date information
sdf = sdf.withColumn("rounded_pickup", \
                     date_trunc("hour", "tpep_pickup_datetime"))
sdf

vendorid,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,ratecodeid,store_and_fwd_flag,pulocationid,dolocationid,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,airport_fee,time_difference,borough,rounded_pickup
1,2022-10-01 00:03:41,2022-10-01 00:18:39,1,1.7,1,N,249,107,1,9.5,3.0,0.5,2.65,0.0,0.3,15.95,2.5,0.0,898,M,2022-10-01 00:00:00
2,2022-10-01 00:14:30,2022-10-01 00:19:48,2,0.72,1,N,151,238,2,5.5,0.5,0.5,0.0,0.0,0.3,9.3,2.5,0.0,318,M,2022-10-01 00:00:00
2,2022-10-01 00:27:13,2022-10-01 00:37:41,1,1.74,1,N,238,166,1,9.0,0.5,0.5,2.06,0.0,0.3,12.36,0.0,0.0,628,M,2022-10-01 00:00:00
1,2022-10-01 00:32:53,2022-10-01 00:38:55,0,1.3,1,N,142,239,1,6.5,3.0,0.5,2.05,0.0,0.3,12.35,2.5,0.0,362,M,2022-10-01 00:00:00
1,2022-10-01 00:44:55,2022-10-01 00:50:21,0,1.0,1,N,238,166,1,6.0,0.5,0.5,1.8,0.0,0.3,9.1,0.0,0.0,326,M,2022-10-01 00:00:00
1,2022-10-01 00:22:52,2022-10-01 00:52:14,1,6.8,1,Y,186,41,2,25.5,3.0,0.5,0.0,0.0,0.3,29.3,2.5,0.0,1762,M,2022-10-01 00:00:00
2,2022-10-01 00:33:19,2022-10-01 00:44:51,3,1.88,1,N,162,145,2,10.5,0.5,0.5,0.0,0.0,0.3,14.3,2.5,0.0,692,M,2022-10-01 00:00:00
1,2022-10-01 00:02:42,2022-10-01 00:50:01,1,12.2,1,N,100,22,1,41.0,3.0,0.5,3.0,0.0,0.3,47.8,2.5,0.0,2839,M,2022-10-01 00:00:00
2,2022-10-01 00:06:35,2022-10-01 00:24:38,1,7.79,1,N,138,112,1,23.5,0.5,0.5,4.96,0.0,0.3,31.01,0.0,1.25,1083,Q,2022-10-01 00:00:00
2,2022-10-01 00:29:25,2022-10-01 00:43:15,1,4.72,1,N,145,75,1,14.5,0.5,0.5,1.5,0.0,0.3,19.8,2.5,0.0,830,Q,2022-10-01 00:00:00


In [41]:
grouped = sdf.groupBy("borough", "rounded_pickup").count()

def extract_borough(df, borough):
    return df.filter(df.borough == borough) 

bk_taxi = extract_borough(grouped, 'BK')
queens_taxi = extract_borough(grouped, 'Q')
manhattan_taxi = extract_borough(grouped, 'M')
bx_taxi = extract_borough(grouped, 'BX')
state_taxi = extract_borough(grouped, 'S')

In [42]:
# fill in the missing timestamp

start_date = "2022-02-01 00:00:00"
end_date = "2023-02-28 23:00:00"  # Adjusted to include the full range of the last day

# Generate hourly timestamps between start_date and end_date
timestamps_df = spark.sql(f"SELECT sequence(to_timestamp('{start_date}'), \
            to_timestamp('{end_date}'), interval 1 hour) AS rounded_pickup")

# Explode the array of timestamps to get one row per timestamp
exploded_df = timestamps_df.select(explode("rounded_pickup").alias("rounded_pickup"))

print(exploded_df.count())
# Show the generated timestamps
# exploded_df.show(truncate=False)
joined_bk_taxi = exploded_df.join(bk_taxi, "rounded_pickup", "left")\
    .orderBy('rounded_pickup')
joined_queens_taxi = exploded_df.join(queens_taxi, "rounded_pickup", "left")\
    .orderBy('rounded_pickup')
joined_manhattan_taxi = exploded_df.join(manhattan_taxi, "rounded_pickup", "left")\
    .orderBy('rounded_pickup')
joined_bx_taxi = exploded_df.join(bx_taxi, "rounded_pickup", "left")\
    .orderBy('rounded_pickup')
joined_state_taxi = exploded_df.join(state_taxi, "rounded_pickup", "left")\
    .orderBy('rounded_pickup')

9432


In [43]:

def find_nan(df,column):
    print(df.filter(isnull(col(column)) | isnan(col(column))).count())

find_nan(joined_bk_taxi, column="count")
find_nan(joined_queens_taxi, column="count")
find_nan(joined_manhattan_taxi, column="count")
find_nan(joined_bx_taxi, column="count")


                                                                                

23


                                                                                

3


                                                                                

1




1189


                                                                                

In [46]:
joined_bx_taxi

                                                                                

rounded_pickup,borough,count
2022-02-01 00:00:00,,
2022-02-01 01:00:00,BX,1.0
2022-02-01 02:00:00,BX,1.0
2022-02-01 03:00:00,,
2022-02-01 04:00:00,BX,2.0
2022-02-01 05:00:00,BX,5.0
2022-02-01 06:00:00,BX,6.0
2022-02-01 07:00:00,BX,12.0
2022-02-01 08:00:00,BX,15.0
2022-02-01 09:00:00,BX,9.0


In [47]:
joined_bk_taxi.printSchema()

root
 |-- rounded_pickup: timestamp (nullable = false)
 |-- borough: string (nullable = true)
 |-- count: long (nullable = true)



In [48]:
def save_to_csv(df, name):
    df = df.withColumn("rounded_pickup", \
            date_format("rounded_pickup", "yyyy-MM-dd HH:mm:ss")).toPandas()
    df.to_csv(f"../data/curated/{name}.csv", header= True, index = False)
save_to_csv(joined_bk_taxi, name = "brooklyn")
save_to_csv(joined_queens_taxi, name = "queens")
save_to_csv(joined_manhattan_taxi, name = "manhattan")
save_to_csv(joined_bx_taxi, name = "bronx")
save_to_csv(joined_state_taxi, name = "state")


                                                                                

In [49]:
sample = pd.read_csv("../data/curated/queens.csv")
sample[sample['count'].isna()]

Unnamed: 0,rounded_pickup,borough,count
962,2022-03-13 02:00:00,,
5491,2022-09-17 19:00:00,,
5502,2022-09-18 06:00:00,,


## Taxi visualization purposes

This code chunk is exactly like the above, with only a twist to save the files into csv for visualization with folium

In [19]:
# use pandas to separate zones and map a region to a letter
zone_map = pd.read_csv("../data/landing/taxi_zones.csv")
zone_map = zone_map[(zone_map['LocationID'] > 1) \
                    & (zone_map['LocationID'] < 264)]
def abbreviation_column(x):
    if x == 'Queens':
        return 'Q'
    elif x == 'Bronx':
        return 'BX'
    elif x == 'Manhattan':
        return "M"
    elif x == 'Staten Island':
        return 'S'
    elif x == 'Brooklyn':
        return "BK"
    else:
        return "Others"
zone_map['abbr_col'] = zone_map['Borough']\
    .apply(lambda x: abbreviation_column(x))
print(zone_map['abbr_col'].value_counts())
zone_map = zone_map[['LocationID', 'abbr_col']]
# Convert DataFrame to dictionary
zone_dict = zone_map.set_index('LocationID')['abbr_col'].to_dict()
zone_dict

abbr_col
Q     69
M     69
BK    61
BX    43
S     20
Name: count, dtype: int64


{2: 'Q',
 3: 'BX',
 4: 'M',
 5: 'S',
 6: 'S',
 7: 'Q',
 8: 'Q',
 9: 'Q',
 10: 'Q',
 11: 'BK',
 12: 'M',
 13: 'M',
 14: 'BK',
 15: 'Q',
 16: 'Q',
 17: 'BK',
 18: 'BX',
 19: 'Q',
 20: 'BX',
 21: 'BK',
 22: 'BK',
 23: 'S',
 24: 'M',
 25: 'BK',
 26: 'BK',
 27: 'Q',
 28: 'Q',
 29: 'BK',
 30: 'Q',
 31: 'BX',
 32: 'BX',
 33: 'BK',
 34: 'BK',
 35: 'BK',
 36: 'BK',
 37: 'BK',
 38: 'Q',
 39: 'BK',
 40: 'BK',
 41: 'M',
 42: 'M',
 43: 'M',
 44: 'S',
 45: 'M',
 46: 'BX',
 47: 'BX',
 48: 'M',
 49: 'BK',
 50: 'M',
 51: 'BX',
 52: 'BK',
 53: 'Q',
 54: 'BK',
 55: 'BK',
 56: 'Q',
 57: 'Q',
 58: 'BX',
 59: 'BX',
 60: 'BX',
 61: 'BK',
 62: 'BK',
 63: 'BK',
 64: 'Q',
 65: 'BK',
 66: 'BK',
 67: 'BK',
 68: 'M',
 69: 'BX',
 70: 'Q',
 71: 'BK',
 72: 'BK',
 73: 'Q',
 74: 'M',
 75: 'M',
 76: 'BK',
 77: 'BK',
 78: 'BX',
 79: 'M',
 80: 'BK',
 81: 'BX',
 82: 'Q',
 83: 'Q',
 84: 'S',
 85: 'BK',
 86: 'Q',
 87: 'M',
 88: 'M',
 89: 'BK',
 90: 'M',
 91: 'BK',
 92: 'Q',
 93: 'Q',
 94: 'BX',
 95: 'Q',
 96: 'Q',
 97: 'BK',

In [20]:
from pyspark.sql.types import StructType, StructField, IntegerType, \
    TimestampNTZType, LongType, DoubleType, StringType

# Define the schema using StructType and StructField
schema = StructType([
    StructField('vendorid', IntegerType(), True),
    StructField('tpep_pickup_datetime', TimestampNTZType(), True),
    StructField('tpep_dropoff_datetime', TimestampNTZType(), True),
    StructField('passenger_count', IntegerType(), True),
    StructField('trip_distance', DoubleType(), True),
    StructField('ratecodeid', IntegerType(), True),
    StructField('store_and_fwd_flag', StringType(), True),
    StructField('pulocationid', IntegerType(), True),
    StructField('dolocationid', IntegerType(), True),
    StructField('payment_type', IntegerType(), True),
    StructField('fare_amount', DoubleType(), True),
    StructField('extra', DoubleType(), True),
    StructField('mta_tax', DoubleType(), True),
    StructField('tip_amount', DoubleType(), True),
    StructField('tolls_amount', DoubleType(), True),
    StructField('improvement_surcharge', DoubleType(), True),
    StructField('total_amount', DoubleType(), True),
    StructField('congestion_surcharge', DoubleType(), True),
    StructField('airport_fee', DoubleType(), True)
])

# Print the generated schema
print(schema)


StructType([StructField('vendorid', IntegerType(), True), StructField('tpep_pickup_datetime', TimestampNTZType(), True), StructField('tpep_dropoff_datetime', TimestampNTZType(), True), StructField('passenger_count', IntegerType(), True), StructField('trip_distance', DoubleType(), True), StructField('ratecodeid', IntegerType(), True), StructField('store_and_fwd_flag', StringType(), True), StructField('pulocationid', IntegerType(), True), StructField('dolocationid', IntegerType(), True), StructField('payment_type', IntegerType(), True), StructField('fare_amount', DoubleType(), True), StructField('extra', DoubleType(), True), StructField('mta_tax', DoubleType(), True), StructField('tip_amount', DoubleType(), True), StructField('tolls_amount', DoubleType(), True), StructField('improvement_surcharge', DoubleType(), True), StructField('total_amount', DoubleType(), True), StructField('congestion_surcharge', DoubleType(), True), StructField('airport_fee', DoubleType(), True)])


In [21]:
sdf = spark.read.schema(schema).parquet("../data/raw/2*")
def min_max_timestamp(df, column = 'transit_timestamp'):
    print(df.agg(max(col(column))))
    print(df.agg(min(col(column))))

def time_filter(df, start_date, end_date, column):
    df = df.filter(col(column) < lit(end_date))
    df = df.filter(col(column) >= lit(start_date))
    return df

def short_distance_trip_filter(df, min_distance):
    df = df.filter(col("Trip_distance") >= min_distance)
    return df

# assume that max distance is 100 miles
def long_distance_trip_filter(df, max_distance = 100):
    df = df.filter(col("Trip_distance") <= max_distance)
    return df

# assume that maximum fare amount is 300 dollars = 3 * 100
def fare_amount_filter(df, min_fare, max_fare):
    df = df.filter(col("fare_amount") <= max_fare)
    df = df.filter(col("fare_amount") >= min_fare)
    return df


def short_trip_filter(df, min_second):
    # Calculate time difference in seconds
    # Calculate time difference in seconds
    df_with_diff = df.withColumn("time_difference",
                            (unix_timestamp(col("tpep_dropoff_datetime")) \
                             - unix_timestamp(col("tpep_pickup_datetime"))))
    df = df_with_diff.filter(col("time_difference") > min_second)
    return df

def long_trip_filter(df, max_second = 14400):
    return df.filter(col("time_difference") < max_second)

# later
def pickup_filter(df):
    return df.filter((df["pulocationid"] > 1) & (df["pulocationid"] <= 263))

def passenger_filter(df):
    return df.filter((df["passenger_count"] > 0) & (df["passenger_count"] <= 10))

# Define a User-Defined Function (UDF) to map IDs to letters
def map_id_to_zone(id):
    return zone_dict.get(id, "Others")


print(f"Original dataset: {sdf.count()}")    
sdf = time_filter(sdf, start_date = '2022-02-01',end_date = '2023-03-01', \
    column = "tpep_pickup_datetime")
print(f"Time filter applied: {sdf.count()}")
sdf = short_distance_trip_filter(sdf, min_distance=0.4)
print(f"Short distance trip filtered: {sdf.count()}")
sdf = short_trip_filter(sdf, min_second=60)
print(f"Short time trip filtered: {sdf.count()}")
sdf = long_trip_filter(sdf)
print(f"Long time trip filtered: {sdf.count()}")
sdf = long_distance_trip_filter(sdf, max_distance = 100)
print(f"Long distance trip filtered: {sdf.count()}")
sdf = fare_amount_filter(sdf, min_fare = 0, max_fare = 300)
print(f"Fare amount filtered: {sdf.count()}")
sdf = pickup_filter(sdf)
print(f'Pick up filtered: {sdf.count()}')
# sdf = passenger_filter(sdf)
# print(f"Passenger count filter: {sdf.count()}")

# Recheck with good quality
# Register the UDF
map_id_to_zone_udf = udf(map_id_to_zone, StringType())
# Add a new column using the UDF
sdf = sdf.withColumn("borough", map_id_to_zone_udf(col("pulocationid")))
print(sdf.show(1, vertical=True))

sdf = sdf.withColumn("tpep_pickup_datetime", \
                     col("tpep_pickup_datetime").cast("timestamp"))
# Round down the hour to the nearest hour while retaining the date information
sdf = sdf.withColumn("rounded_pickup", \
                     date_trunc("hour", "tpep_pickup_datetime"))
sdf = sdf.withColumn("hour", hour(col("tpep_pickup_datetime")))
sdf
# 38912804

Original dataset: 43172888
Time filter applied: 43172300


                                                                                

Short distance trip filtered: 41811443


                                                                                

Short time trip filtered: 41753931


                                                                                

Long time trip filtered: 41704217


                                                                                

Long distance trip filtered: 41702718


                                                                                

Fare amount filtered: 41504885


                                                                                

Pick up filtered: 40984627
-RECORD 0------------------------------------
 vendorid              | 1                   
 tpep_pickup_datetime  | 2022-10-01 00:03:41 
 tpep_dropoff_datetime | 2022-10-01 00:18:39 
 passenger_count       | 1                   
 trip_distance         | 1.7                 
 ratecodeid            | 1                   
 store_and_fwd_flag    | N                   
 pulocationid          | 249                 
 dolocationid          | 107                 
 payment_type          | 1                   
 fare_amount           | 9.5                 
 extra                 | 3.0                 
 mta_tax               | 0.5                 
 tip_amount            | 2.65                
 tolls_amount          | 0.0                 
 improvement_surcharge | 0.3                 
 total_amount          | 15.95               
 congestion_surcharge  | 2.5                 
 airport_fee           | 0.0                 
 time_difference       | 898                 
 boroug

vendorid,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,ratecodeid,store_and_fwd_flag,pulocationid,dolocationid,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,airport_fee,time_difference,borough,rounded_pickup,hour
1,2022-10-01 00:03:41,2022-10-01 00:18:39,1,1.7,1,N,249,107,1,9.5,3.0,0.5,2.65,0.0,0.3,15.95,2.5,0.0,898,M,2022-10-01 00:00:00,0
2,2022-10-01 00:14:30,2022-10-01 00:19:48,2,0.72,1,N,151,238,2,5.5,0.5,0.5,0.0,0.0,0.3,9.3,2.5,0.0,318,M,2022-10-01 00:00:00,0
2,2022-10-01 00:27:13,2022-10-01 00:37:41,1,1.74,1,N,238,166,1,9.0,0.5,0.5,2.06,0.0,0.3,12.36,0.0,0.0,628,M,2022-10-01 00:00:00,0
1,2022-10-01 00:32:53,2022-10-01 00:38:55,0,1.3,1,N,142,239,1,6.5,3.0,0.5,2.05,0.0,0.3,12.35,2.5,0.0,362,M,2022-10-01 00:00:00,0
1,2022-10-01 00:44:55,2022-10-01 00:50:21,0,1.0,1,N,238,166,1,6.0,0.5,0.5,1.8,0.0,0.3,9.1,0.0,0.0,326,M,2022-10-01 00:00:00,0
1,2022-10-01 00:22:52,2022-10-01 00:52:14,1,6.8,1,Y,186,41,2,25.5,3.0,0.5,0.0,0.0,0.3,29.3,2.5,0.0,1762,M,2022-10-01 00:00:00,0
2,2022-10-01 00:33:19,2022-10-01 00:44:51,3,1.88,1,N,162,145,2,10.5,0.5,0.5,0.0,0.0,0.3,14.3,2.5,0.0,692,M,2022-10-01 00:00:00,0
1,2022-10-01 00:02:42,2022-10-01 00:50:01,1,12.2,1,N,100,22,1,41.0,3.0,0.5,3.0,0.0,0.3,47.8,2.5,0.0,2839,M,2022-10-01 00:00:00,0
2,2022-10-01 00:06:35,2022-10-01 00:24:38,1,7.79,1,N,138,112,1,23.5,0.5,0.5,4.96,0.0,0.3,31.01,0.0,1.25,1083,Q,2022-10-01 00:00:00,0
2,2022-10-01 00:29:25,2022-10-01 00:43:15,1,4.72,1,N,145,75,1,14.5,0.5,0.5,1.5,0.0,0.3,19.8,2.5,0.0,830,Q,2022-10-01 00:00:00,0


In [22]:
# group by location id and hour, then agg for folium viz
grouped = sdf.groupBy("pulocationid", "hour").count()
grouped = grouped.orderBy("pulocationid", "hour")
grouped

                                                                                

pulocationid,hour,count
2,7,1
2,8,1
2,10,2
2,11,1
2,12,4
2,13,2
2,14,2
2,15,5
2,16,1
2,17,1


In [23]:
# checking the quality of the code
total_sum = grouped.select(sum(col("count"))).collect()[0][0]
total_sum

                                                                                

40984627

In [24]:
grouped.toPandas().to_csv("../data/curated/figure1.csv", \
                          index = False, header = True)

                                                                                

In [25]:
fig1 = pd.read_csv("../data/curated/figure1.csv")
fig1

Unnamed: 0,pulocationid,hour,count
0,2,7,1
1,2,8,1
2,2,10,2
3,2,11,1
4,2,12,4
...,...,...,...
5866,263,19,53627
5867,263,20,46155
5868,263,21,43798
5869,263,22,40632


23/08/18 17:31:35 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 3567295 ms exceeds timeout 120000 ms
23/08/18 17:31:35 WARN SparkContext: Killing executors is not supported by current scheduler.
23/08/18 17:31:35 WARN Executor: Issue communicating with driver in heartbeater
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:322)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:101)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:85)
	at org.apache.spark.storage.BlockManagerMaster.registerBlockManager(BlockManagerMaster.scala:80)
	at org.apache.spark.storage.BlockManager.reregister(BlockManager.scala:641)
	at org.apache.spark.executor.Executor.reportHeartBeat(Executor.scala:1111)
	at org.apache.spark.executor.Executor.$anonfun$heartbeater$1(Executor.scala:244)
	at s