# 03 - Shuffle, Joins & Partitioning

This notebook combines three performance-critical Spark topics in one flow:

1. **Shuffle** - what it is, why it creates stage boundaries, and why it is expensive
2. **Joins** - Broadcast Hash Join, Sort-Merge Join, Shuffle Hash Join
3. **Partitioning and data organization** - `repartition()` vs `coalesce()`, partitioning vs bucketing

Dataset: based NYC Taxi  dataset.

In [1]:
from pyspark.sql import functions as F
from pyspark.sql.functions import broadcast

In [2]:
import urllib.request
import os

base_url = "https://d37ci6vzurychx.cloudfront.net/trip-data"
save_dir = r"C:\code\spark-tuning-handbook\data\taxi"
os.makedirs(save_dir, exist_ok=True)

for month in range(1, 13):
    fname = f"yellow_tripdata_2024-{month:02d}.parquet"
    url = f"{base_url}/{fname}"
    dest = os.path.join(save_dir, fname)
    if not os.path.exists(dest):
        print(f"Downloading {fname}...")
        urllib.request.urlretrieve(url, dest)

print("Done")

Done


In [2]:
# Load parquet (local path placeholder style, same as previous notebooks)
taxi = spark.read.parquet(r"C:\code\spark-tuning-handbook\data\taxi")

taxi.printSchema()
print("rows:", taxi.count())
print("partitions:", taxi.rdd.getNumPartitions())
print("columns:", taxi.columns)

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)

