In [12]:
# Import required packages
import boto3
import datetime as dt
import multiprocessing as mp
from pyspark.sql import SparkSession
from pyspark.sql.types import BooleanType, DoubleType, IntegerType, StringType, StructType, StructField, TimestampType
import pyspark.sql.functions as F

In [13]:
# Set parameters 
bucket_name = "nyc-tlc" # s3 bucket name 
years = ["2019", "2020"]
tlc_colours = ["yellow", "green"]
months = ['01','02','03','04','05','06','07','08','09','10','11','12']
dt_columns = ["pickup_datetime","dropoff_datetime"]
int_columns = ["payment_type","RatecodeID","passenger_count","year"]
num_columns = ["trip_distance","fare_amount","extra","mta_tax","improvement_surcharge",
               "tip_amount","tolls_amount","total_amount"]
initial_columns = ["pickup_datetime","dropoff_datetime","RatecodeID","passenger_count","trip_distance",
                   "payment_type","fare_amount","extra","mta_tax",
                 "improvement_surcharge","tip_amount","tolls_amount","total_amount","taxi_type",
                 "year","month",]

In [14]:
# Create a local spark session
spark = SparkSession.builder \
        .appName('nyc-taxi-etl') \
        .getOrCreate()

## Extraction Function

In [15]:
# Function to extract data from S3 bucket
def extract_data_from_bucket(bucket, year, colour, month):
    df = spark.read.csv(f"s3a://{bucket}/trip data/{colour}_tripdata_{year}-{month}.csv", header=True)
    return df

## Transform Data

### Drop columns
* **VendorID** - needs processing intensive one hot encoding
* **store_and_fwd_flag** - not informative for model training
* **PULocationID** - needs processing intensive one hot encoding
* **DOLocationID** - needs processing intensive one hot encoding
* **trip_type** - is not in other data set
* **congestion_surcharge** - is not in other data set and should be included in extras per dictionary
* **ehail_fee** - is not in other data set

### Modify data types

* pickup_datetime_string: string -> timestamp
* dropoff_datetime_string: string -> timestamp
* trip_type: string -> integer
* payment_type: string -> integer
* passenger_count: string -> integer
* trip_distance: string -> double
* fare_amount: string -> double
* extra: string -> double
* mta_tax: string -> double
* tip_amount: string -> double
* tolls_amount: string -> double
* improvement_surcharge: string -> double
* total_amount: string -> double

### Rename Columns

* **lpep_pickup_datetime** -> pickup_datetime
* **lpep_dropoff_datetime** -> dropoff_datetime
* **tpep_dropoff_datetime** -> dropoff_datetime
* **tpep_dropoff_datetime** -> dropoff_datetime

### Create new features

* **taxi_type**: whether is a green or yellow cab - created in extract
* **trip_duration**: time, in seconds, between trip start and trip end
* **trip_duration_cat**: bins of trip durations; lt 5 Mins, 5-10 mins, 10-20 mins, 20-30 mins, gt 30 mins
* **year**: the year the trip took place in - created in extract
* **month**: the month the trip took place in
* **hour**: the hour the trip took place in

In [16]:
# Function to calulate trip duration category
def get_trip_duration_category(time):
    minutes = time / 60
    if minutes < 5:
        return "Under 5 mins"
    elif 5 <= minutes < 10:
        return "5-10 mins"
    elif 10 <= minutes < 20:
        return "10-20 mins"
    elif 20 <= minutes < 30:
        return "20-30 mins"
    else:
        return "Above 30 mins"

# Register function as a Spark user defined function 
udf_get_trip_duration_category = F.udf(lambda x: get_trip_duration_category(x), StringType())

In [17]:
# Function to calaculate kilometres from a value in miles - may not be required
def get_kilometres_from_miles(miles):
    km = miles * 1.60934
    return km

#Register function as a Spark user defined function
udf_kilometres_from_miles = F.udf(lambda x: get_kilometres_from_miles(x), DoubleType())

