In [24]:
import os
os.environ['HADOOP_HOME'] = 'C:\\Hadoop'
os.environ['PYSPARK_PYTHON'] = 'python'

# Import necessary libraries
from pyspark.sql import SparkSession

# Create a local SparkSession
# This is the entry point to all Spark functionality.
# .master("local[*]") tells Spark to use all available cores on your local machine.
# .appName(...) gives your application a name.
spark = SparkSession.builder \
    .master("local[*]") \
    .appName("NYCTransitDataAnalysis") \
    .config("spark.driver.memory", "6g") \
    .getOrCreate()

# Print the SparkSession object to confirm it's created
print(spark)

<pyspark.sql.session.SparkSession object at 0x00000209EACD4CA0>


In [25]:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType, TimestampType

# Based on the verified schemas from the Parquet files

# Yellow Cab Schema (19 columns)
yellow_cab_schema = StructType([
    StructField("VendorID", IntegerType(), True),
    StructField("tpep_pickup_datetime", TimestampType(), True),
    StructField("tpep_dropoff_datetime", TimestampType(), True),
    StructField("passenger_count", LongType(), True),
    StructField("trip_distance", DoubleType(), True),
    StructField("RatecodeID", LongType(), True),
    StructField("store_and_fwd_flag", StringType(), True),
    StructField("PULocationID", IntegerType(), True),
    StructField("DOLocationID", IntegerType(), True),
    StructField("payment_type", LongType(), True),
    StructField("fare_amount", DoubleType(), True),
    StructField("extra", DoubleType(), True),
    StructField("mta_tax", DoubleType(), True),
    StructField("tip_amount", DoubleType(), True),
    StructField("tolls_amount", DoubleType(), True),
    StructField("improvement_surcharge", DoubleType(), True),
    StructField("total_amount", DoubleType(), True),
    StructField("congestion_surcharge", DoubleType(), True),
    StructField("Airport_fee", DoubleType(), True)
])

# Green Cab Schema (20 columns)
green_cab_schema = StructType([
    StructField("VendorID", IntegerType(), True),
    StructField("lpep_pickup_datetime", TimestampType(), True),
    StructField("lpep_dropoff_datetime", TimestampType(), True),
    StructField("store_and_fwd_flag", StringType(), True),
    StructField("RatecodeID", LongType(), True),
    StructField("PULocationID", IntegerType(), True),
    StructField("DOLocationID", IntegerType(), True),
    StructField("passenger_count", LongType(), True),
    StructField("trip_distance", DoubleType(), True),
    StructField("fare_amount", DoubleType(), True),
    StructField("extra", DoubleType(), True),
    StructField("mta_tax", DoubleType(), True),
    StructField("tip_amount", DoubleType(), True),
    StructField("tolls_amount", DoubleType(), True),
    StructField("ehail_fee", DoubleType(), True),
    StructField("improvement_surcharge", DoubleType(), True),
    StructField("total_amount", DoubleType(), True),
    StructField("payment_type", LongType(), True),
    StructField("trip_type", LongType(), True),
    StructField("congestion_surcharge", DoubleType(), True)
])

# Detailed HVFHV Schema (24 columns)
fhvhv_schema = StructType([
    StructField("hvfhs_license_num", StringType(), True),
    StructField("dispatching_base_num", StringType(), True),
    StructField("originating_base_num", StringType(), True),
    StructField("request_datetime", TimestampType(), True),
    StructField("on_scene_datetime", TimestampType(), True),
    StructField("pickup_datetime", TimestampType(), True),
    StructField("dropoff_datetime", TimestampType(), True),
    StructField("PULocationID", IntegerType(), True),
    StructField("DOLocationID", IntegerType(), True),
    StructField("trip_miles", DoubleType(), True),
    StructField("trip_time", LongType(), True),
    StructField("base_passenger_fare", DoubleType(), True),
    StructField("tolls", DoubleType(), True),
    StructField("bcf", DoubleType(), True),
    StructField("sales_tax", DoubleType(), True),
    StructField("congestion_surcharge", DoubleType(), True),
    StructField("airport_fee", DoubleType(), True),
    StructField("tips", DoubleType(), True),
    StructField("driver_pay", DoubleType(), True),
    StructField("shared_request_flag", StringType(), True),
    StructField("shared_match_flag", StringType(), True),
    StructField("access_a_ride_flag", StringType(), True),
    StructField("wav_request_flag", StringType(), True),
    StructField("wav_match_flag", StringType(), True)
])

