# NYC Taxi PySpark Pipeline (Databricks)
*Dataset:* `/databricks-datasets/nyctaxi/` (yellow tripdata + taxi zone lookup)

This notebook implements:
- Data loading (>1GB)
- Apply transformations: 
    - 2+ filter operations
    - complex aggregation
    - 1+ groupBy with aggregations
    - Column transformations using withColumn
- 2+ SQL queries
- Optimization: filter pushdown, partitioning, avoiding shuffles, caching
- Write results to Parquet
- Actions vs Transformations
- MLlib model

# Load data

In [0]:
# Setup: inspect dataset structure
display(dbutils.fs.ls("/databricks-datasets/nyctaxi"))

path,name,size,modificationTime
dbfs:/databricks-datasets/nyctaxi/readme_nyctaxi.txt,readme_nyctaxi.txt,916,1596568072000
dbfs:/databricks-datasets/nyctaxi/reference/,reference/,0,1762914946558
dbfs:/databricks-datasets/nyctaxi/sample/,sample/,0,1762914946558
dbfs:/databricks-datasets/nyctaxi/tables/,tables/,0,1762914946558
dbfs:/databricks-datasets/nyctaxi/taxizone/,taxizone/,0,1762914946558
dbfs:/databricks-datasets/nyctaxi/tripdata/,tripdata/,0,1762914946558


In [0]:
display(dbutils.fs.ls("/databricks-datasets/nyctaxi/tripdata/yellow/"))

path,name,size,modificationTime
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-01.csv.gz,yellow_tripdata_2009-01.csv.gz,504262564,1596568279000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-02.csv.gz,yellow_tripdata_2009-02.csv.gz,480034681,1596568279000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-03.csv.gz,yellow_tripdata_2009-03.csv.gz,521102719,1596568279000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-04.csv.gz,yellow_tripdata_2009-04.csv.gz,515435466,1596568279000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-05.csv.gz,yellow_tripdata_2009-05.csv.gz,531133739,1596568279000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-06.csv.gz,yellow_tripdata_2009-06.csv.gz,508802995,1596568313000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-07.csv.gz,yellow_tripdata_2009-07.csv.gz,487731497,1596568318000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-08.csv.gz,yellow_tripdata_2009-08.csv.gz,490825210,1596568318000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-09.csv.gz,yellow_tripdata_2009-09.csv.gz,503121179,1596568318000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-10.csv.gz,yellow_tripdata_2009-10.csv.gz,567109604,1596568319000


In [0]:
# Read the dataset
yellow = (
    spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv("/databricks-datasets/nyctaxi/tripdata/yellow/*.csv.gz")
)

print("Rows:", yellow.count())
yellow.printSchema()
display(yellow.limit(5))

No such comm: LSP_COMM_ID


