### SETUP & CONSTANTS

In [14]:
import re
import os
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
import pyspark.sql.functions as F
from pyspark.sql import Window
from functools import reduce
from collections import defaultdict
import json

# CONSTANTS
PARQUET_INPUT_PATH = 'storage/taxi_industry/parquet'
CSV_INPUT_PATH = 'storage/taxi_industry/csv'
OUTPUT_PATH = 'storage/taxi_industry/parquet_processed'
TABLE_NAMES = [
    'yellow_tripdata',
    'green_tripdata',
]

In [15]:
def unique(df, group_by, select_cols,min=1):
    return df.groupBy(*group_by).agg(F.collect_set(F.struct(*select_cols)).alias('objs')).where(F.array_size('objs')>min)

In [16]:
spark= (
    SparkSession.builder.appName('transform')
    .getOrCreate()
)

### GENERIC PARQUET LOAD

In [17]:
df_raw={}
for table_name in TABLE_NAMES:
    df_raw[table_name] = spark.read.parquet(f'{PARQUET_INPUT_PATH}/{table_name}_2025*.parquet')
    print(f'{table_name}:')
    print(f'{df_raw[table_name].count()=}')
    df_raw[table_name].limit(2).show()


yellow_tripdata:
df_raw[table_name].count()=19760424
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-05-01 00:07:06|  2025-05-01 00:24:15|              1|        

### COLUMN INSPECTION

In [18]:
columns = defaultdict(list)
for c in df_raw['yellow_tripdata'].columns:
    columns[c].append('yellow')

for c in df_raw['green_tripdata'].columns:
    columns[c].append('green')

print('shared columns:')
print([c for c in columns.keys() if len(columns[c]) == 2])
print('unique columns:')
print([(k,v) for k,v in columns.items() if len(v) == 1])


shared columns:
['VendorID', 'passenger_count', 'trip_distance', 'RatecodeID', 'store_and_fwd_flag', 'PULocationID', 'DOLocationID', 'payment_type', 'fare_amount', 'extra', 'mta_tax', 'tip_amount', 'tolls_amount', 'improvement_surcharge', 'total_amount', 'congestion_surcharge', 'cbd_congestion_fee']
unique columns:
[('tpep_pickup_datetime', ['yellow']), ('tpep_dropoff_datetime', ['yellow']), ('Airport_fee', ['yellow']), ('lpep_pickup_datetime', ['green']), ('lpep_dropoff_datetime', ['green']), ('ehail_fee', ['green']), ('trip_type', ['green'])]


### TABLES

#### df_trip

In [19]:
shared_columns = [
    'VendorID',
    'passenger_count',
    'trip_distance', # in miles
    'RatecodeID',
    'store_and_fwd_flag', # Y = store and forward trip, N = not a store and forward trip
    'PULocationID', # taximeter was engaged.
    'DOLocationID', # taximeter was disengaged
    'payment_type',

    # costs/charges
    'fare_amount', # The time-and-distance fare calculated by the meter
    'extra',
    'mta_tax',
    'tip_amount', # Tip amount – This field is automatically populated for credit card tips. Cash tips are not included.
    'tolls_amount',
    'improvement_surcharge',
    'congestion_surcharge',
    'cbd_congestion_fee', # Per-trip charge for MTA's Congestion Relief Zone
    
    'total_amount',
]

df_trip = (
    reduce(
        DataFrame.unionByName,
        [
            df_raw['yellow_tripdata'].select(
                F.lit('yellow').alias('taxi_type'),
                *shared_columns,
                F.col('tpep_pickup_datetime').alias('pickup_datetime'),
                F.col('tpep_dropoff_datetime').alias('dropoff_datetime'),
            ),
            df_raw['green_tripdata'].select(
                F.lit('green').alias('taxi_type'),
                *shared_columns,
                F.col('lpep_pickup_datetime').alias('pickup_datetime'),
                F.col('lpep_dropoff_datetime').alias('dropoff_datetime'),
            ),
        ]
    )
    .withColumn('id', F.monotonically_increasing_id()) # unique identifier for each trip
    .withColumn('duration', F.round((F.unix_timestamp('dropoff_datetime') - F.unix_timestamp('pickup_datetime'))/60,2)) # in minutes
    .withColumn('charge_per_mile', F.round(F.col('total_amount') / F.col('trip_distance'),2)) # in $/mile
    .withColumn('charge_per_minute', F.round(F.col('total_amount') / (F.col('duration')),2)) # in $/minute
    .withColumn('pickup_hour', F.hour('pickup_datetime')) # 0 - 23
    .withColumn('day_of_week', F.dayofweek('pickup_datetime')) # 1=Sunday, 7=Saturday
    .withColumn('non_fare_amount', F.round(F.col('total_amount') - F.col('fare_amount'),2)) # non-fare amount (tips, tolls, etc.)
    
    .where('duration > 1 and trip_distance > 0.1') # filter out unrealistic trip durations and distances
    .where('total_amount between 1 and 2000') # filter out unrealistic trip costs
)
df_trip.orderBy(F.col('charge_per_minute').desc()).show(10)