print("All schemas defined successfully.")

All schemas defined successfully.


In [26]:
# Update these file names to match the files in your /data folder
base_path = "../data/"
yellow_cab_df = spark.read.schema(yellow_cab_schema).parquet(f"{base_path}yellow_tripdata_2024-01.parquet")
green_cab_df = spark.read.schema(green_cab_schema).parquet(f"{base_path}green_tripdata_2024-01.parquet")
fhvhv_df = spark.read.schema(fhvhv_schema).parquet(f"{base_path}fhvhv_tripdata_2024-01.parquet")

print("All source data loaded into DataFrames.")

# Verify by checking the schema of one DataFrame
print("\nYellow Cab DataFrame Schema:")
yellow_cab_df.printSchema()

All source data loaded into DataFrames.

Yellow Cab DataFrame Schema:
root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)



In [27]:
from pyspark.sql.functions import col, lit

# Standardize Yellow Cab data
# ===========================
yc_transformed = yellow_cab_df \
    .withColumnRenamed("tpep_pickup_datetime", "pickup_datetime") \
    .withColumnRenamed("tpep_dropoff_datetime", "dropoff_datetime") \
    .withColumn("hvfhs_license_num", lit(None).cast(StringType())) \
    .withColumn("dispatching_base_num", lit(None).cast(StringType())) \
    .withColumn("originating_base_num", lit(None).cast(StringType())) \
    .withColumn("request_datetime", lit(None).cast(TimestampType())) \
    .withColumn("on_scene_datetime", lit(None).cast(TimestampType())) \
    .withColumn("trip_miles", col("trip_distance")) \
    .withColumn("trip_time", lit(None).cast(LongType())) \
    .withColumn("base_passenger_fare", col("fare_amount")) \
    .withColumn("tolls", col("tolls_amount")) \
    .withColumn("bcf", lit(None).cast(DoubleType())) \
    .withColumn("sales_tax", lit(None).cast(DoubleType())) \
    .withColumn("tips", col("tip_amount")) \
    .withColumn("driver_pay", lit(None).cast(DoubleType())) \
    .withColumn("shared_request_flag", lit(None).cast(StringType())) \
    .withColumn("shared_match_flag", lit(None).cast(StringType())) \
    .withColumn("access_a_ride_flag", lit(None).cast(StringType())) \
    .withColumn("wav_request_flag", lit(None).cast(StringType())) \
    .withColumn("wav_match_flag", lit(None).cast(StringType())) \
    .withColumn("SR_Flag", lit(None).cast(IntegerType())) \
    .withColumn("Affiliated_base_number", lit(None).cast(StringType())) \
    .withColumn("ehail_fee", lit(None).cast(DoubleType())) \
    .withColumn("trip_type", lit(None).cast(IntegerType())) \
    .withColumn("DataSource", lit("yellow_cab"))

# Standardize Green Cab data
# ==========================
gc_transformed = green_cab_df \
    .withColumnRenamed("lpep_pickup_datetime", "pickup_datetime") \
    .withColumnRenamed("lpep_dropoff_datetime", "dropoff_datetime") \
    .withColumn("Airport_fee", lit(None).cast(DoubleType())) \
    .withColumn("hvfhs_license_num", lit(None).cast(StringType())) \
    .withColumn("dispatching_base_num", lit(None).cast(StringType())) \
    .withColumn("originating_base_num", lit(None).cast(StringType())) \
    .withColumn("request_datetime", lit(None).cast(TimestampType())) \
    .withColumn("on_scene_datetime", lit(None).cast(TimestampType())) \
    .withColumn("trip_miles", col("trip_distance")) \
    .withColumn("trip_time", lit(None).cast(LongType())) \
    .withColumn("base_passenger_fare", col("fare_amount")) \
    .withColumn("tolls", col("tolls_amount")) \
    .withColumn("bcf", lit(None).cast(DoubleType())) \
    .withColumn("sales_tax", lit(None).cast(DoubleType())) \
    .withColumn("tips", col("tip_amount")) \
    .withColumn("driver_pay", lit(None).cast(DoubleType())) \
    .withColumn("shared_request_flag", lit(None).cast(StringType())) \
    .withColumn("shared_match_flag", lit(None).cast(StringType())) \
    .withColumn("access_a_ride_flag", lit(None).cast(StringType())) \
    .withColumn("wav_request_flag", lit(None).cast(StringType())) \
    .withColumn("wav_match_flag", lit(None).cast(StringType())) \
    .withColumn("SR_Flag", lit(None).cast(IntegerType())) \
    .withColumn("Affiliated_base_number", lit(None).cast(StringType())) \
    .withColumn("DataSource", lit("green_cab"))

