# Module 8: PySpark Optimization Lab

## Objective

In this lab, the goal is to improve the performance and memory efficiency of a PySpark job.
We will do this by applying the following optimization techniques:

- **Caching**: Store intermediate results in memory to avoid recomputation
- **Repartitioning**: Distribute data more evenly across the cluster to enhance parallelism
- **UDF Replacement**: Replace Python-based UDFs with native Spark built-in functions that execute faster


## Dataset Summary

Below is an overview of the dataset used in this lab.

| Attribute                 | Description                                              |
|--------------------------|----------------------------------------------------------|
| **Filename**             | original_cleaned_nyc_taxi_data_2018.csv                 |
| **Size**                 | ~750 MB                                                  |
| **Format**               | CSV (Comma-Separated Values)                             |
| **Number of Rows**       | ~5 million (depending on the full source)                |
| **Columns**              | 21                                                       |
| **Contains**             | NYC yellow taxi trip records for the year 2018          |
| **Key Features**         | Trip distance, fare, tip, duration, pickup/dropoff IDs  |
| **Time Dimensions**      | Year, Month, Day, Day of Week, Hour                      |
| **Target Fields Used**   | tip_amount, fare_amount, trip_duration, trip_distance   |
| **Use Case**             | Performance tuning via Spark DataFrame transformations  |

## Lab Steps

This lab will consist of the following steps:

1. **Load and inspect the dataset**  
2. **Run an unoptimized pipeline**  
   - Uses a Python UDF
   - No caching or partitioning
3. **Refactor into an optimized pipeline**  
   - Replace UDF with built-in functions
   - Apply caching
   - Apply repartitioning
4. **Benchmark performance**  
   - Compare execution time between unoptimized and optimized jobs
5. **Analyze execution plans**  
   - Use `.explain(True)` to understand Spark's physical and logical execution strategies

## Spark Setup

We begin by importing required libraries and initializing a Spark session.  
This session serves as the entry point for all PySpark operations.


In [21]:
import time
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, expr

# Create a Spark session (entry point to Spark functionality)
spark = SparkSession.builder \
    .appName("NYC Taxi Optimization Lab") \
    .getOrCreate()

In [22]:
# Install the gdown package to enable downloading from Google Drive
!pip install -q gdown

# Download the public dataset using its file ID from Google Drive
!gdown --id 1p03CbxCZAahZN7eWSZ8QwCFz9DWxiedc --output original_cleaned_nyc_taxi_data_2018.csv

Downloading...
From (original): https://drive.google.com/uc?id=1p03CbxCZAahZN7eWSZ8QwCFz9DWxiedc
From (redirected): https://drive.google.com/uc?id=1p03CbxCZAahZN7eWSZ8QwCFz9DWxiedc&confirm=t&uuid=0a5c1c59-35b0-4a25-a747-e97bac8d2a80
To: /content/original_cleaned_nyc_taxi_data_2018.csv
100% 754M/754M [00:14<00:00, 51.2MB/s]


## Load and Inspect the Dataset
We will now load the NYC Taxi dataset using `spark.read.csv`.  
- The `header=True` option ensures that the first row is treated as column headers.  
- The `inferSchema=True` option automatically detects and assigns appropriate data types.

After loading, we will inspect the schema and a few sample rows.

In [23]:
# Load the CSV file into a Spark DataFrame
trips = spark.read.csv("original_cleaned_nyc_taxi_data_2018.csv", header=True, inferSchema=True)

# Display the schema to understand data types and column names
trips.printSchema()

# Show the first 5 rows to preview the data
trips.show(5)