+---------+--------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+--------------------+------------------+------------+-------------------+-------------------+-----------+--------+---------------+-----------------+-----------+-----------+---------------+
|taxi_type|VendorID|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|congestion_surcharge|cbd_congestion_fee|total_amount|    pickup_datetime|   dropoff_datetime|         id|duration|charge_per_mile|charge_per_minute|pickup_hour|day_of_week|non_fare_amount|
+---------+--------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+--------------------+------------------+---------

#### df_taxi_zone

In [20]:
df_taxi_zone = (
    spark.read.csv(f'{CSV_INPUT_PATH}/taxi_zone_lookup.csv', header=True)
    .withColumn('LocationID', F.col('LocationID').cast('int'))
)
df_taxi_zone.show(5)
df_taxi_zone.printSchema()

+----------+-------------+--------------------+------------+
|LocationID|      Borough|                Zone|service_zone|
+----------+-------------+--------------------+------------+
|         1|          EWR|      Newark Airport|         EWR|
|         2|       Queens|         Jamaica Bay|   Boro Zone|
|         3|        Bronx|Allerton/Pelham G...|   Boro Zone|
|         4|    Manhattan|       Alphabet City| Yellow Zone|
|         5|Staten Island|       Arden Heights|   Boro Zone|
+----------+-------------+--------------------+------------+
only showing top 5 rows

root
 |-- LocationID: integer (nullable = true)
 |-- Borough: string (nullable = true)
 |-- Zone: string (nullable = true)
 |-- service_zone: string (nullable = true)



#### df_vendor

In [21]:
df_vendor = spark.createDataFrame([
    (1, 'Creative Mobile Technologies, LLC'),
    (2, 'Curb Mobility, LLC'),
    (6, 'Myle Technologies Inc'),
    (7, 'Helix'),
], ['VendorID', 'VendorName'])

#### df_rate_code

In [22]:
df_rate_code = spark.createDataFrame([
    (1, 'Standard rate'),
    (2, 'JFK'), # John F. Kennedy Airport (JFK)
    (3, 'Newark'), # Newark Airport (EWR): tolls + surcharge
    (4, 'Nassau or Westchester'), # Westchester and Nassau Counties tolls
    (5, 'Negotiated fare'), # Other Points Outside the City
    (6, 'Group ride'),
    (99, 'Null/unknown'),
], ['RatecodeID', 'RateName'])

#### df_payment_type

In [23]:
df_payment_type = spark.createDataFrame([
    (0, 'Flex Fare trip'),
    (1, 'Credit card'),
    (2, 'Cash'),
    (3, 'No charge'),
    (4, 'Dispute'),
    (5, 'Unknown'),
    (6, 'Voided trip'),
], ['payment_type','PaymentTypeName'])

### ANOMALIES

In [24]:
df_trip.where('VendorID = 7').show(5) # Vendor 7 (Helix) doesn't actually record dropoff_datetime, so we can't calculate duration or charge per mile/minute

+---------+--------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+--------------------+------------------+------------+---------------+----------------+---+--------+---------------+-----------------+-----------+-----------+---------------+
|taxi_type|VendorID|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|congestion_surcharge|cbd_congestion_fee|total_amount|pickup_datetime|dropoff_datetime| id|duration|charge_per_mile|charge_per_minute|pickup_hour|day_of_week|non_fare_amount|
+---------+--------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+--------------------+------------------+------------+---------------+----------

### WRITE TO DISK

In [25]:
dataframe_list = [
    'df_trip',
    'df_taxi_zone',
    'df_vendor',
    'df_rate_code',
    'df_payment_type',
]

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

for df_name in dataframe_list:
    print(f'Writing {df_name} to parquet...')
    globals()[df_name].write.mode('overwrite').parquet(f'{OUTPUT_PATH}/{df_name.replace("df_", "")}')
print('Finished!')

Writing df_trip to parquet...
Writing df_taxi_zone to parquet...
Writing df_vendor to parquet...
Writing df_rate_code to parquet...
Writing df_payment_type to parquet...
Finished!