# Standardize HVFHV data (Fully Corrected)
# ========================================
fhvhv_transformed = fhvhv_df \
    .withColumnRenamed("PUlocationID", "PULocationID") \
    .withColumnRenamed("DOlocationID", "DOLocationID") \
    .withColumn("fare_amount", col("base_passenger_fare")) \
    .withColumn("tolls_amount", col("tolls")) \
    .withColumn("tip_amount", col("tips")) \
    .withColumn("VendorID", lit(None).cast(IntegerType())) \
    .withColumn("passenger_count", lit(None).cast(IntegerType())) \
    .withColumn("trip_distance", col("trip_miles")) \
    .withColumn("RatecodeID", lit(None).cast(IntegerType())) \
    .withColumn("store_and_fwd_flag", lit(None).cast(StringType())) \
    .withColumn("payment_type", lit(None).cast(IntegerType())) \
    .withColumn("extra", lit(None).cast(DoubleType())) \
    .withColumn("mta_tax", lit(None).cast(DoubleType())) \
    .withColumn("improvement_surcharge", lit(None).cast(DoubleType())) \
    .withColumn("total_amount", lit(None).cast(DoubleType())) \
    .withColumn("ehail_fee", lit(None).cast(DoubleType())) \
    .withColumn("trip_type", lit(None).cast(IntegerType())) \
    .withColumn("Affiliated_base_number", lit(None).cast(StringType())) \
    .withColumn("DataSource", lit("fhvhv"))

print("All source DataFrames have been transformed and standardized.")

All source DataFrames have been transformed and standardized.


In [28]:
# The final, definitive list of columns for the unified DataFrame

final_columns = [
    # Core Identifiers and Timestamps
    "DataSource",
    "pickup_datetime",
    "dropoff_datetime",
    "request_datetime",
    "on_scene_datetime",
    
    # Location and Vendor IDs
    "PULocationID",
    "DOLocationID",
    "VendorID",
    "hvfhs_license_num",
    "dispatching_base_num",
    "originating_base_num",
    "Affiliated_base_number",
    
    # Trip Metrics
    "passenger_count",
    "trip_distance",
    "trip_miles",
    "trip_time",

    # Standardized Financials
    "fare_amount",
    "extra",
    "mta_tax",
    "tip_amount",
    "tolls_amount",
    "improvement_surcharge",
    "congestion_surcharge",
    "total_amount",
    "Airport_fee",
    "ehail_fee",
    "bcf",
    "sales_tax",
    "driver_pay",

    # Source-Specific Financials (now present in all DFs)
    "base_passenger_fare",
    "tolls",
    "tips",

    # Flags and Codes
    "RatecodeID",
    "payment_type",
    "trip_type",
    "store_and_fwd_flag",
    "SR_Flag",
    "shared_request_flag",
    "shared_match_flag",
    "access_a_ride_flag",
    "wav_request_flag",
    "wav_match_flag"
]

def ensure_columns(df, required_columns):
    """Add missing columns with null values"""
    for col in required_columns:
        if col not in df.columns:
            df = df.withColumn(col, lit(None).cast("string"))
    return df

# Ensure all DataFrames have the required columns
yc_transformed = ensure_columns(yc_transformed, final_columns)
gc_transformed = ensure_columns(gc_transformed, final_columns)
fhvhv_transformed = ensure_columns(fhvhv_transformed, final_columns)

# Union all the standardized DataFrames
# Using select() on each ensures columns are in a consistent order before union
final_df = yc_transformed.select(final_columns) \
    .unionByName(gc_transformed.select(final_columns)) \
    .unionByName(fhvhv_transformed.select(final_columns))

print("DataFrames unified successfully.")

# Verify the final schema and a sample of the unified data
final_df.printSchema()
final_df.show(5)

DataFrames unified successfully.
root
 |-- DataSource: string (nullable = false)
 |-- pickup_datetime: timestamp (nullable = true)
 |-- dropoff_datetime: timestamp (nullable = true)
 |-- request_datetime: timestamp (nullable = true)
 |-- on_scene_datetime: timestamp (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- VendorID: integer (nullable = true)
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: string (nullable = true)
 |-- originating_base_num: string (nullable = true)
 |-- Affiliated_base_number: string (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- trip_miles: double (nullable = true)
 |-- trip_time: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 