Rows: 1611611035
root
 |-- vendor_id: string (nullable = true)
 |-- pickup_datetime: string (nullable = true)
 |-- dropoff_datetime: string (nullable = true)
 |-- passenger_count: string (nullable = true)
 |-- trip_distance: string (nullable = true)
 |-- pickup_longitude: string (nullable = true)
 |-- pickup_latitude: string (nullable = true)
 |-- rate_code: string (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- dropoff_longitude: string (nullable = true)
 |-- dropoff_latitude: string (nullable = true)
 |-- payment_type: string (nullable = true)
 |-- fare_amount: string (nullable = true)
 |-- surcharge: string (nullable = true)
 |-- mta_tax: string (nullable = true)
 |-- tip_amount: string (nullable = true)
 |-- tolls_amount: string (nullable = true)
 |-- total_amount: string (nullable = true)



vendor_id,pickup_datetime,dropoff_datetime,passenger_count,trip_distance,pickup_longitude,pickup_latitude,rate_code,store_and_fwd_flag,dropoff_longitude,dropoff_latitude,payment_type,fare_amount,surcharge,mta_tax,tip_amount,tolls_amount,total_amount
CMT,2010-05-28 21:09:20,2010-05-28 21:30:50,2,4.6,-74.00168499999998,40.721345,1,0,-73.95820299999998,40.769568,Cas,15.3,0.5,0.5,0,0,16.3
CMT,2010-05-28 15:58:09,2010-05-28 16:01:31,1,0.6999999999999998,-73.95565499999998,40.776583,1,0,-73.94761599999998,40.78279,Cas,4.1,1.0,0.5,0,0,5.6
CMT,2010-05-28 10:42:44,2010-05-28 10:46:14,1,0.4,-73.987708,40.77521,1,0,-73.982101,40.775251,Cas,3.7,0.0,0.5,0,0,4.2
CMT,2010-05-27 23:14:35,2010-05-27 23:22:23,1,2.1,-73.97220199999998,40.755899,1,0,-73.98638699999998,40.730299,Cas,7.3,0.5,0.5,0,0,8.3
CMT,2010-05-28 00:10:10,2010-05-28 00:12:18,1,0.5,-73.99276999999998,40.748281,1,0,-74.000422,40.747921,Cas,3.7,0.5,0.5,0,0,4.7


In [0]:
# Inspect dataset size 
files = dbutils.fs.ls("/databricks-datasets/nyctaxi/tripdata/yellow/")
total_bytes = sum(f.size for f in files)
total_gb_compressed = total_bytes / (1024**3)
estimated_uncompressed_gb = total_gb_compressed * 4  # average ×4 expansion
print(f"Compressed size: {total_gb_compressed:.2f} GB")
print(f"Estimated uncompressed size: {estimated_uncompressed_gb:.2f} GB (≈×4 expansion)")

Compressed size: 46.79 GB
Estimated uncompressed size: 187.17 GB (≈×4 expansion)


In [0]:
# Convert once to Parquet for much faster reloads later
yellow.write.mode("overwrite").parquet("dbfs:/tmp/nyctaxi_parquet")

# Fast reload path (use this in later sessions instead of CSV):
yellow = spark.read.parquet("dbfs:/tmp/nyctaxi_parquet")

print("Parquet converted.")

Parquet converted.


In [0]:
# Randomly keep about 5% of rows
yellow_sampled = yellow.sample(withReplacement=False, fraction=0.05, seed=42)

print("Original count:", yellow.count())
print("Sampled count:", yellow_sampled.count())

Original count: 1611611035
Sampled count: 80576536


In [0]:
# This is about 9-10 GB. I will sample 15% again to keep ~1.5GB.
yellow_sampled_2 = yellow_sampled.sample(withReplacement=False, fraction=0.15, seed=42)
print("Sampled count:", yellow_sampled_2.count())

Sampled count: 12090682


In [0]:
# Read the taxi zone dataset
zones = (spark.read
         .option("header", True)
         .option("inferSchema", True)
         .csv("/databricks-datasets/nyctaxi/taxizone/taxi_zone_lookup.csv"))

print("Rows:", zones.count())
zones.printSchema()
display(zones.limit(5))

Rows: 265
root
 |-- LocationID: integer (nullable = true)
 |-- Borough: string (nullable = true)
 |-- Zone: string (nullable = true)
 |-- service_zone: string (nullable = true)



LocationID,Borough,Zone,service_zone
1,EWR,Newark Airport,EWR
2,Queens,Jamaica Bay,Boro Zone
3,Bronx,Allerton/Pelham Gardens,Boro Zone
4,Manhattan,Alphabet City,Yellow Zone
5,Staten Island,Arden Heights,Boro Zone


In [0]:
yellow_sampled_2.columns

['vendor_id',
 'pickup_datetime',
 'dropoff_datetime',
 'passenger_count',
 'trip_distance',
 'pickup_longitude',
 'pickup_latitude',
 'rate_code',
 'store_and_fwd_flag',
 'dropoff_longitude',
 'dropoff_latitude',
 'payment_type',
 'fare_amount',
 'surcharge',
 'mta_tax',
 'tip_amount',
 'tolls_amount',
 'total_amount']

# Apply transformations

In [0]:
from pyspark.sql import functions as F

pickup_col  = "pickup_datetime"
dropoff_col = "dropoff_datetime"

df = yellow_sampled_2

# Columns that should be numeric
numeric_cols = ["passenger_count","trip_distance","fare_amount","tip_amount","tolls_amount","total_amount"]

# Remove commas/whitespace/$ but KEEP e/E for scientific notation, then try_cast to double
for c in numeric_cols:
    if c in df.columns:
        df = df.withColumn(
            c,
            F.expr(f"try_cast(regexp_replace(cast({c} as string), '[, $]', '') as double)")
        )

# payment_type -> int (strip non-digits; tolerate bad rows)
if "payment_type" in df.columns:
    df = df.withColumn(
        "payment_type",
        F.expr("try_cast(regexp_replace(cast(payment_type as string), '[^0-9-]', '') as int)")
    )

# Timestamps + engineered columns
df = (df
      .withColumn("pickup_ts",  F.col(pickup_col).cast("timestamp"))
      .withColumn("dropoff_ts", F.col(dropoff_col).cast("timestamp"))
      .withColumn("pickup_date", F.to_date("pickup_ts"))
      .withColumn("pickup_hour", F.hour("pickup_ts"))
      .withColumn("trip_mins", (F.unix_timestamp("dropoff_ts") - F.unix_timestamp("pickup_ts"))/60.0)
      .withColumn("tip_pct",
                  F.when(F.col("fare_amount") > 0,
                         F.col("tip_amount") / F.col("fare_amount"))
                   .otherwise(F.lit(0.0)))
)

df.printSchema()
display(df.select("fare_amount","tip_amount","tip_pct","payment_type").limit(10))

root
 |-- vendor_id: string (nullable = true)
 |-- pickup_datetime: string (nullable = true)
 |-- dropoff_datetime: string (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- pickup_longitude: string (nullable = true)
 |-- pickup_latitude: string (nullable = true)
 |-- rate_code: string (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- dropoff_longitude: string (nullable = true)
 |-- dropoff_latitude: string (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- surcharge: string (nullable = true)
 |-- mta_tax: string (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- pickup_ts: timestamp (nullable = true)
 |-- dropoff_ts: timestamp (nullable = true)
 |-- pickup_date: date (nullable = true)
 |-- pickup_hour: integer (nullable = true)
 |-- trip_mi

fare_amount,tip_amount,tip_pct,payment_type
3.7,0.0,0.0,
6.9,1.58,0.2289855072463768,
13.7,2.0,0.145985401459854,
7.7,0.0,0.0,
4.9,0.0,0.0,
6.5,0.0,0.0,
4.5,2.0,0.4444444444444444,
8.5,1.0,0.1176470588235294,
10.5,0.0,0.0,
23.7,0.0,0.0,


In the next cell, I have optimized my queries since I did filter early in my pipeline, before join and groupBy.

In [0]:
# 4 filter and 1 join
lookup = spark.createDataFrame(
    [(0,"Unknown/No Charge"),(1,"Credit Card"),(2,"Cash"),(3,"No Charge/Dispute"),(4,"Dispute"),(5,"Unknown")],
    ["payment_type","payment_desc"]
)
df_f = (df
        .filter(F.col("pickup_date").isNotNull())
        .filter((F.col("trip_distance") > 0) & (F.col("trip_distance") < 200))
        .filter((F.col("fare_amount") > 0) & (F.col("fare_amount") < 1000))
        .filter(F.col("trip_mins") > 0)
       )
dfj = df_f.join(lookup, on="payment_type", how="left")
display(dfj.select("pickup_date","pickup_hour","trip_distance","fare_amount","tip_pct","payment_type","payment_desc").limit(10))

pickup_date,pickup_hour,trip_distance,fare_amount,tip_pct,payment_type,payment_desc
2010-05-28,0,0.5,3.7,0.0,,
2010-05-24,22,1.6,6.9,0.2289855072463768,,
2010-05-30,8,4.7,13.7,0.145985401459854,,
2010-05-12,9,1.8,7.7,0.0,,
2010-05-06,13,0.91,4.9,0.0,,
2010-05-30,19,1.19,6.5,0.0,,
2010-05-04,0,0.86,4.5,0.4444444444444444,,
2010-05-18,8,2.1,8.5,0.1176470588235294,,
2010-05-13,23,3.61,10.5,0.0,,
2010-05-12,11,9.9,23.7,0.0,,


In [0]:
# groupBy and agg and orderBy
# Aggregation by payment method
agg_by_pay = (dfj
              .groupBy("payment_type","payment_desc")
              .agg(
                  F.count("*").alias("trips"),
                  F.avg("trip_distance").alias("avg_miles"),
                  F.avg("fare_amount").alias("avg_fare"),
                  F.avg("tip_pct").alias("avg_tip_pct"),
                  F.sum("total_amount").alias("sum_rev")
              )
              .orderBy(F.desc("trips")))
display(agg_by_pay)
agg_by_pay.explain(mode="formatted")

payment_type,payment_desc,trips,avg_miles,avg_fare,avg_tip_pct,sum_rev
,,7690320,2.8510525594513605,10.94555247766016,0.0911043899581793,99431796.79000305
1.0,Credit Card,1425528,2.922872507590168,9.774187430902794,0.3318986884136245,447613.44999967417
0.0,Unknown/No Charge,1314571,3.000520443551546,0.5000058574242092,0.5999835687838162,442375.25
5.0,Unknown,777037,2.990820244081045,0.5000036034320117,0.5999868732121657,247523.0
2.0,Cash,571342,2.699831344448684,11.943479737180184,1.2685341213578476e-05,170653.20000000668
25.0,,100980,2.716745395127747,0.5000166369578134,0.5999821746882182,251790.0
3.0,No Charge/Dispute,57846,2.7361240189468585,1.4904327006188849,0.556104996837457,134975.8
35.0,,32381,2.08938760384176,0.5000114264537846,0.5999814706154329,80742.5
45.0,,10345,17.237866602223296,0.5,0.5999999999999971,4145.0
4.0,Dispute,1497,3.749899799599198,14.460126920507683,0.0,446.7000000000008


== Physical Plan ==
AdaptiveSparkPlan (29)
+- == Initial Plan ==
   ColumnarToRow (28)
   +- PhotonResultStage (27)
      +- PhotonSort (26)
         +- PhotonShuffleExchangeSource (25)
            +- PhotonShuffleMapStage (24)
               +- PhotonShuffleExchangeSink (23)
                  +- PhotonProject (22)
                     +- PhotonGroupingAgg (21)
                        +- PhotonShuffleExchangeSource (20)
                           +- PhotonShuffleMapStage (19)
                              +- PhotonShuffleExchangeSink (18)
                                 +- PhotonGroupingAgg (17)
                                    +- PhotonProject (16)
                                       +- PhotonBroadcastHashJoin LeftOuter (15)
                                          :- PhotonGroupingAgg (8)
                                          :  +- PhotonProject (7)
                                          :     +- PhotonProject (6)
                                          :        +- P

In [0]:
# groupBy and agg and orderBY
# Aggregation by daily profile
daily = (dfj
         .groupBy("pickup_date")
         .agg(
             F.count("*").alias("trips"),
             F.avg("trip_mins").alias("avg_trip_mins"),
             F.avg("tip_pct").alias("avg_tip_pct")
         )
         .orderBy("pickup_date"))
display(daily.limit(20))
daily.explain(mode="formatted")

pickup_date,trips,avg_trip_mins,avg_tip_pct
2008-12-31,5,24.020000000000003,0.6
2009-01-01,2398,10.878343063664136,0.0308396093056974
2009-01-02,2808,11.85220797720797,0.0316802741246063
2009-01-03,3128,11.219698422847404,0.0327033597041971
2009-01-04,2625,12.030539682539674,0.0393223486867346
2009-01-05,2684,11.893541977148542,0.0369231185789221
2009-01-06,3177,12.096742209631744,0.0394393853173968
2009-01-07,2637,12.506648969788907,0.0450125873621109
2009-01-08,3517,12.09092029191547,0.0424797706641849
2009-01-09,3831,13.73480379361352,0.0414673899692068


== Physical Plan ==
AdaptiveSparkPlan (27)
+- == Initial Plan ==
   ColumnarToRow (26)
   +- PhotonResultStage (25)
      +- PhotonSort (24)
         +- PhotonShuffleExchangeSource (23)
            +- PhotonShuffleMapStage (22)
               +- PhotonShuffleExchangeSink (21)
                  +- PhotonGroupingAgg (20)
                     +- PhotonShuffleExchangeSource (19)
                        +- PhotonShuffleMapStage (18)
                           +- PhotonShuffleExchangeSink (17)
                              +- PhotonGroupingAgg (16)
                                 +- PhotonProject (15)
                                    +- PhotonBroadcastHashJoin LeftOuter (14)
                                       :- PhotonProject (7)
                                       :  +- PhotonProject (6)
                                       :     +- PhotonFilter (5)
                                       :        +- PhotonSample (4)
                                       :           +- PhotonSa

In [0]:
# SQL views
dfj.createOrReplaceTempView("trips")
agg_by_pay.createOrReplaceTempView("agg_by_pay")
daily.createOrReplaceTempView("daily_trips")

In [0]:
# SQL #1: Top payment types by avg tip pct (with volume threshold)
display(spark.sql("""
SELECT payment_desc, trips, avg_tip_pct
FROM agg_by_pay
WHERE trips > 20000
ORDER BY avg_tip_pct DESC
"""))

payment_desc,trips,avg_tip_pct
,32832,0.5999999999999873
,101125,0.5999940667492517
Unknown/No Charge,1315150,0.5999881382344024
Unknown,776747,0.5999814611446724
No Charge/Dispute,58069,0.5562115698059619
Credit Card,1424833,0.3317679149043738
,7690320,0.0911043899581793
Cash,571338,8.484764219513413e-06


In [0]:
# SQL #2: Hourly profile for longer trips
display(spark.sql("""
SELECT pickup_hour, COUNT(*) AS trips, AVG(trip_distance) AS avg_miles, AVG(tip_pct) AS avg_tip_pct
FROM trips
WHERE trip_distance >= 10
GROUP BY pickup_hour
ORDER BY pickup_hour
"""))

pickup_hour,trips,avg_miles,avg_tip_pct
0,20887,14.702872121415234,0.229362961292806
1,13117,14.25104139666082,0.2095982222293419
2,9208,13.742932232841008,0.1890136708753356
3,8370,13.87738470728793,0.173694811091896
4,11496,14.09176409185804,0.2005311970051356
5,16433,15.34563682833323,0.2342713724797737
6,21970,15.601621756941285,0.239391892633447
7,23570,15.166803563852357,0.2401389359604087
8,23357,14.649639508498526,0.236913664080508
9,23381,14.226450536760616,0.2400994761983788


## PySpark Optimization (Faster) 

In [0]:
# PySpark performance on ~1GB (mimics your screenshot)
from pyspark.sql.functions import col, desc
import time

# Choose a reasonably wide/probing selection to make the grouping meaningful
df_large = (dfj
    .select("payment_desc", "pickup_hour", "trip_distance", "fare_amount", "tip_pct")
    .filter((col("fare_amount") > 0) & (col("trip_distance") > 0))
)
_ = df_large.count()  # materialize cache so we're timing the actual aggregation work

start = time.time()

result_large = (df_large
    .filter(col("tip_pct") > 0)     # analogous to delay > 0
    .groupBy("payment_desc")        # analogous to groupBy('origin')
    .count()
    .orderBy(desc("count"))
    .limit(10)
)

result_large.show(truncate=False)

large_time = time.time() - start

print(f"\n⏱️  PySpark on ~1GB sample: {large_time:.2f} seconds")
print(f"\n💡 Pandas would:")
print(f"  - Often take 3–5× longer (~{large_time*4:.1f}s estimate) on the same machine")
print( "  - Or risk MemoryError if the dataset doesn't fit driver RAM")
print("\n✅ PySpark distributes the work across the cluster and handles it smoothly.")

+-----------------+-------+
|payment_desc     |count  |
+-----------------+-------+
|NULL             |3465266|
|Credit Card      |1387827|
|Unknown/No Charge|1315124|
|Unknown          |776723 |
|No Charge/Dispute|53833  |
|Cash             |15     |
+-----------------+-------+


⏱️  PySpark on ~1GB sample: 498.74 seconds

💡 Pandas would:
  - Often take 3–5× longer (~1994.9s estimate) on the same machine
  - Or risk MemoryError if the dataset doesn't fit driver RAM

✅ PySpark distributes the work across the cluster and handles it smoothly.


## Write results to Parquet 

In [0]:
# Save results to Parquet
dfj.write.mode("overwrite").parquet("dbfs:/tmp/nyc_taxi_results")

print("Parquet converted.")

Parquet converted.


# Actions vs Transformations

## Lazy Evaluation

In [0]:
# Lazy Evaluation (using NYC Taxi df)

query = (
    dfj
    .select("pickup_date", "fare_amount", "tip_amount", "trip_distance", "tip_pct")
    # Intentionally compute a column BEFORE filtering (wasteful by design)
    .withColumn(
        "tip_category",
        F.when(F.col("tip_pct") >= 0.3, "high") \
         .when(F.col("tip_pct") >= 0.15, "medium") \
         .otherwise("low")
    )
    # Filters written AFTER withColumn (optimizer may fix it, but this shows the idea)
    .filter(F.col("fare_amount") > 50)           # expensive to do late
    .filter(F.col("trip_distance") > 3)          # expensive to do late
    .groupBy("tip_category")
    .agg(
        F.avg("fare_amount").alias("avg_fare"),
        F.avg("trip_distance").alias("avg_miles"),
        F.count("*").alias("count")
    )
    .orderBy(F.desc("count"))
)

# Nothing has run yet. The next line is the ACTION that triggers execution.
display(query)

tip_category,avg_fare,avg_miles,count
low,56.82622975947818,18.6201293110477,61325
medium,55.33771971077538,18.26595105128054,51033
high,58.0821972711768,18.64293064241045,3518


## Actions vs Transformations

In [0]:
# Actions vs Transformations

# Transformations only (build a plan)
start = time.time()
filtered = df.filter((col("fare_amount") > 100) & (col("trip_distance") > 5))
selected = filtered.select("pickup_date", "fare_amount", "trip_distance", "tip_pct")
print(f"🧱 Transformations (plan built only), took {time.time() - start:.4f}s")

# Action 1: triggers execution
print("\n⚡ Action 1: count()")
start = time.time()
cnt = selected.count()  # ACTION
print(f"count() => {cnt:,} rows, took {time.time() - start:.2f}s")

# Action 2: triggers execution again (recomputes the plan)
print("\n⚡ Action 2: show(5)")
start = time.time()
selected.show(5, truncate=False)  # ACTION
print(f"show(5) took {time.time() - start:.2f}s")

print("\n🔁 Notice: each ACTION re-executes the upstream transformations!")

🧱 Transformations (plan built only), took 0.0005s

⚡ Action 1: count()
count() => 1,860 rows, took 452.86s

⚡ Action 2: show(5)
+-----------+-----------+-------------+-------------------+
|pickup_date|fare_amount|trip_distance|tip_pct            |
+-----------+-----------+-------------+-------------------+
|2010-05-31 |114.0      |49.7         |0.15789473684210525|
|2010-05-01 |120.0      |19.1         |0.0                |
|2010-05-08 |118.0      |32.2         |0.0847457627118644 |
|2010-05-04 |100.9      |30.4         |0.0                |
|2010-05-23 |146.25     |32.73        |0.10256410256410256|
+-----------+-----------+-------------+-------------------+
only showing top 5 rows
show(5) took 30.15s

🔁 Notice: each ACTION re-executes the upstream transformations!


# Optional Machine Learning

In [0]:
# Pick the most feature-rich DataFrame available
try:
    source_df = dfj
except NameError:
    source_df = df

from pyspark.sql import functions as F

# Required columns (adjust if any are missing in your schema)
needed = ["fare_amount","tip_amount","trip_distance","passenger_count","pickup_hour","payment_type","tip_pct"]
have = [c for c in needed if c in source_df.columns]

ml = (source_df
      .select(*have)
      .na.drop(subset=["fare_amount","tip_amount","trip_distance","passenger_count"])
      .filter((F.col("fare_amount") > 0) & (F.col("trip_distance") > 0))
      .withColumn("passenger_count", F.coalesce(F.col("passenger_count").cast("double"), F.lit(1.0)))
     )

print("Rows for ML after cleaning:", ml.count())
display(ml.limit(5))

Rows for ML after cleaning: 11982701


fare_amount,tip_amount,trip_distance,passenger_count,pickup_hour,payment_type,tip_pct
3.7,0.0,0.5,1.0,0,,0.0
6.9,1.58,1.6,1.0,22,,0.2289855072463768
13.7,2.0,4.7,4.0,8,,0.145985401459854
7.7,0.0,1.8,5.0,9,,0.0
4.9,0.0,0.91,1.0,13,,0.0


In [0]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
from pyspark.ml.regression import LinearRegression, GBTRegressor
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import RegressionEvaluator

# Categorical: payment_type (if present)
stages = []
if "payment_type" in ml.columns:
    idx = StringIndexer(inputCol="payment_type", outputCol="payment_type_idx", handleInvalid="keep")
    ohe = OneHotEncoder(inputCols=["payment_type_idx"], outputCols=["payment_type_ohe"])
    stages += [idx, ohe]
    cat_cols = ["payment_type_ohe"]
else:
    cat_cols = []

num_cols = [c for c in ["fare_amount","trip_distance","passenger_count","pickup_hour"] if c in ml.columns]

assembler = VectorAssembler(inputCols=cat_cols + num_cols, outputCol="features_raw")
scaler = StandardScaler(inputCol="features_raw", outputCol="features", withMean=False, withStd=True)

# Try a simple Linear Regression first; you can switch to GBT below
lr = LinearRegression(featuresCol="features", labelCol="tip_amount", maxIter=50, regParam=0.1, elasticNetParam=0.2)

pipeline = Pipeline(stages=stages + [assembler, scaler, lr])

train, test = ml.randomSplit([0.8, 0.2], seed=42)
model = pipeline.fit(train)
pred = model.transform(test)

for metric in ["rmse","mae","r2"]:
    ev = RegressionEvaluator(labelCol="tip_amount", predictionCol="prediction", metricName=metric)
    print(f"{metric.upper():<4}:", ev.evaluate(pred))

display(pred.select("fare_amount","trip_distance","passenger_count","tip_amount","prediction").limit(20))

RMSE: 1.4246476555839174
MAE : 0.8725577183641957
R2  : 0.35764034121529475


fare_amount,trip_distance,passenger_count,tip_amount,prediction
2.5,0.01,4.0,0.0,0.1860603265793722
2.5,0.02,1.0,0.0,0.194470555825471
2.5,0.02,3.0,0.0,0.1889574758468454
2.5,0.04,1.0,0.0,0.1947517743817917
2.5,0.04,1.0,0.0,0.1947517743817917
2.5,0.04,1.0,0.0,0.1947517743817917
2.5,0.04,1.0,0.0,0.1947517743817917
2.5,0.05,1.0,0.0,0.1948923836599521
2.5,0.05,5.0,0.0,0.1838662237027008
2.5,0.06,5.0,0.0,0.1840068329808611