In [18]:
# Transforms specific to yellow taxi files
def transform_yellow_taxi_data(df):
    df = df.withColumnRenamed("tpep_pickup_datetime", "pickup_datetime").\
            withColumnRenamed("tpep_dropoff_datetime", "dropoff_datetime")
    return df

In [19]:
# Transforms specific to green taxi files
def transform_green_taxi_data(df):
    df = df.withColumnRenamed("lpep_pickup_datetime", "pickup_datetime").\
            withColumnRenamed("lpep_dropoff_datetime", "dropoff_datetime")
    return(df)

In [20]:
def transform_integer_columns(df, column):
    if column in df.columns:
        df = df.withColumn(column, F.col(column).astype(IntegerType()))
    return df

In [21]:
def transform_double_columns(df, column):
    if column in df.columns:
        df = df.withColumn(column, F.col(column).astype(DoubleType()))
    return df

In [22]:
# Transform field to timestamp data type
def transform_timestamp_columns(df, column):
    if column in df.columns:
        df = df.withColumn(column, F.col(column).astype(TimestampType()))
    return df

In [25]:
# Transforms for all NYC TLC files
def transform_generic_taxi_data(df, dt_columns, int_columns, num_columns, select_columns):
    # Modify data type for timestamp columns
    for column in dt_columns:
        df = transform_timestamp_columns(df, column)
    
    # Modify data type for integers columns
    for column in int_columns:
        df = transform_integer_columns(df, column)
        
    # Modify data type for numbers/decimals
    for column in num_columns:
        df = transform_double_columns(df, column)
    
    # Add features
    df = df.select(select_columns).\
            withColumn("trip_duration_seconds", F.col("dropoff_datetime").cast("long") - F.col("pickup_datetime").cast("long")).\
            withColumn("trip_duration_category", udf_get_trip_duration_category(F.col("trip_duration_seconds"))).\
            withColumn("pickup_hour", F.hour(F.col("pickup_datetime"))).\
            withColumn("trip_distance_km", udf_kilometres_from_miles(F.col("trip_distance")))
            #withColumn("trip_distance_km", F.col("trip_distance") * 1.60934)
    
    return df

In [26]:
# Function to bring all transforms together
def data_processing_transform(df, year, colour, month, dt_columns, int_columns, num_columns, select_columns):
    df = df.withColumn("taxi_type", F.lit(colour)).\
        withColumn("year", F.lit(year)).\
        withColumn("month", F.lit(month))
    if colour == "yellow":
        # Process transform tasks specific to yellow taxis
        df = transform_yellow_taxi_data(df)
    elif colour == "green":
        # Process transform tasks specific to green taxis
        df = transform_green_taxi_data(df)
    else:
        print("Taxi colour not defined")

    # Process generic transformations
    df = transform_generic_taxi_data(df, dt_columns, int_columns, num_columns, select_columns)
    return df

## Data Clean
### Remove records

* **payment_type**: remove trips with a voided payment type and out of bounds
* **RatecodeID**: remove trips out of bounds
* **mta_tax**: remove incorrect non 0.50 mta_tax
* **passenger_count**: remove less than one
* **tip_amount**: remove negatives
* **improvement_surcharge**: remove incorrect non 0.30 improvement_surcharge   
* **fare_amount**: remove trips with a fare amount of zero or below
* **extra**: remove trips negative or excessively high extras
* **tolls_amount**: remove trips negative
* **total_amount**: remove trips negative    
* **trip_duration_seconds**: remove trips with a duration of zero, or less, seconds
* **pickup_datetime**: remove trips outside month of stated period in file name


In [28]:
### Filter on valid payment types and also remove voided trips ie those labelled 6
def clean_payment_type(df):
    df.filter((F.col("payment_type") < 6) & (F.col("payment_type") < 1))
    return df

In [29]:
### Filter on valid rate code ids
def clean_rate_code_id(df):
    df.filter((F.col("RatecodeID") < 7) & (F.col("RatecodeID") < 1))
    return df

In [30]:
### Filter for correct mta_tax
def clean_mta_tax(df):
    df.filter(F.col("mta_tax") == 0.50)
    return df

