In [None]:
%%spark

In [11]:
import pyspark.sql.functions as F
from pyspark.sql.types import StructType as R, StructField as Fld, DoubleType as Dbl, \
StringType as Str, IntegerType as Int, DateType as Date, LongType as Long, TimestampType as TS
import os

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [18]:


def etl(input_data, input_key, output_data, output_key='analytics', start_datetime='1900-01-01 00:00:00', end_datetime='2200-12-31 12:59:59',
        dq_checks=[], load_on_error=False):
    """
    Function to join the trips records data and  weather data,
    and write resulting analytics data to S3 in parquet format
    
    Argements:
    input_data: source bucket of the trip record data 
    input_key: key of the trip record data
    output_data: destination bucket for the analytics data
    output_key: key for the analytics data
    start_datetime: The expected start datetime for the trip records data. All records with earlier pickup_datetime
    will be removed.
    end_datetime:The expected end datetime for the trip records data. All records with later pickup_datetime
    will be removed.
    dq_checks: data quality checks on the analytics table. should be a list of dictionary with the following keys:
    {'name': name of the test,
    'sql_query': test sql query (the name of the view is 'sql_view'),
    'expected_result': expected output of the sql_query
    }
    load_on_error: whether to load the analytics table if there is error 
    """
    
    print('Processing data...')

    
    ######################
    #process weather data
    ######################
    
    #load data
    weather = spark.read.json(input_data + 'weather/weather.json')

    #create datetime column in Eastern Standard Time (EST) to match trip record data 
    weather = weather.withColumn('datetime', F.when(F.col('timezone') == -18000, F.from_unixtime(F.col('dt')) - F.expr('INTERVAL 5 HOURS')) \
                                 .otherwise(F.from_unixtime(F.col('dt')) - F.expr('INTERVAL 4 HOURS')))

    #remove duplicate due to daylight saving (e.g. record for 2019-11-03 01:00:00 EST appears twice)
    weather = weather.drop_duplicates(['datetime'])

    #create year, month, day and hour columns
    weather = weather.withColumn('hour',F.hour('datetime')) \
        .withColumn('day',F.dayofmonth('datetime')) \
        .withColumn('month',F.month('datetime')) \
        .withColumn('year',F.year('datetime'))

    #select relevant weather columns
    weather = weather.select(['year',
                              'month',
                              'day',
                              'hour',
                               F.col('clouds.all').alias('cloudiness_pct'),
                               F.col('main.temp').alias('temperature'),
                              F.col('main.feels_like').alias('temp_feel_like'),
                              F.col('main.temp_max').alias('temperature_max'),
                              F.col('main.temp_min').alias('temperature_min'),
                              F.col('main.humidity').alias('humidity'),
                              F.col('main.pressure').alias('pressure'),
                              F.col('rain.1h').alias('rain_last_1h'),
                              F.col('rain.3h').alias('rain_last_3h'),
                              F.col('snow.1h').alias('snow_last_1h'),
                              F.col('snow.3h').alias('snow_last_3h'),
                              F.col('weather')[0].main.alias('weather_type'),
                              F.col('weather')[0].description.alias('weather_description'),
                              F.col('wind').speed.alias('wind_speed'),
                              F.col('wind').deg.alias('wind_degree')
    ])


    ######################
    #Process trip record data
    ######################

    #load trips record data
    trips = spark.read.csv(input_data + input_key, header=True, inferSchema=True)

    #remove all data with NULL VendorID
    trips = trips.filter(F.col('VendorID').isNotNull())

    #remove all data out of date range
    trips = trips.filter((F.col('tpep_pickup_datetime') >= start_datetime) & (F.col('tpep_pickup_datetime') <= end_datetime))

    #load other dimensios table
    vendors = spark.read.json(input_data + "codes/vendor.json")
    payments = spark.read.json(input_data + "codes/payment.json")
    rates = spark.read.json(input_data + "codes/ratecode.json")

    #need separate table for pickup and dropoff to avoid cross join issue 
    taxizones_pickup = spark.read.csv(input_data + "codes/taxizone.csv", header=True, inferSchema=True)
    taxizones_dropoff = spark.read.csv(input_data + "codes/taxizone.csv", header=True, inferSchema=True)

    #join reference tables
    trips = trips.join(vendors, (trips.VendorID == vendors.id), 'left')
    trips = trips.join(payments, (trips.payment_type == payments.id), 'left')
    trips = trips.join(rates, (trips.RatecodeID == rates.id),'left')

    #join taxizones for pickup location
    trips = trips.join(taxizones_pickup.select(['LocationID',F.col('Borough').alias('pickup_borough'),
                                         F.col('Zone').alias('pickup_zone')]),
                       trips.PULocationID == taxizones_pickup.LocationID, 'left')

    #join taxizones for dropoff location
    trips = trips.join(taxizones_dropoff.select(['LocationID',F.col('Borough').alias('dropoff_borough'),
                                                 F.col('Zone').alias('dropoff_zone')]),
                       trips.DOLocationID == taxizones_dropoff.LocationID, 'left_outer')

    #calcuation trip duration
    trips = trips.withColumn('trip_duration', 
                             (F.col('tpep_dropoff_datetime').cast(Long()) - F.col('tpep_pickup_datetime').cast(Long()))/60)

    trips = trips.select(['provider',
                          F.col('tpep_pickup_datetime').alias('pickup_datetime'),
                          F.col('tpep_dropoff_datetime').alias('dropoff_datetime'),
                          'pickup_borough',
                          'pickup_zone',
                          'dropoff_borough',
                          'dropoff_zone',
                          'trip_duration',
                          'trip_distance',
                          'fare_amount',
                          'extra',
                          'mta_tax',
                          'tip_amount',
                          'tolls_amount',
                          'improvement_surcharge',
                          'congestion_surcharge',
                          'total_amount',
                          'payment',
                          'rate',
                          'store_and_fwd_flag'])

    #############################
    #Create time dimension table
    ############################

    time = trips.select(F.col('pickup_datetime')).distinct().withColumn('hour',F.hour('pickup_datetime')) \
        .withColumn('day',F.dayofmonth('pickup_datetime')) \
        .withColumn('week',F.weekofyear('pickup_datetime')) \
        .withColumn('month',F.month('pickup_datetime')) \
        .withColumn('year',F.year('pickup_datetime')) \
        .withColumn('weekday',F.dayofweek('pickup_datetime'))

    #############################
    #Join trip and weather tables
    ############################

    #join through the time table
    analytics = trips.join(time, ['pickup_datetime'],'left').join(weather,['year','month','day','hour'], 'left')

    print('Data processing completed...')

    #############################
    #Perform check
    ############################

    analytics.createOrReplaceTempView("sql_view")

    error_count = 0
    num_of_test = len(dq_checks)

    for test_num, dq_check in enumerate(dq_checks,1):
        
        
        name = dq_check['name']
        sql_query = dq_check['sql_query']
        expected_result = dq_check['expected_result']

        print(f'running data quality test {test_num} of {num_of_test}...')
        print(f'test name: {name}')

        result = spark.sql(sql_query).collect()[0][0]

        if result == expected_result:
            print(f'result: test passed')
        else:
            print(f'result: test failed. Result is {result}. Expected result is {expected_result}.')
            error_count += 1

    #############################
    #load analytics table to output s3
    ############################

    if (error_count > 0) & (load_on_error == False):
        print('Data tables not loaded due to failed data quality cuecks')
    else:
        print('loading data to s3...')
        analytics.write.partitionBy('year','month').parquet(os.path.join(output_data, output_key), 'overwrite')
        print('loading completed.')

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [22]:
# Run ETL Function
input_data = "s3://nyc-yellow-cab-project/"
input_key = f"tripdata/2019/*.csv"
output_data = "s3://nyc-yellow-cab-project/"
output_key = 'analytics'
start_datetime = '2019-01-01 00:00:00'
end_datetime = '2019-12-31 23:59:59'
dq_checks = [{"name":"check number of missing provider",
"sql_query": """SELECT COUNT(*) FROM sql_view WHERE PROVIDER IS NULL""",
"expected_result": 0
},
{"name": "check number of missing temperature value",
"sql_query": """SELECT COUNT(*) FROM sql_view WHERE TEMPERATURE IS NULL""",
"expected_result": 0
}
]

load_on_error = False

etl(input_data, input_key,output_data, output_key,start_datetime,end_datetime,dq_checks,load_on_error)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Processing data...
Data processing completed...
running data quality test 1 of 2...
test name: check number of missing provider
result: test passed
running data quality test 2 of 2...
test name: check number of missing temperature value
result: test passed
loading data to s3...
loading completed.