In [None]:
print(spark.version)

In [None]:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType, LongType
import pyspark.sql.functions as F
from pyspark.sql import DataFrame
import boto3
from functools import reduce
import pyarrow.parquet as pq
from datetime import datetime


S3 = "s3"
BUCKET_NAME = "robot-dreams-source-data"
YELLOW_TAXI_DIR = "home-work-1/nyc_taxi/yellow/"
GREEN_TAXI_DIR = "home-work-1/nyc_taxi/green/"
PARQUET_DIRS = [YELLOW_TAXI_DIR, GREEN_TAXI_DIR]
TAXI_ZONE_LOOKUP = "s3://robot-dreams-source-data/home-work-1/nyc_taxi/taxi_zone_lookup.csv"

In [None]:
def s3_files_search(s3, bucket_name, dirs, files_pathes=[], files_extension=None):
    
    while dirs:
        
        current_dir = dirs.pop(0)
        dir_iterator = paginator.paginate(Bucket=bucket_name, Prefix=current_dir, Delimiter="/")
        
        for dir_data in dir_iterator:
            
            if "CommonPrefixes" in dir_data:
                current_subdirs = [
                    subdir_dict['Prefix'] 
                    for subdir_dict in dir_data["CommonPrefixes"]
                ]
                
                dirs.extend(current_subdirs)

            if "Contents" in dir_data:
                current_files_pathes = [
                    s3 + "://" + bucket_name + "/" + files_dict["Key"] 
                    for files_dict in dir_data["Contents"] 
                    if files_dict["Key"].endswith(files_extension)
                ]
                
                if current_files_pathes:
                    files_pathes.extend(current_files_pathes)

    return files_pathes

In [None]:
def get_parquet_schema_summary(parquet_files_pathes):
    
    schema_summary = {}
    
    for parquet_file_path in parquet_files_pathes:
        parquet_file_metadata = pq.ParquetFile(parquet_file_path)
        parquet_file_schema = parquet_file_metadata.schema
    
        for col_idx in range(len(parquet_file_schema)):
            column_name = parquet_file_schema.column(col_idx).name.lower()  
            
            logical_data_type = parquet_file_schema.column(col_idx).logical_type.type
            if logical_data_type in ["NONE", "UNKNOWN"]:
                logical_data_type = None
            
            physical_data_type = parquet_file_schema.column(col_idx).physical_type
            
            # Initialize schema entry if first time seen
            if column_name not in schema_summary:
                schema_summary[column_name] = {
                    "col_counter": 1,
                    "dtypes": [logical_data_type or physical_data_type]
                }
            else:
                schema_summary[column_name]["col_counter"] += 1
                dtype = logical_data_type or physical_data_type
                if dtype not in schema_summary[column_name]["dtypes"]:
                    schema_summary[column_name]["dtypes"].append(dtype)
   
    return schema_summary

In [None]:
def get_df_cast_schema(parquet_files_schema, parquet_to_df_dtype_cast):
    df_cast_schema = {}
    
    for col_name, col_data in parquet_files_schema.items():
        if len(col_data['dtypes']) > 1:
            df_cast_schema[col_name] = parquet_to_df_dtype_cast[col_data['dtypes'][0]]
        
    return df_cast_schema

In [None]:
def load_parquets_with_cast(parquet_files_pathes, df_cast_schema):
    dfs = []
    for parquet_file_path in parquet_files_pathes:
        taxi_type = parquet_file_path.split("/")[5]
        df = spark.read.parquet(parquet_file_path)
        for col_name in df_cast_schema.keys():
            if col_name in df.columns:
                df = df.withColumn(col_name, F.col(col_name).cast(df_cast_schema[col_name]))
        df = df.withColumn("taxi_type", F.lit(taxi_type))
        dfs.append(df)
    return dfs

In [None]:
# Inspect parquet files schema 
s3_boto = boto3.client("s3")
paginator = s3_boto.get_paginator("list_objects_v2")
parquet_files_pathes = s3_files_search(S3, BUCKET_NAME, PARQUET_DIRS, files_extension=".parquet")
parquet_files_schema_summary = get_parquet_schema_summary(parquet_files_pathes)
parquet_files_schema_summary

In [None]:
# Prepare list of fields that needs to be casted after read
parquet_to_df_dtype_cast = {
    "INT32": "long",
    "INT64": "long",
    "DOUBLE": "double",
    "BYTE_ARRAY": "string",
    "BOOLEAN": "boolean",
    "FLOAT": "float",
    "FIXED_LEN_BYTE_ARRAY": "binary",
    "TIMESTAMP": "timestamp",
    "STRING": "string"
}

df_cast_schema = get_df_cast_schema(parquet_files_schema_summary, parquet_to_df_dtype_cast)
df_cast_schema