In [32]:
### Filter for realistic passenger_count
def clean_passenger_count(df):
    df.filter((F.col("passenger_count") >= 1)) 
    return df

In [33]:
### Filter for realistic tip_amount
def clean_tip_amount(df):
    df.filter((F.col("tip_amount") >= 0)) 
    return df

In [34]:
### Filter for correct improvement_surcharge
def clean_improvement_surcharge(df):
    df.filter((F.col("improvement_surcharge") == 0.30)) 
    return df

In [35]:
### Filter for realistic fare_amount
def clean_fare_amount(df):
    df.filter((F.col("fare_amount") > 0)) 
    return df

In [37]:
### Filter for realistic extras and remove significantly higher extras, data dict says shouldn't be more than 
# night time and congestion fees normally but a bit inconsistent with the data
def clean_extra(df):
    df.filter((F.col("extra") >= 0) & (F.col("extra") < 3.5))
    return df

In [38]:
### Filter for realistic tolls_amount
def clean_tolls_amount(df):
    df.filter((F.col("tolls_amount") >= 0)) 
    return df

In [39]:
### Filter for realistic tolls_amount
def clean_total_amount(df):
    df.filter((F.col("total_amount") > 0)) 
    return df

In [40]:
# Remove records with a trip duration of 0 seconds, or less, or greater than 10 hours
def clean_trip_duration_seconds(df):
    df = df.filter((F.col("trip_duration_seconds") > 0) & (F.col("trip_duration_seconds") < 36000))
    return df

In [41]:
# Remove trips that are outside original files stated dates
def clean_trips_outside_file_period(df, dt_field):  
    df = df.filter((F.col("year") == F.year(F.col(dt_field))) & (F.col("month") == F.month(F.col(dt_field))))
    return df

In [42]:
# Function to bring all clean processes into one
def data_processing_clean(df):
    if "payment_type" in df.columns:
        df = clean_payment_type(df)    
    
    if "RatecodeID" in df.columns:
        df = clean_rate_code_id(df)
        
    if "mta_tax" in df.columns:
        df = clean_mta_tax(df)
        
    if "passenger_count" in df.columns:
        df = clean_passenger_count(df)
        
    if "tip_amount" in df.columns:
        df = clean_tip_amount(df)
    
    if "fare_amount" in df.columns:
        df = clean_fare_amount(df)
        
    if "extra" in df.columns:
        df = clean_extra(df)        
        
    if "tolls_amount" in df.columns(df):
        df = clean_tolls_amount(df)
        
    if "total_amount" in df.columns(df):
        df = clean_total_amount(df)
        
    if "trip_duration_seconds" in df.columns:
        df = clean_trip_duration_seconds(df)

    # Remove records outside file month year
    df = clean_trips_outside_file_period(df, "pickup_datetime")
    
    return df

## Clean and Write Data into a parquet file

In [43]:
# Function to write data to parquet files
def write_data_to_parquet(df, mode):
    #df = df.repartition(F.col("year"), F.col("month"))
    df.write.partitionBy("year","month").parquet("./output", mode=mode)

In [None]:
loop_num = 1

# For each applicable year, month and taxi colour process files and load into parquet 
for year in years:
    for tlc_colour in tlc_colours:
        for month in months:
            start = dt.datetime.now()
            df_extract = extract_data_from_bucket(bucket_name, year, tlc_colour, month)
            df_transform = data_processing_transform(df_extract,
                                                     year,
                                                     tlc_colour,
                                                     month,
                                                     dt_columns,
                                                     int_columns,
                                                     num_columns,
                                                     initial_columns)
            df_clean = data_processing_clean(df_transform)
            
            # Now write data to parquet
            if loop_num == 1:
                mode = "overwrite"
            else:
                mode = "append"
            
            write_data_to_parquet(df_clean, mode)
            
            loop_num += 1
            end = dt.datetime.now()
            process_time = abs((end - start).seconds)
            string = "Data file for month: {}, year: {} and taxi colour: {} successfully loaded in {} seconds".format(month, year, tlc_colour, process_time)
            print(string)