rows: 41169720
partitions: 6
columns: ['VendorID', 'tpep_pickup_datetime', 'tpep_dropoff_datetime', 'passenger_count', 'trip_

In [3]:
# Keep plans deterministic for learning cells
spark.conf.set("spark.sql.adaptive.enabled", "false")
spark.conf.set("spark.sql.shuffle.partitions", "8")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", str(10 * 1024 * 1024))  # 10 MB default
spark.conf.set("spark.sql.join.preferSortMergeJoin", "true")

print("spark.sql.adaptive.enabled:", spark.conf.get("spark.sql.adaptive.enabled"))
print("spark.sql.shuffle.partitions:", spark.conf.get("spark.sql.shuffle.partitions"))
print("spark.sql.autoBroadcastJoinThreshold:", spark.conf.get("spark.sql.autoBroadcastJoinThreshold"))
print("spark.sql.join.preferSortMergeJoin:", spark.conf.get("spark.sql.join.preferSortMergeJoin"))

spark.sql.adaptive.enabled: false
spark.sql.shuffle.partitions: 8
spark.sql.autoBroadcastJoinThreshold: 10485760
spark.sql.join.preferSortMergeJoin: true


Spark UI: http://localhost:4040/jobs/

We use Spark UI in this lab after each action to validate:
- how many stages were generated
- where stage boundaries appear
- shuffle write in upstream stage(s)
- shuffle read in downstream stage(s)
- whether one or a few tasks are much slower (skew signal)

## Helper - find Exchange operators quickly

`explain("formatted")` is the main tool in this notebook.
The helper below extracts Exchange-related nodes from the executed plan so you can quickly confirm shuffle boundaries.


In [4]:
import time


def show_exchange_nodes(df):
    plan = df._jdf.queryExecution().executedPlan().toString()
    lines = [line.strip() for line in plan.splitlines() if "Exchange" in line]

    if not lines:
        print("No Exchange nodes in executed physical plan.")
    else:
        print("Exchange-related nodes:")
        for line in lines:
            print(line)


def show_join_nodes(df):
    plan = df._jdf.queryExecution().executedPlan().toString()
    lines = [line.strip() for line in plan.splitlines() if "Join" in line]

    if not lines:
        print("No Join operator found in executed physical plan.")
    else:
        print("Join-related nodes:")
        for line in lines:
            print(line)


# Executes an action under a unique job group ID (visible in Spark UI), then collects and prints the job and stage IDs it produced.
def run_and_report(action_label, action_fn):
    sc = spark.sparkContext
    tracker = sc.statusTracker()
    group_id = f"notebook_demo_{int(time.time() * 1000)}"

    sc.setJobGroup(group_id, action_label)
    try:
        result = action_fn()
    finally:
        job_ids = list(tracker.getJobIdsForGroup(group_id))
        stage_ids = set()
        for job_id in job_ids:
            job_info = tracker.getJobInfo(job_id)
            if job_info is not None:
                stage_ids.update(list(job_info.stageIds))

        print(
            f"{action_label} -> jobs={job_ids}, stages={sorted(stage_ids)}, stage_count={len(stage_ids)}"
        )
        sc.setJobGroup("", "")

    return result


def reset_join_defaults():
    spark.conf.set("spark.sql.autoBroadcastJoinThreshold", str(10 * 1024 * 1024))
    spark.conf.set("spark.sql.join.preferSortMergeJoin", "true")
    spark.conf.set("spark.sql.adaptive.enabled", "false")

---

## Section 1 - Shuffle

### What is spill?

**Spill** means Spark had to move intermediate in-memory execution data to disk because memory was insufficient for the current operator.

**Where** spill commonly happens:
- Aggregations: groupBy with many unique keys (high cardinality)
- Joins: Sorting phase of Sort-Merge Join or building tables in Shuffle Hash Join
- Sorting: orderBy or sort operations
- Window Functions: Processing large partitions with OVER(...)
- Shuffle Read: Buffering and de-serializing incoming data from other executors

**Why** Spark spills to disk:
- dropping partial state is not allowed
- prevents OOM by using disk as a safety net
- keeps jobs alive at the expense of heavy I/O

**How** spill files are created:
- task builds in-memory buffers: sort buffer (for sorting) or hash map (for aggregation/joins)
- when memory thresholds are exceeded, Spark writes sort/hash chunks to local disk
- later, Spark merges spill files to produce final output for downstream operators

**Performance impact** of spill:
- more local disk I/O
- additional CPU for merge phases
- —Åreates stragglers (long-tail tasks) that delay the entire stage completion

### Spill demo
Spill depends on runtime memory pressure, so exact spill bytes vary by machine.
This experiment is deterministic in plan shape and often shows spill in Spark UI when resources are constrained:
- reduce shuffle partitions (larger per-partition workload)
- run wide aggregation + global sort
- inspect stage task metrics for spill counters

In [54]:
spill_test_partitions_bkp = spark.conf.get("spark.sql.shuffle.partitions")
spark.conf.set("spark.sql.shuffle.partitions", "1")

spill_demo = (
    taxi
    .orderBy("total_amount", "tip_amount", "trip_distance")
)

spill_demo.explain()
show_exchange_nodes(spill_demo)

== Physical Plan ==
*(2) Sort [total_amount#279 ASC NULLS FIRST, tip_amount#276 ASC NULLS FIRST, trip_distance#267 ASC NULLS FIRST], true, 0
+- Exchange rangepartitioning(total_amount#279 ASC NULLS FIRST, tip_amount#276 ASC NULLS FIRST, trip_distance#267 ASC NULLS FIRST, 1), ENSURE_REQUIREMENTS, [plan_id=916]
   +- *(1) ColumnarToRow
      +- FileScan parquet [VendorID#263,tpep_pickup_datetime#264,tpep_dropoff_datetime#265,passenger_count#266L,trip_distance#267,RatecodeID#268L,store_and_fwd_flag#269,PULocationID#270,DOLocationID#271,payment_type#272L,fare_amount#273,extra#274,mta_tax#275,tip_amount#276,tolls_amount#277,improvement_surcharge#278,total_amount#279,congestion_surcharge#280,Airport_fee#281] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/C:/code/spark-tuning-handbook/data/taxi], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<VendorID:int,tpep_pickup_datetime:timestamp_ntz,tpep_dropoff_datetime:timestamp_ntz,passen...


In [55]:
# Sorting 41M rows in a single partition exceeded executor's available execution memory, forcing Spark to spill 6.5 GiB to memory and 1.6 GiB to disk.
run_and_report("spill experiment", lambda: spill_demo.tail(1))

spark.conf.set("spark.sql.shuffle.partitions", spill_test_partitions_bkp)
print("restored spark.sql.shuffle.partitions:", spark.conf.get("spark.sql.shuffle.partitions"))

spill experiment -> jobs=[23], stages=[39, 40], stage_count=2
restored spark.sql.shuffle.partitions: 6


**Spill verification**:
- Spark UI -> Stages -> stage detail -> Task metrics
- inspect `Spill (Memory)` and `Spill (Disk)`
- if both remain zero/not displayed, increase workload volume or tighten executor memory for this experiment

### Spill practical mitigation

Mitigation checklist:
- increase executor memory when possible
- tune `spark.memory.fraction` carefully to rebalance execution vs storage (cache/persist)
- tune `spark.sql.shuffle.partitions` so partition size is reasonable
- pre-aggregate before expensive joins when logic allows
- broadcast truly small side of joins to remove one redistribution path
- fix skew (salting, AQE skew optimization, split hot keys)
- repartition intelligently by the join/aggregation key before expensive steps

## What shuffle is and why it is expensive

A **shuffle** is the process of redistributing data across partitions so that data with the same key ends up in the same partition. This typically involves copying data across executors and machines, making the shuffle a complex and costly operation. 

Typical shuffle-triggering wide transformations: `groupBy` / aggregations by key, `distinct`, `repartition`, `orderBy` (global sort), joins (except broadcast), etc. Key-dependent operations that require a global data reorganization.

Why shuffle creates a **stage boundary**:
- upstream tasks (map side) must finish producing shuffle files first (buffered in memory, spill to disk if needed)
- shuffle transfers partitioned shuffle blocks fetched from local disk across executors over the network
- downstream tasks (reduce side) cannot start until required shuffle partitions are available

=> Spark DAG scheduler splits execution into separate stages at `Exchange` (see Phisical plan)

Shuffle write vs shuffle read:
- **shuffle write** (map side): each map task writes its partitioned shuffle output to a local shuffle data file (sequential blocks)
- **shuffle read** (reduce side): each reduce task fetches needed blocks over the network from many map tasks, then aggregation/join/sort

Why **shuffle is expensive**:
- serialization cost before writing blocks
- disk I/O for spill/write shuffle files
- moving data across the cluster via the network
- deserialization cost on read
- memory pressure while buffering/sorting/hash-building

### Demo 1 - `groupBy()` shuffle

`groupBy("PULocationID")` requires all rows for each `PULocationID` to meet in the same partition for final aggregation.
That requires key-based redistribution, so we expect an `Exchange` and a stage boundary.

In [51]:
grouped = (
    taxi
    .groupBy("PULocationID")
    .agg(F.count("*").alias("row_count"), F.sum("total_amount").alias("total_sales"))
)

grouped.explain()
show_exchange_nodes(grouped)

== Physical Plan ==
*(2) HashAggregate(keys=[PULocationID#270], functions=[count(1), sum(total_amount#279)])
+- Exchange hashpartitioning(PULocationID#270, 6), ENSURE_REQUIREMENTS, [plan_id=813]
   +- *(1) HashAggregate(keys=[PULocationID#270], functions=[partial_count(1), partial_sum(total_amount#279)])
      +- *(1) ColumnarToRow
         +- FileScan parquet [PULocationID#270,total_amount#279] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/C:/code/spark-tuning-handbook/data/taxi], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<PULocationID:int,total_amount:double>


Exchange-related nodes:
+- Exchange hashpartitioning(PULocationID#270, 6), ENSURE_REQUIREMENTS, [plan_id=813]


In [56]:
run_and_report("groupBy aggregation action", lambda: grouped.show(10, truncate=False))

+------------+---------+--------------------+
|PULocationID|row_count|total_sales         |
+------------+---------+--------------------+
|186         |1362156  |3.3822911449999966E7|
|234         |1105439  |2.472046954999828E7 |
|263         |766745   |1.6355274009999966E7|
|10          |15770    |1037939.5900000082  |
|90          |652648   |1.4660598599999804E7|
|239         |1142282  |2.47995970299976E7  |
|4           |71461    |1687265.0800000008  |
|209         |87461    |2612556.970000002   |
|161         |1914607  |4.703328904000025E7 |
|45          |65301    |1812732.3700000045  |
+------------+---------+--------------------+
only showing top 10 rows

groupBy aggregation action -> jobs=[24], stages=[41, 42], stage_count=2


In [None]:
#in case of confusion with lambdas, it's basically the same code, same result

def my_func():
    return grouped.show(10, truncate=False)

# pass fucntion as object
run_and_report("action", my_func)

Physical-plan reading:
- Look for `HashAggregate` (partial) -> `Exchange hashpartitioning(store_id, ...)` -> `HashAggregate` (final).
- The `Exchange` is the shuffle boundary.

Stage interpretation:
- upstream stage writes shuffle blocks (shuffle write > 0)
- downstream stage reads those blocks (shuffle read > 0)

### Demo 2 - `distinct()` shuffle

`distinct()` is logically a deduplication (`dropDuplicates()`) by all selected columns. To remove duplicates globally, Spark groups identical keys across partitions, which requires shuffle.

In [58]:
distinct_pairs = taxi.select("PULocationID", "DOLocationID").distinct()

distinct_pairs.explain()
show_exchange_nodes(distinct_pairs)

== Physical Plan ==
*(2) HashAggregate(keys=[PULocationID#270, DOLocationID#271], functions=[])
+- Exchange hashpartitioning(PULocationID#270, DOLocationID#271, 6), ENSURE_REQUIREMENTS, [plan_id=984]
   +- *(1) HashAggregate(keys=[PULocationID#270, DOLocationID#271], functions=[])
      +- *(1) ColumnarToRow
         +- FileScan parquet [PULocationID#270,DOLocationID#271] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/C:/code/spark-tuning-handbook/data/taxi], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<PULocationID:int,DOLocationID:int>


Exchange-related nodes:
+- Exchange hashpartitioning(PULocationID#270, DOLocationID#271, 6), ENSURE_REQUIREMENTS, [plan_id=984]


In [59]:
distinct_rows = run_and_report("distinct count action", lambda: distinct_pairs.count())

print("distinct rows:", distinct_rows)

distinct count action -> jobs=[25], stages=[43, 44, 45], stage_count=3
distinct rows: 50518


Physical-plan reading:
- Expect aggregate-style dedup operators and `Exchange hashpartitioning(...)`.
- Dedup without data movement is not possible when duplicates can be in different partitions.

Stage interpretation:
- map side writes per-key shuffle buckets
- reduce side reads buckets and emits unique keys


### Demo 3 - `orderBy()` shuffle

Global `orderBy()`/`sort()` requires a global ordering guarantee.
A global order cannot be produced partition-locally, so Spark introduces range partitioning + sort work.


In [61]:
ordered = taxi.orderBy(F.col("total_amount").desc())

ordered.explain()
show_exchange_nodes(ordered)

== Physical Plan ==
*(2) Sort [total_amount#279 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(total_amount#279 DESC NULLS LAST, 6), ENSURE_REQUIREMENTS, [plan_id=1094]
   +- *(1) ColumnarToRow
      +- FileScan parquet [VendorID#263,tpep_pickup_datetime#264,tpep_dropoff_datetime#265,passenger_count#266L,trip_distance#267,RatecodeID#268L,store_and_fwd_flag#269,PULocationID#270,DOLocationID#271,payment_type#272L,fare_amount#273,extra#274,mta_tax#275,tip_amount#276,tolls_amount#277,improvement_surcharge#278,total_amount#279,congestion_surcharge#280,Airport_fee#281] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/C:/code/spark-tuning-handbook/data/taxi], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<VendorID:int,tpep_pickup_datetime:timestamp_ntz,tpep_dropoff_datetime:timestamp_ntz,passen...


Exchange-related nodes:
+- Exchange rangepartitioning(total_amount#279 DESC NULLS LAST, 6), ENSURE_REQUIREMENTS, [plan_id=1094]


In [65]:
run_and_report("orderBy total_amount", lambda: ordered.select("total_amount").show(5, truncate=False))

+------------+
|total_amount|
+------------+
|335550.94   |
|334145.3    |
|50558.68    |
|12903.4     |
|9792.0      |
+------------+
only showing top 5 rows

orderBy total_amount -> jobs=[28], stages=[48], stage_count=1


Physical-plan reading:
- Expect `Exchange rangepartitioning(...)` (or hash partitioning depending on planner path) and sort operators.
- Global ordering introduces expensive wide dependency.

Stage interpretation:
- upstream stage redistributes rows by range/key
- downstream stage performs final sort per output partition

---

## Section 2 - Joins

### Join strategies

Spark chooses physical join strategy from logical join + stats/configuration.
Three core strategies:

1. **Broadcast Hash Join (BHJ)** - small + any
   - one side is small enough (default threshold - 10 MB) or explicitly hinted
   - small side is collected on driver, broadcast to all executors, and built into an in-memory hash table
   - large side is scanned and probed against that hash table, no shuffle needed
   - trade-off: driver and executor memory usage for broadcast materialization

2. **Sort-Merge Join (SMJ)** - large + large
   - both sides are shuffled by join key
   - both sides are sorted by join key
   - merge phase: scan sorted sides (two pointersüëª), matching and joining keys
   - default scalable strategy for large joins

3. **Shuffle Hash Join (SHJ)** - medium + large
   - both sides are shuffled by join key
   - smaller side is built into a hash table per partition, **no global sort required**
   - larger side is probed against that hash table
   - can be faster than SMJ **when sort cost outweighs hash build**
   - risk: per-partition hash table build can pressure executor memory

**AQE note**:
- with Adaptive Query Execution enabled, Spark can switch strategy at runtime
- typical example: initially planned SMJ can become BHJ when runtime stats reveal a smaller side than expected


In [5]:
# Reusable DataFrames - saved to disk for clean physical plans
base_path = r"C:\code\spark-tuning-handbook\data\taxi"

taxi.select(
    "VendorID", "PULocationID", "DOLocationID", "payment_type",
    "trip_distance", "fare_amount", "tip_amount", "total_amount"
).write.mode("overwrite").parquet(f"{base_path}\\trips_fact")

taxi.groupBy("PULocationID", "DOLocationID").agg(
    F.count("*").alias("trip_count"),
    F.avg("total_amount").alias("avg_amount"),
    F.avg("trip_distance").alias("avg_distance")
).write.mode("overwrite").parquet(f"{base_path}\\zone_stats")

taxi.select(
    "PULocationID", "RatecodeID", "congestion_surcharge", "Airport_fee"
).dropDuplicates(["PULocationID"]).write.mode("overwrite").parquet(f"{base_path}\\location_dim")

# Read back - no lineage, clean plans
trips_fact = spark.read.parquet(f"{base_path}\\trips_fact")
zone_stats = spark.read.parquet(f"{base_path}\\zone_stats")
location_dim = spark.read.parquet(f"{base_path}\\location_dim")

print("trips_fact rows (big):", trips_fact.count())
print("zone_stats rows (med):", zone_stats.count(), "- yes, we could broadcast it but autoBroadcastJoinThreshold = -1 for demo purposes")
print("location_dim rows (small):", location_dim.count())

trips_fact rows (big): 41169720
zone_stats rows (med): 50518 - yes, we could broadcast it but autoBroadcastJoinThreshold = -1 for demo purposes
location_dim rows (small): 263


### Demo 1 - Broadcast Hash Join

We force BHJ with `broadcast()` hint on `location_dim`.
This should produce `BroadcastHashJoin` in the physical plan.

**NB**: BroadcastExchange is not a shuffle, it is a full copy of the small side sent to every executor.

In [7]:
reset_join_defaults()

bhj = trips_fact.join(broadcast(location_dim), on="PULocationID", how="inner")

bhj.explain()
show_join_nodes(bhj)
show_exchange_nodes(bhj)

== Physical Plan ==
*(2) Project [PULocationID#161, VendorID#160, DOLocationID#162, payment_type#163L, trip_distance#164, fare_amount#165, tip_amount#166, total_amount#167, RatecodeID#187L, congestion_surcharge#188, Airport_fee#189]
+- *(2) BroadcastHashJoin [PULocationID#161], [PULocationID#186], Inner, BuildRight, false
   :- *(2) Filter isnotnull(PULocationID#161)
   :  +- *(2) ColumnarToRow
   :     +- FileScan parquet [VendorID#160,PULocationID#161,DOLocationID#162,payment_type#163L,trip_distance#164,fare_amount#165,tip_amount#166,total_amount#167] Batched: true, DataFilters: [isnotnull(PULocationID#161)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/C:/code/spark-tuning-handbook/data/taxi/trips_fact], PartitionFilters: [], PushedFilters: [IsNotNull(PULocationID)], ReadSchema: struct<VendorID:int,PULocationID:int,DOLocationID:int,payment_type:bigint,trip_distance:double,fa...
   +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigin

#### Spark build broadcast side (location_dim) as a separate job

In [8]:
bhj_rows = run_and_report("BHJ action", lambda: bhj.count())
print("BHJ rows:", bhj_rows)

BHJ action -> jobs=[13, 12], stages=[18, 19, 20], stage_count=3
BHJ rows: 41169720


**Memory behavior**:
- no shuffle for the broadcast side - small side is collected on the driver and broadcast to all executors
- executors build a read-only in-memory hash table from the broadcast data
- this lives in storage memory, not execution memory => **Spark cannot spill it, it either fits entirely or causes OOM**

### Demo 2 - Sort-Merge Join (forced)

We disable broadcast and prefer merge strategy.
This should produce `SortMergeJoin` with shuffle + sort on both sides.


In [11]:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
spark.conf.set("spark.sql.join.preferSortMergeJoin", "true")

smj_left = trips_fact.select("PULocationID", "VendorID", "fare_amount", "tip_amount").hint("merge")
smj_right = location_dim.select("PULocationID", "RatecodeID", "congestion_surcharge").hint("merge")

smj = smj_left.join(smj_right, on="PULocationID", how="inner")

smj.explain()
show_join_nodes(smj)
show_exchange_nodes(smj)

== Physical Plan ==
*(5) Project [PULocationID#161, VendorID#160, fare_amount#165, tip_amount#166, RatecodeID#187L, congestion_surcharge#188]
+- *(5) SortMergeJoin [PULocationID#161], [PULocationID#186], Inner
   :- *(2) Sort [PULocationID#161 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(PULocationID#161, 8), ENSURE_REQUIREMENTS, [plan_id=444]
   :     +- *(1) Filter isnotnull(PULocationID#161)
   :        +- *(1) ColumnarToRow
   :           +- FileScan parquet [VendorID#160,PULocationID#161,fare_amount#165,tip_amount#166] Batched: true, DataFilters: [isnotnull(PULocationID#161)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/C:/code/spark-tuning-handbook/data/taxi/trips_fact], PartitionFilters: [], PushedFilters: [IsNotNull(PULocationID)], ReadSchema: struct<VendorID:int,PULocationID:int,fare_amount:double,tip_amount:double>
   +- *(4) Sort [PULocationID#186 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(PULocationID#186, 8), ENSURE_REQUI

In [12]:
smj_rows = run_and_report("SMJ action", lambda: smj.count())
print("SMJ rows:", smj_rows)

SMJ action -> jobs=[14], stages=[21, 22, 23, 24], stage_count=4
SMJ rows: 41169720


### Demo 3 - Shuffle Hash Join (forced)

We disable broadcast, disable sort-merge preference, and apply `shuffle_hash` hint.
This should produce `ShuffledHashJoin` if planner conditions are met.


In [13]:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
spark.conf.set("spark.sql.join.preferSortMergeJoin", "false")

shj_left = trips_fact.select("PULocationID", "DOLocationID", "total_amount").hint("shuffle_hash")
shj_right = zone_stats.select("PULocationID", "DOLocationID", "trip_count").hint("shuffle_hash")

shj = shj_left.join(shj_right, on=["PULocationID", "DOLocationID"], how="inner")

shj.explain()
show_join_nodes(shj)
show_exchange_nodes(shj)

== Physical Plan ==
*(3) Project [PULocationID#161, DOLocationID#162, total_amount#167, trip_count#178L]
+- *(3) ShuffledHashJoin [PULocationID#161, DOLocationID#162], [PULocationID#176, DOLocationID#177], Inner, BuildRight
   :- Exchange hashpartitioning(PULocationID#161, DOLocationID#162, 8), ENSURE_REQUIREMENTS, [plan_id=636]
   :  +- *(1) Filter (isnotnull(PULocationID#161) AND isnotnull(DOLocationID#162))
   :     +- *(1) ColumnarToRow
   :        +- FileScan parquet [PULocationID#161,DOLocationID#162,total_amount#167] Batched: true, DataFilters: [isnotnull(PULocationID#161), isnotnull(DOLocationID#162)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/C:/code/spark-tuning-handbook/data/taxi/trips_fact], PartitionFilters: [], PushedFilters: [IsNotNull(PULocationID), IsNotNull(DOLocationID)], ReadSchema: struct<PULocationID:int,DOLocationID:int,total_amount:double>
   +- Exchange hashpartitioning(PULocationID#176, DOLocationID#177, 8), ENSURE_REQUIREMENTS, [plan_id=642]

**BuildRight** means "build the hash table from the right branch of the join tree". The right side (zone_stats, smaller) is hashed per partition, and the left side (trips_fact, larger) streams through and probes that hash table to find matching keys.

In [15]:
shj_rows = run_and_report("SHJ action", lambda: shj.count())
print("SHJ rows:", shj_rows)

# Restore defaults for remaining sections
reset_join_defaults()

SHJ action -> jobs=[16], stages=[29, 30, 31, 32], stage_count=4
SHJ rows: 41169720


### Handling data skew in joins

**Skew** means key distribution is highly uneven - one or few keys hold most of the rows.

Why skew is **expensive**:
- shuffle partitions are key-driven, so **hot keys** create very large partitions
- large skewed partitions increase spill probability
- one or few slow tasks hold back entire stage completion (long-running stragglers)

**Mitigation** techniques:
- salting
- AQE skew join optimization (spark.sql.adaptive.skewJoin.enabled)
- broadcast small side to avoid shuffle (and skew) entirely

In [9]:
# Force most rows into one join key (0) to create skew.
skew_left = (
    trips_fact
    .select("PULocationID", "DOLocationID", "total_amount")
    .withColumn("join_key", F.when((F.col("PULocationID") % 20) == 0, F.col("PULocationID")).otherwise(F.lit(0)))
)
skew_right = (
    trips_fact
    .withColumn(
        "join_key",
        F.when((F.col("PULocationID") % 20) == 0, F.col("PULocationID")).otherwise(F.lit(0))
    )
    .select("join_key")
    .dropDuplicates(["join_key"])
)

# Save and read back for clean plans
skew_left.write.mode("overwrite").parquet(f"{base_path}\\skew_left")
skew_right.write.mode("overwrite").parquet(f"{base_path}\\skew_right")
skew_left = spark.read.parquet(f"{base_path}\\skew_left")
skew_right = spark.read.parquet(f"{base_path}\\skew_right")

# Show skew profile: key 0 should dominate.
skew_left.groupBy("join_key").count().orderBy(F.desc("count")).show(10, truncate=False)

+--------+--------+
|join_key|count   |
+--------+--------+
|0       |39679338|
|140     |804618  |
|100     |630553  |
|80      |20310   |
|260     |13158   |
|40      |7817    |
|220     |2777    |
|160     |2612    |
|60      |2162    |
|180     |2086    |
+--------+--------+
only showing top 10 rows



### Demo - skewed join without mitigation

Broadcast is disabled to make shuffle cost visible.


In [10]:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
spark.conf.set("spark.sql.join.preferSortMergeJoin", "true")

skew_join = skew_left.join(skew_right, on="join_key", how="inner")

skew_join.explain()
show_join_nodes(skew_join)
show_exchange_nodes(skew_join)

== Physical Plan ==
*(5) Project [join_key#282, PULocationID#279, DOLocationID#280, total_amount#281]
+- *(5) SortMergeJoin [join_key#282], [join_key#287], Inner
   :- *(2) Sort [join_key#282 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(join_key#282, 8), ENSURE_REQUIREMENTS, [plan_id=559]
   :     +- *(1) Filter isnotnull(join_key#282)
   :        +- *(1) ColumnarToRow
   :           +- FileScan parquet [PULocationID#279,DOLocationID#280,total_amount#281,join_key#282] Batched: true, DataFilters: [isnotnull(join_key#282)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/C:/code/spark-tuning-handbook/data/taxi/skew_left], PartitionFilters: [], PushedFilters: [IsNotNull(join_key)], ReadSchema: struct<PULocationID:int,DOLocationID:int,total_amount:double,join_key:int>
   +- *(4) Sort [join_key#287 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(join_key#287, 8), ENSURE_REQUIREMENTS, [plan_id=568]
         +- *(3) Filter isnotnull(join_key#287)
   

In [11]:
skew_rows = run_and_report("skewed join action", lambda: skew_join.count())
print("skewed join rows:", skew_rows)

skewed join action -> jobs=[19], stages=[28, 29, 30, 31], stage_count=4
skewed join rows: 41169720


In Spark UI, skew shows as a few **tasks with much longer duration**, disproportionate shuffle read, and possible **spill**.

_In my latest run, the skewed task took 19s and read 9.1 MiB / 40M records with 792 MiB memory spill and 5.2 MiB disk spill, while the remaining 7 tasks finished in under 0.5s with negligible shuffle read - **based skew**._


### Demo - salting the hot key

**Salting** adds a random bucket value (0..N) to the hot key, turning one overloaded partition into N smaller ones. The right side is duplicated across all N buckets for the hot key only, so join results stay correct.

**NB**: in practice skew_right is small enough to broadcast, salting is shown here as a technique for cases where both sides are **medium/large** and broadcast is not an option.

In [12]:
salt_buckets = 8

skew_left_salted = (
    skew_left
    # For hot key (join_key == 0): assign a random salt 0..7 based on hash of DOLocationID. 
    # For cold keys: salt = 0. 
    # This splits the hot key into 8 partitions instead of one.
    .withColumn("salt", F.when(F.col("join_key") == 0, F.pmod(F.hash("DOLocationID"), F.lit(salt_buckets))).otherwise(F.lit(0)))
)

salt_values = spark.range(0, salt_buckets).toDF("salt")
hot_right = skew_right.filter(F.col("join_key") == 0).crossJoin(salt_values) # multiply record with hot key (0) by 8 salt values
cold_right = skew_right.filter(F.col("join_key") != 0).withColumn("salt", F.lit(0))
skew_right_salted = hot_right.unionByName(cold_right)

salted_join = skew_left_salted.join(skew_right_salted, on=["join_key", "salt"], how="inner")

#salted_join.explain() 
show_join_nodes(salted_join)
show_exchange_nodes(salted_join)

Join-related nodes:
+- *(7) SortMergeJoin [join_key#282, cast(salt#323 as bigint)], [join_key#287, salt#331L], Inner
Exchange-related nodes:
:  +- Exchange hashpartitioning(join_key#282, cast(salt#323 as bigint), 8), ENSURE_REQUIREMENTS, [plan_id=788]
+- Exchange hashpartitioning(join_key#287, salt#331L, 8), ENSURE_REQUIREMENTS, [plan_id=805]


In [13]:
salted_rows = run_and_report("salted join action", lambda: salted_join.count())
print("salted join rows:", salted_rows)

reset_join_defaults()

salted join action -> jobs=[20], stages=[32, 33, 34, 35], stage_count=4
salted join rows: 41169720


How this mitigation works:
- the hot key no longer maps to exactly one huge partition
- work is distributed across multiple salted partitions
- this reduces single-task pressure and distributes spill evenly, but total data volume stays the same

In [30]:
# Optional: enable AQE skew optimization and re-run skewed join action
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true") # ~ automated skew detection and salting

print("spark.sql.adaptive.enabled:", spark.conf.get("spark.sql.adaptive.enabled"))
print("spark.sql.adaptive.skewJoin.enabled:", spark.conf.get("spark.sql.adaptive.skewJoin.enabled"))

spark.sql.adaptive.enabled: true
spark.sql.adaptive.skewJoin.enabled: true


In [35]:
# Keep deterministic non-AQE behavior for the rest of the notebook.
spark.conf.set("spark.sql.adaptive.enabled", "false")

---

## Section 3 - Partitioning & Data Organization

### Theory: `repartition()` vs `coalesce()`

`repartition()`:
- full shuffle - produces evenly distributed partitions
- can increase or decrease partitions
- supports key-based repartitioning (`repartition(n, key)`)
- introduces `Exchange` and stage boundary

`coalesce()`:
- narrow transformation when reducing partitions
- avoids full shuffle by e by merging adjacent partitions - result can be uneven
- can only reduce partition count efficiently
- does not trigger a shuffle, simply merges existing partitions on the same executor without data redistribution

**Performance implications**:
- use `repartition()` when you need balanced data movement or key alignment
- use `coalesce()` for cheap partition reduction before write/output when data is already reasonably distributed

### Demo 1 - `repartition()`

This is a full redistribution by `store_id` into 12 partitions.


In [17]:
rep12 = taxi.repartition(12, "PULocationID")
print("rep12 partitions:", rep12.rdd.getNumPartitions())

rep12.explain()
show_exchange_nodes(rep12)

rep12 partitions: 12
== Physical Plan ==
Exchange hashpartitioning(PULocationID#7, 12), REPARTITION_BY_NUM, [plan_id=1031]
+- *(1) ColumnarToRow
   +- FileScan parquet [VendorID#0,tpep_pickup_datetime#1,tpep_dropoff_datetime#2,passenger_count#3L,trip_distance#4,RatecodeID#5L,store_and_fwd_flag#6,PULocationID#7,DOLocationID#8,payment_type#9L,fare_amount#10,extra#11,mta_tax#12,tip_amount#13,tolls_amount#14,improvement_surcharge#15,total_amount#16,congestion_surcharge#17,Airport_fee#18] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/C:/code/spark-tuning-handbook/data/taxi], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<VendorID:int,tpep_pickup_datetime:timestamp_ntz,tpep_dropoff_datetime:timestamp_ntz,passen...


Exchange-related nodes:
Exchange hashpartitioning(PULocationID#7, 12), REPARTITION_BY_NUM, [plan_id=1031]


In [18]:
rep12_rows = run_and_report("repartition(12) action",lambda: rep12.count())
print("rep12 rows:", rep12_rows)

repartition(12) action -> jobs=[21], stages=[36, 37, 38], stage_count=3
rep12 rows: 41169720


Plan reading:
- Expect `Exchange hashpartitioning(PULocationID, 12)`
- This confirms full shuffle and stage split


### Demo 2 - `coalesce()`

This reduces partitions from current scan partition count to 2 without full redistribution.


In [19]:
coal2 = taxi.coalesce(2)
print("coal2 partitions:", coal2.rdd.getNumPartitions())

coal2.explain()
show_exchange_nodes(coal2)

coal2 partitions: 2
== Physical Plan ==
Coalesce 2
+- *(1) ColumnarToRow
   +- FileScan parquet [VendorID#0,tpep_pickup_datetime#1,tpep_dropoff_datetime#2,passenger_count#3L,trip_distance#4,RatecodeID#5L,store_and_fwd_flag#6,PULocationID#7,DOLocationID#8,payment_type#9L,fare_amount#10,extra#11,mta_tax#12,tip_amount#13,tolls_amount#14,improvement_surcharge#15,total_amount#16,congestion_surcharge#17,Airport_fee#18] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/C:/code/spark-tuning-handbook/data/taxi], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<VendorID:int,tpep_pickup_datetime:timestamp_ntz,tpep_dropoff_datetime:timestamp_ntz,passen...


No Exchange nodes in executed physical plan.


In [20]:
coal2_rows = run_and_report("coalesce(2) action",lambda: coal2.count())
print("coal2 rows:", coal2_rows)

coalesce(2) action -> jobs=[22], stages=[39, 40], stage_count=2
coal2 rows: 41169720


Plan reading:
- In the common reduce-only path, `coalesce()` does not add full `Exchange` shuffle
- coalesce() produces fewer stages than repartition() because it avoids the shuffle Exchange

### Theory: Partitioning vs Bucketing

**Partitioning** (directory-based):
- rows are physically split by partition column values into directory tree (e.g. year=2024/month=06/...)
- enables partition pruning, Spark skips entire directories when filter matches partition columns
- improves scan I/O by reading only relevant files


**Bucketing** pre-sorts data at write time into a fixed number of files by `hash(key) % num_buckets`. When two tables are bucketed by the same key into the same number of buckets, Spark already knows which file maps to which bucket, so it can join or aggregate **without shuffling**.

**Bucketing** (hash-based fixed buckets):
- at write time, each row is assigned to a file by hash(key) % num_buckets
- at read time, Spark knows data is already distributed by key => can skip shuffle for joins and aggregations
- shuffle avoidance only works when both join sides share the same bucket key and bucket count

**NB**: Bucketing requires a metastore (**Hive-only optimization**), bucket metadata (key, count, sort order) is stored in the catalog. Without it, Spark reads files as plain parquet and shuffles anyway.


### Demo 3 - partitioned write and partition pruning

Write by `tpep_pickup_datetime` year/month, then filter by one date and inspect scan plan for partition pruning.


In [25]:
import os

partitioned_path = r"C:\code\spark-tuning-handbook\data\tmp\taxi_partitioned_by_year_month"

taxi_to_partition = taxi.select(
    F.year("tpep_pickup_datetime").alias("year"),
    F.month("tpep_pickup_datetime").alias("month"),
    "PULocationID", "DOLocationID", "total_amount"
).filter(F.year("tpep_pickup_datetime") == 2024) # filter out random old records

taxi_to_partition.write \
    .mode("overwrite") \
    .partitionBy("year", "month") \
    .parquet(partitioned_path)

year_dirs = sorted([name for name in os.listdir(partitioned_path) if name.startswith("year=")])
print("year directories:", year_dirs)
last_year = os.path.join(partitioned_path, year_dirs[-1])
month_dirs = sorted([name for name in os.listdir(last_year) if name.startswith("month=")])
print("month directories under", year_dirs[-1] + ":", month_dirs)
print("month dir count:", len(month_dirs))

year directories: ['year=2024']
month directories under year=2024: ['month=1', 'month=10', 'month=11', 'month=12', 'month=2', 'month=3', 'month=4', 'month=5', 'month=6', 'month=7', 'month=8', 'month=9']
month dir count: 12


In [28]:
partitioned_df = spark.read.parquet(partitioned_path)

june_df = partitioned_df.filter((F.col("year") == 2024) & (F.col("month") == 6))
june_df.explain()
#show_exchange_nodes(pruned)

plan = june_df._jdf.queryExecution().executedPlan().toString() # gets Physical plan as string via Java API
print("contains FileScan:", "FileScan" in plan)
print("contains PartitionFilters:", "PartitionFilters" in plan)

june_rows = run_and_report("partition-pruning action",lambda: june_df.count())
print("rows for year=2024/month=6:", june_rows)

== Physical Plan ==
*(1) ColumnarToRow
+- FileScan parquet [PULocationID#508,DOLocationID#509,total_amount#510,year#511,month#512] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/C:/code/spark-tuning-handbook/data/tmp/taxi_partitioned_by_year_..., PartitionFilters: [isnotnull(year#511), isnotnull(month#512), (year#511 = 2024), (month#512 = 6)], PushedFilters: [], ReadSchema: struct<PULocationID:int,DOLocationID:int,total_amount:double>


contains FileScan: True
contains PartitionFilters: True
partition-pruning action -> jobs=[33], stages=[53, 54], stage_count=2
rows for year=2024/month=6: 3539170


`dynamic partition pruning time: 0 ms`: filter values are known at planning time (static pruning), so no runtime (dynamic) pruning is needed. Dynamic partition pruning (DPP) applies when filter values come from the result of another query, e.g. the build (AKA small) side of a join.

Plan reading:
- PartitionFilters: [year=2024, month=6] confirms Spark pruned directories at scan time
- `number of partitions read: 1` and `size of files read: 13.7 MiB` => only June data was touched out of ~670 MB total

### Demo 4 - bucketing

This demo creates two bucketed tables with matching bucket key and count, then checks the physical plan to see if Spark skips shuffle for the join.

Again, at write time, Spark assigns rows to files by hash(key) % num_buckets => rows with the same key always land in the same file. If both tables are bucketed by the same key into the same number of buckets, Spark can join file-to-file directly => no shuffle needed.


In [6]:
spark.conf.set("spark.sql.sources.bucketing.enabled", "true")

spark.sql("DROP TABLE IF EXISTS taxi_bucketed_trips")
spark.sql("DROP TABLE IF EXISTS taxi_bucketed_zones")

trips_fact.select("PULocationID", "DOLocationID", "total_amount") \
    .write \
    .mode("overwrite") \
    .bucketBy(8, "PULocationID") \
    .sortBy("PULocationID") \
    .saveAsTable("taxi_bucketed_trips")

zone_stats.select("PULocationID", "DOLocationID", "trip_count") \
    .write \
    .mode("overwrite") \
    .bucketBy(8, "PULocationID") \
    .sortBy("PULocationID") \
    .saveAsTable("taxi_bucketed_zones")

print("bucketed tables created")

bucketed tables created


In [8]:
spark.sql("DESCRIBE EXTENDED taxi_bucketed_trips").show(200, truncate=False)
spark.sql("DESCRIBE EXTENDED taxi_bucketed_zones").show(200, truncate=False)

spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") # force disable Broadcast for bucket demo

bucket_join = spark.table("taxi_bucketed_trips").join(
    spark.table("taxi_bucketed_zones"),
    on="PULocationID",
    how="inner"
)

bucket_join.explain()
show_join_nodes(bucket_join)
show_exchange_nodes(bucket_join)

bucket_rows = run_and_report("bucket-join action", lambda: bucket_join.count())
print("bucket join rows:", bucket_rows)

+----------------------------+---------------------------------------------------------------------------------+-------+
|col_name                    |data_type                                                                        |comment|
+----------------------------+---------------------------------------------------------------------------------+-------+
|PULocationID                |int                                                                              |NULL   |
|DOLocationID                |int                                                                              |NULL   |
|total_amount                |double                                                                           |NULL   |
|                            |                                                                                 |       |
|# Detailed Table Information|                                                                                 |       |
|Catalog                     |sp

Plan reading for bucketing:
- Check DESCRIBE EXTENDED output for Num Buckets: 8 and Bucket Columns: [PULocationID].
- If no Exchange hashpartitioning in the join plan - Spark leveraged bucket alignment, no shuffle needed.
- If Exchange remains - planner did not use bucket metadata. Always verify in plan, never assume.

## Key Takeaways

**Shuffle**
- `Exchange` marks redistribution and stage boundaries.
- Shuffle cost is CPU + disk + network + memory pressure.
- Spill is a correctness mechanism with performance cost; reduce it by memory/partition/skew tuning.

**Joins**
- BHJ: broadcast small side, no shuffle on broadcast path, memory trade-off.
- SMJ: shuffle + sort on both sides, scalable default for large joins.
- SHJ: shuffle + per-partition hash, avoids sort but can pressure memory.
- Skew control is mandatory for stable join latency and spill reduction.

**Partitioning and data organization**
- `repartition()` is a full shuffle tool; `coalesce()` is a narrow reduce-partitions tool.
- Directory partitioning helps scan pruning.
- Bucketing (Hive) can reduce join shuffle only when key/count/sort/metadata/planner conditions align and are verified in plan.
