In [0]:
# Check Spark version and cluster info
print(f"Spark version: {spark.version}")
print(f"Cluster configured successfully!")

# Check available datasets
dbutils.fs.ls("/databricks-datasets/")

Spark version: 4.0.0
Cluster configured successfully!


[FileInfo(path='dbfs:/databricks-datasets/COVID/', name='COVID/', size=0, modificationTime=1762918301348),
 FileInfo(path='dbfs:/databricks-datasets/README.md', name='README.md', size=976, modificationTime=1596557781000),
 FileInfo(path='dbfs:/databricks-datasets/Rdatasets/', name='Rdatasets/', size=0, modificationTime=1762918301348),
 FileInfo(path='dbfs:/databricks-datasets/SPARK_README.md', name='SPARK_README.md', size=3359, modificationTime=1596557823000),
 FileInfo(path='dbfs:/databricks-datasets/adult/', name='adult/', size=0, modificationTime=1762918301348),
 FileInfo(path='dbfs:/databricks-datasets/airlines/', name='airlines/', size=0, modificationTime=1762918301348),
 FileInfo(path='dbfs:/databricks-datasets/amazon/', name='amazon/', size=0, modificationTime=1762918301348),
 FileInfo(path='dbfs:/databricks-datasets/asa/', name='asa/', size=0, modificationTime=1762918301348),
 FileInfo(path='dbfs:/databricks-datasets/atlas_higgs/', name='atlas_higgs/', size=0, modificationTime=

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType, DoubleType, TimestampType
import time

In [0]:
nyc_taxi_df = spark.read.parquet(
    "/Volumes/workspace/default/nyc_taxi/yellow_tripdata_2019-*.parquet"
)

print(f"Rows: {nyc_taxi_df.count()}")
nyc_taxi_df.show(5)

Rows: 84598444
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2019-03-01 00:24:41|  2019-03-01 00:25:31|            1.0|          0.0|       1.0|                 N|         145|         145|           2|        2.5|  0.5|  

In [0]:
nyc_taxi_df.printSchema()

root
 |-- VendorID: long (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: double (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: integer (nullable = true)



In [0]:
cols_needed = [
    "tpep_pickup_datetime", "tpep_dropoff_datetime", "passenger_count",
    "trip_distance", "PULocationID", "DOLocationID", "payment_type", "fare_amount", "tip_amount", "total_amount"
]

nyc_taxi_df_trimmed = nyc_taxi_df.select(*[c for c in cols_needed if c in nyc_taxi_df.columns])


In [0]:
nyc_taxi_df_trimmed = (nyc_taxi_df_trimmed
      .withColumn("pickup_dt", F.to_timestamp("tpep_pickup_datetime"))
      .withColumn("dropoff_dt", F.to_timestamp("tpep_dropoff_datetime"))
      .withColumn("trip_seconds", 
                  (F.unix_timestamp("dropoff_dt") - F.unix_timestamp("pickup_dt")).cast(IntegerType()))
      .withColumn("cost_per_mile", 
                  F.when(F.col("trip_distance")>0, F.col("total_amount")/F.col("trip_distance")).otherwise(None))
      .withColumn("year", F.year("pickup_dt"))
      .withColumn("month", F.month("pickup_dt"))
     )

nyc_taxi_df_trimmed = nyc_taxi_df_trimmed.filter((F.col("trip_distance") > 0) & (F.col("trip_seconds") > 0) & (F.col("total_amount") >= 0))


In [0]:
nyc_taxi_df_filtered = (nyc_taxi_df_trimmed
            .filter(F.col("year") == 2019)
            .filter(F.col("trip_distance") < 100)
            .filter(F.col("passenger_count").between(1,6))
           )

print("Filtered count:", nyc_taxi_df_filtered.count())


Filtered count: 81761875


In [0]:
nyc_taxi_df_filtered_aggregate = (nyc_taxi_df_filtered
       .withColumn("hour", F.hour("pickup_dt"))
       .groupBy("PULocationID", "hour")
       .agg(
           F.count("*").alias("n_trips"),
           F.avg("tip_amount").alias("avg_tip"),
           F.avg("total_amount").alias("avg_total")
       )
       .filter(F.col("n_trips") > 50)
       .orderBy(F.desc("n_trips"))
      )
display(nyc_taxi_df_filtered_aggregate.limit(50))


PULocationID,hour,n_trips,avg_tip,avg_total
237,18,262763,1.782155478511069,15.052447871272758
237,14,259886,1.6719072208583865,14.461724025146651
237,15,254915,1.6386837573308837,14.349240413468396
237,17,251552,1.78843686394862,15.366496549418374
161,19,249342,2.0351082449005955,16.897644480269065
236,15,248721,1.6846729065901234,14.685151233708943
161,18,244925,2.092352271103415,17.685739430434705
230,22,238005,2.048803008340178,17.261519463873707
162,18,235888,2.1765007122024165,17.429331759137142
237,13,234691,1.6399080066981762,14.365251202642654


In [0]:
nyc_taxi_df_filtered_daily_stats = (nyc_taxi_df_filtered
               .withColumn("pickup_date", F.to_date("pickup_dt"))
               .groupBy("pickup_date")
               .agg(
                   F.count("*").alias("trips"),
                   F.sum("trip_distance").alias("sum_distance"),
                   F.avg("fare_amount").alias("avg_fare"),
                   F.avg("cost_per_mile").alias("avg_cost_per_mile")
               )
               .orderBy("pickup_date")
              )
display(nyc_taxi_df_filtered_daily_stats.limit(10))


pickup_date,trips,sum_distance,avg_fare,avg_cost_per_mile
2019-01-01,184733,658065.8899999907,13.576036008726112,8.691470048130315
2019-01-02,193401,625974.8899999988,12.987656630524144,9.877036052276798
2019-01-03,217898,649918.3700000047,12.510801108775649,9.776335915169437
2019-01-04,229688,650575.460000009,12.047309350074878,9.30571368593544
2019-01-05,231092,620001.3000000006,11.357736096446423,9.086749183085296
2019-01-06,203634,630749.2899999899,12.29189658897825,9.212683429959348
2019-01-07,222514,639902.1300000112,12.11741719622136,9.602806992143757
2019-01-08,230858,644410.6800000038,12.07237453326285,9.339118097465365
2019-01-09,248989,686445.9599999953,12.07002373598833,9.650481696055488
2019-01-10,274325,752054.4800000101,12.168949931650404,9.652038471709073


In [0]:
nyc_taxi_df_filtered.createOrReplaceTempView("trips_filtered")

sql_1 = spark.sql("""
SELECT PULocationID, COUNT(*) AS trip_count, AVG(total_amount) AS avg_total
FROM trips_filtered
GROUP BY PULocationID
ORDER BY trip_count DESC
LIMIT 50
""")
display(sql_1)

sql_2 = spark.sql("""
SELECT hour(pickup_dt) AS hour_of_day, 
       COUNT(*) AS trips,
       AVG(CASE WHEN total_amount > 0 THEN tip_amount/total_amount ELSE NULL END) AS avg_tip_rate
FROM trips_filtered
GROUP BY hour_of_day
ORDER BY hour_of_day
""")
display(sql_2)

PULocationID,trip_count,avg_total
237,3558623,14.786242178466004
161,3367811,17.806076103988737
236,3218419,14.903735212816118
162,2974213,17.42958801871112
186,2943209,17.81265454810834
230,2823856,19.07906584467504
132,2613346,56.40362840207042
48,2606595,16.941972600238717
170,2534856,17.173273243108266
234,2472886,16.392720978624222


hour_of_day,trips,avg_tip_rate
0,2426028,0.109730295534205
1,1685005,0.1077820716214895
2,1167226,0.1048905511873369
3,811732,0.0986166809867727
4,625218,0.0863772029217584
5,736134,0.0924036997773783
6,1689658,0.101971968401445
7,2969422,0.1117405040272496
8,3726209,0.1148528229531605
9,3802137,0.1114190562835887


In [0]:
print("=== EXPLAIN for sql_1 ===")
sql_1.explain(True)

print("=== EXPLAIN for agg ===")
nyc_taxi_df_filtered_aggregate.explain(True)


=== EXPLAIN for sql_1 ===
== Parsed Logical Plan ==
'GlobalLimit 50
+- 'LocalLimit 50
   +- 'Sort ['trip_count DESC NULLS LAST], true
      +- 'Aggregate ['PULocationID], ['PULocationID, 'COUNT(1) AS trip_count#11330, 'AVG('total_amount) AS avg_total#11331]
         +- 'UnresolvedRelation [trips_filtered], [], false

== Analyzed Logical Plan ==
PULocationID: bigint, trip_count: bigint, avg_total: double
GlobalLimit 50
+- LocalLimit 50
   +- Sort [trip_count#11330L DESC NULLS LAST], true
      +- Aggregate [PULocationID#11045L], [PULocationID#11045L, count(1) AS trip_count#11330L, avg(total_amount#11054) AS avg_total#11331]
         +- SubqueryAlias trips_filtered
            +- View (`trips_filtered`, [tpep_pickup_datetime#11039, tpep_dropoff_datetime#11040, passenger_count#11041, trip_distance#11042, PULocationID#11045L, DOLocationID#11046L, payment_type#11047L, fare_amount#11048, tip_amount#11051, total_amount#11054, pickup_dt#11126, dropoff_dt#11128, trip_seconds#11130, cost_per_mil

In [0]:
nparts = 200
nyc_taxi_df_filtered_partitioned = nyc_taxi_df_filtered.repartition(nparts, "PULocationID")

nyc_taxi_df_groupby_PUID = (nyc_taxi_df_filtered_partitioned.groupBy("PULocationID")
         .agg(F.count("*").alias("n_trips"), F.avg("total_amount").alias("avg_total")))
display(nyc_taxi_df_groupby_PUID.limit(20))


PULocationID,n_trips,avg_total
26,4479,22.17868943960843
29,1433,28.094173063503177
65,94187,20.34511036556757
191,2544,38.39559355345994
222,2386,30.31531014249889
243,13553,20.473413266432168
54,2470,21.188829959514447
19,1094,27.11773308957913
113,1243216,16.256658714345622
112,22210,18.043027915349462


In [0]:
out_path = "/Volumes/workspace/default/nyc_taxi/processed/2019_stats"

nyc_taxi_df_filtered_daily_stats.write.mode("overwrite").parquet(out_path)

In [0]:
display(dbutils.fs.ls(out_path))

path,name,size,modificationTime
dbfs:/Volumes/workspace/default/nyc_taxi/processed/2019_stats/_SUCCESS,_SUCCESS,0,1762918347000
dbfs:/Volumes/workspace/default/nyc_taxi/processed/2019_stats/_committed_4149569641990917045,_committed_4149569641990917045,124,1762915069000
dbfs:/Volumes/workspace/default/nyc_taxi/processed/2019_stats/_committed_7918402241361222417,_committed_7918402241361222417,234,1762918346000
dbfs:/Volumes/workspace/default/nyc_taxi/processed/2019_stats/_committed_vacuum1728349989606344054,_committed_vacuum1728349989606344054,96,1762918347000
dbfs:/Volumes/workspace/default/nyc_taxi/processed/2019_stats/_started_7918402241361222417,_started_7918402241361222417,0,1762918346000
dbfs:/Volumes/workspace/default/nyc_taxi/processed/2019_stats/part-00000-tid-7918402241361222417-44fe9199-520f-43c7-99d0-91b4b7487201-258-1.c000.snappy.parquet,part-00000-tid-7918402241361222417-44fe9199-520f-43c7-99d0-91b4b7487201-258-1.c000.snappy.parquet,13464,1762918346000


In [0]:
transform = (nyc_taxi_df_filtered
     .withColumn("hour", F.hour("pickup_dt"))
     .filter(F.col("trip_distance") > 1)
    )
print("Declared 't' — no action yet. ")

print("Trigger an action (count):", transform.count())

transform2 = transform.withColumn(
    "tip_rate",
    F.try_divide(F.col("tip_amount"), F.col("total_amount"))
)
print("Declared 't2' — still lazy.")
display(transform2.limit(5))

Declared 't' — no action yet. 
Trigger an action (count): 60016055
Declared 't2' — still lazy.


tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,PULocationID,DOLocationID,payment_type,fare_amount,tip_amount,total_amount,pickup_dt,dropoff_dt,trip_seconds,cost_per_mile,year,month,hour,tip_rate
2019-03-01T00:25:27.000,2019-03-01T00:36:37.000,2.0,3.7,95,130,1,13.0,0.7,15.0,2019-03-01T00:25:27.000Z,2019-03-01T00:36:37.000Z,670,4.0540540540540535,2019,3,0,0.0466666666666666
2019-03-01T00:05:21.000,2019-03-01T00:38:23.000,1.0,14.1,249,28,1,41.0,10.1,60.66,2019-03-01T00:05:21.000Z,2019-03-01T00:38:23.000Z,1982,4.302127659574468,2019,3,0,0.1665018133860863
2019-03-01T00:48:55.000,2019-03-01T01:06:03.000,1.0,9.6,138,98,2,27.0,0.0,28.3,2019-03-01T00:48:55.000Z,2019-03-01T01:06:03.000Z,1028,2.947916666666667,2019,3,0,0.0
2019-03-01T00:45:03.000,2019-03-01T00:49:38.000,1.0,1.2,246,48,2,6.0,0.0,9.8,2019-03-01T00:45:03.000Z,2019-03-01T00:49:38.000Z,275,8.166666666666668,2019,3,0,0.0
2019-02-28T19:52:45.000,2019-02-28T20:01:54.000,1.0,5.65,132,197,2,17.0,0.0,18.3,2019-02-28T19:52:45.000Z,2019-02-28T20:01:54.000Z,549,3.238938053097345,2019,2,19,0.0


In [0]:
start = time.time()
nyc_taxi_df_filtered_daily_stats.count()
print("First count (before caching):", round(time.time() - start, 2), "seconds")

nyc_taxi_df_filtered_daily_stats.createOrReplaceTempView("daily_stats_temp")

start = time.time()
spark.sql("SELECT COUNT(*) FROM daily_stats_temp").collect()
print("Second count (after temp view reuse):", round(time.time() - start, 2), "seconds")



First count (before caching): 1.83 seconds
Second count (after temp view reuse): 1.92 seconds