root
 |-- _c0: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- rate_code: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- imp_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- pickup_location_id: integer (nullable = true)
 |-- dropoff_location_id: integer (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- hour_of_day: integer (nullable = true)
 |-- trip_duration: double (nullable = true)
 |-- calculated_total_amount: double (nullable = true)

+---+-------------+---------+------------------+------------+-----------+-----+-

## Unoptimized Pipeline

In this step, we build an initial version of the Spark job **without any optimizations**.  
Key characteristics of this unoptimized pipeline:

- Uses a Python-based User Defined Function (UDF) to classify trips based on distance
- No caching is applied, so all intermediate results are recomputed if reused
- No partitioning is used, which may lead to skewed data distribution
- Performs group-by aggregations on `pickup_location_id` and `trip_category`
- Calculates:
  - Average tip percentage
  - Total fare revenue
  - Average trip duration

In [24]:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import time

# Define a Python UDF to classify trip distances into categories
@udf(StringType())
def trip_category(distance):
    if distance is None:
        return "unknown"
    elif distance < 2:
        return "short"
    elif distance < 10:
        return "medium"
    else:
        return "long"

In [25]:
# Start timer to measure unoptimized job runtime
start_time = time.time()

# Apply UDF, group by pickup location and trip category, and calculate aggregations
df_unopt = trips.withColumn("trip_category", trip_category(col("trip_distance"))) \
    .groupBy("pickup_location_id", "trip_category") \
    .agg(
        expr("avg(tip_amount / fare_amount) AS avg_tip_pct"),       # Average tip percentage
        expr("sum(fare_amount) AS total_fare"),                     # Total fare collected
        expr("avg(trip_duration) AS avg_trip_duration"),            # Average duration of trips
        expr("count(*) AS num_trips")                               # Number of trips in each group
    )

# Display the first 10 grouped results
df_unopt.show(10, truncate=False)

# Print the runtime for benchmarking
print(f"Unoptimized runtime: {time.time() - start_time:.2f} seconds")

+------------------+-------------+--------------------+-------------------+------------------+---------+
|pickup_location_id|trip_category|avg_tip_pct         |total_fare         |avg_trip_duration |num_trips|
+------------------+-------------+--------------------+-------------------+------------------+---------+
|164               |medium       |0.18146879695241228 |2407198.2800000003 |2204.711435696473 |92631    |
|83                |medium       |0.15294876111007924 |19831.63           |2043.0618131868132|728      |
|169               |medium       |0.06651791474826965 |4358.5599999999995 |2600.923076923077 |169      |
|229               |short        |0.17797636285644572 |78406.2            |2080.5712018620397|9452     |
|132               |long         |0.17016430811661032 |3.669964357999997E7|2211.642234573835 |730269   |
|71                |medium       |0.06982431609781146 |11986.61           |2105.740909090909 |440      |
|51                |short        |0.027346938775510202|

## Optimized Pipeline

In this step, we refactor the Spark job to apply standard optimization techniques:

- **Replaces the Python UDF** with a native Spark `when` / `otherwise` expression  
  (which allows Spark to apply query optimization using Catalyst)
- **Caches** the intermediate DataFrame to avoid redundant computation in future stages
- **Repartitions** the data based on `pickup_location_id` to improve parallelism and reduce shuffle

This version should be significantly faster and more memory-efficient compared to the unoptimized version.

In [26]:
# Replace the UDF with Spark's native when/otherwise logic
df_opt = trips.withColumn("trip_category",
    when(col("trip_distance") < 2, "short")
    .when(col("trip_distance") < 10, "medium")
    .otherwise("long")
)

# Repartition data by pickup_location_id to improve parallel execution
# Cache the result to avoid recomputation in future transformations
df_opt = df_opt.repartition("pickup_location_id").cache()

# Trigger caching explicitly to load the data into memory
df_opt.count()

# Start timer to benchmark the optimized pipeline
start_time = time.time()

### Run Optimized Aggregation and Benchmark

Now we perform the same group-by and aggregation as before,  
but on the optimized DataFrame that uses native functions, caching, and repartitioning.

We measure the execution time to compare performance against the unoptimized job.

This pipeline computes:
- `avg_tip_pct`: Average tip as a percentage of fare
- `total_fare`: Sum of fare amounts
- `avg_trip_duration`: Average trip duration
- `num_trips`: Total number of trips in each group

In [27]:
# Perform aggregations on the optimized DataFrame
df_optimized = df_opt.groupBy("pickup_location_id", "trip_category") \
    .agg(
        expr("avg(tip_amount / fare_amount) AS avg_tip_pct"),
        expr("sum(fare_amount) AS total_fare"),
        expr("avg(trip_duration) AS avg_trip_duration"),
        expr("count(*) AS num_trips")
    )

# Show first 10 result groups
df_optimized.show(10, truncate=False)

# Print runtime of optimized pipeline
print(f"Optimized runtime: {time.time() - start_time:.2f} seconds")

+------------------+-------------+---------------------+------------------+------------------+---------+
|pickup_location_id|trip_category|avg_tip_pct          |total_fare        |avg_trip_duration |num_trips|
+------------------+-------------+---------------------+------------------+------------------+---------+
|148               |medium       |0.16581733792892117  |1701843.6600000006|2200.412935109583 |69810    |
|148               |long         |0.21089685731788563  |688346.0600000005 |2207.885403954113 |16388    |
|148               |short        |0.2695573443554674   |41056.01000000001 |2217.6414307658547|4557     |
|243               |medium       |0.1704259153301238   |61693.76999999999 |2194.324134910206 |2283     |
|243               |long         |0.14743155331239505  |41722.48000000001 |2210.81626187962  |947      |
|243               |short        |0.12962486299747908  |2150.0299999999997|1919.111111111111 |171      |
|31                |long         |0.1259056250569908   

## Plan Analysis

To understand how Spark processes and optimizes each version of our job,  
we examine the physical and logical execution plans generated by `.explain(True)`.

This helps identify performance bottlenecks such as:
- Unnecessary shuffling
- Inefficient scans or wide transformations
- Redundant computations due to missing caching

In [28]:
# Print the execution plan for the unoptimized pipeline
print("--- UNOPTIMIZED PLAN ---")
df_unopt.explain(True)

# Print the execution plan for the optimized pipeline
print("--- OPTIMIZED PLAN ---")
df_optimized.explain(True)

--- UNOPTIMIZED PLAN ---
== Parsed Logical Plan ==
'Aggregate ['pickup_location_id, 'trip_category], ['pickup_location_id, 'trip_category, 'avg(('tip_amount / 'fare_amount)) AS avg_tip_pct#5415, 'sum('fare_amount) AS total_fare#5416, 'avg('trip_duration) AS avg_trip_duration#5417, 'count(1) AS num_trips#5418]
+- Project [_c0#5220, trip_distance#5221, rate_code#5222, store_and_fwd_flag#5223, payment_type#5224, fare_amount#5225, extra#5226, mta_tax#5227, tip_amount#5228, tolls_amount#5229, imp_surcharge#5230, total_amount#5231, pickup_location_id#5232, dropoff_location_id#5233, year#5234, month#5235, day#5236, day_of_week#5237, hour_of_day#5238, trip_duration#5239, calculated_total_amount#5240, trip_category(trip_distance#5221)#5369 AS trip_category#5370]
   +- Relation [_c0#5220,trip_distance#5221,rate_code#5222,store_and_fwd_flag#5223,payment_type#5224,fare_amount#5225,extra#5226,mta_tax#5227,tip_amount#5228,tolls_amount#5229,imp_surcharge#5230,total_amount#5231,pickup_location_id#5232