In [None]:
# Load parquet files with cast
dfs = load_parquets_with_cast(parquet_files_pathes, df_cast_schema)

In [None]:
# Union all the individual dfs
raw_trips_df = reduce(lambda df1, df2: df1.unionByName(df2, allowMissingColumns=True), dfs)

In [None]:
# Add additional computations needed for further aggrgations
raw_trips_df_enriched = raw_trips_df.withColumns({
    "trip_distance": F.coalesce(F.col("trip_distance"), F.lit(0)),
    "fare_amount": F.coalesce(F.col("fare_amount"), F.lit(0)),
    "pickup_datetime": F.coalesce(F.col('tpep_pickup_datetime'), F.col('lpep_pickup_datetime')),
    "dropoff_datetime": F.coalesce(F.col('tpep_dropoff_datetime'), F.col('lpep_dropoff_datetime')),
    "duration_min": (F.unix_timestamp("dropoff_datetime") - F.unix_timestamp("pickup_datetime")) / 60,
    "pickup_hour": F.hour("pickup_datetime"),
    "pickup_day_of_week": F.date_format("pickup_datetime", "E"),
    "yellow_trip": F.when(F.col("taxi_type") == "yellow", 1).otherwise(0),
    "green_trip": F.when(F.col("taxi_type") == "green", 1).otherwise(0),
    "high_fare_trip": F.when(F.col("fare_amount") > 30, 1).otherwise(0)
})

In [None]:
# Apply filtering conditions
trips_df_filtered = raw_trips_df_enriched.filter(
    (F.col("trip_distance") > 0.1) &
    (F.col("fare_amount") > 2.0) &
    (F.col("duration_min") > 1.0)
)

In [None]:
# Load taxi zone dictionary
taxi_zone_df = spark.read.format("csv").option("header", "true").load(TAXI_ZONE_LOOKUP)

In [None]:
# Cast location id to be long type
taxi_zone_df = taxi_zone_df.withColumn(
    "LocationID", F.col("LocationID").cast("long")
)

In [None]:
# Create temporary dfs to lookup pickup and dropoff zones
pickup_zone_df = taxi_zone_df.select(
    F.col("LocationID").alias("PU_LocationID"),
    F.col("Zone").alias("pickup_zone")
)

dropoff_zone_df = taxi_zone_df.select(
    F.col("LocationID").alias("DO_LocationID"),
    F.col("Zone").alias("dropoff_zone")
)

In [None]:
# Erich trips df with data from taxi zone dictionary
trips_df_enriched = (
    trips_df_filtered
    .join(pickup_zone_df, trips_df_filtered["PULocationID"] == pickup_zone_df["PU_LocationID"], how="left")
    .join(dropoff_zone_df, trips_df_filtered["DOLocationID"] == dropoff_zone_df["DO_LocationID"], how="left")
    .drop("PU_LocationID", "DO_LocationID")
)

In [None]:
# Create aggregation and calculations for zone summary
zone_summary = trips_df_enriched.groupby("pickup_zone").agg(
    F.count("*").alias("total_trips"),
    F.avg('trip_distance').alias('avg_trip_distance'),
    F.avg('total_amount').alias('avg_total_amount'),
    F.avg('tip_amount').alias('avg_tip_amount'),
    F.sum('yellow_trip').alias('yellow_trips'),
    F.sum('green_trip').alias('green_trips'),
    F.max('trip_distance').alias('max_trip_distance'),
    F.min('tip_amount').alias('min_tip_amount'),
)

zone_summary = zone_summary.withColumns({
    "yellow_share": F.col("yellow_trips") / F.col("total_trips"),
    "green_share": F.col("green_trips") / F.col("total_trips")
})

In [None]:
# Create aggregation and calculations for by weekday, zone summary
zone_days_statistic = trips_df_enriched.groupby(["pickup_day_of_week", "pickup_zone"]).agg(
    F.count("*").alias("total_trips"),
    F.sum("high_fare_trip").alias("high_fare_trips")
)

zone_days_statistic = zone_days_statistic.withColumn(
    "high_fare_share", F.col("high_fare_trips") / F.col("total_trips")
)

zone_days_statistic = zone_days_statistic.filter(F.col("pickup_day_of_week") == "Mon")

In [None]:
# Save files
date_str = datetime.today().strftime("%Y-%m-%d")
output_path_zone_s = f"s3://dhalahan-emr-studio/zone_statistic/{date_str}/"
output_path_zone_d = f"s3://dhalahan-emr-studio/zone_days_statistic/{date_str}/"
zone_summary.coalesce(1).write.mode("overwrite").parquet(output_path_zone_s)
zone_days_statistic.coalesce(1).write.mode("overwrite").parquet(output_path_zone_d)