# PySpark: Zero to Hero
## Module 23: Optimizing Joins in Spark

Joins are one of the most resource-intensive operations in Spark because they usually involve **Shuffling** large amounts of data across the network.

In this module, we will learn how to optimize joins by choosing the right strategy based on the table sizes and using techniques like **Broadcasting** and **Bucketing**.

### Agenda:
1.  **Join Strategies:** Shuffle Hash Join vs. Sort Merge Join vs. Broadcast Hash Join.
2.  **Big vs. Small Table:** Using `broadcast()` for Map-Side Joins.
3.  **Big vs. Big Table:** Understanding Sort Merge Join and avoiding skew.
4.  **Bucketing:** Pre-shuffling data to speed up joins.

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, broadcast
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
import time

spark = SparkSession.builder \
    .appName("Optimizing_Joins") \
    .master("local[*]") \
    .config("spark.sql.autoBroadcastJoinThreshold", -1) \
    # We disable auto-broadcast initially to force Shuffle Joins for demonstration
    .getOrCreate()

print("Spark Session Active (Auto Broadcast Disabled)")

In [None]:
# 1. Big Table: Transactions (1 Million rows)
transactions_df = spark.range(1, 1000000).toDF("txn_id") \
    .withColumn("store_id", (col("txn_id") % 10).cast("integer")) \
    .withColumn("amount", col("txn_id") * 0.5)

# 2. Small Table: Stores (10 rows)
stores_data = [(i, f"Store_{i}", f"City_{i}") for i in range(10)]
schema = StructType([
    StructField("store_id", IntegerType(), False),
    StructField("store_name", StringType(), True),
    StructField("city", StringType(), True)
])
stores_df = spark.createDataFrame(stores_data, schema)

print("--- Transactions (Big Table) ---")
transactions_df.show(5)

print("--- Stores (Small Table) ---")
stores_df.show()

## 1. Standard Join (Sort Merge Join)

When joining two large tables (or when broadcast is disabled), Spark defaults to **Sort Merge Join**.
Steps:
1.  **Shuffle:** Move data with the same key to the same partition.
2.  **Sort:** Sort data within each partition by the join key.
3.  **Merge:** Iterate through both sorted datasets and join matching rows.

This is expensive due to shuffling and sorting.

In [None]:
# Perform a standard join
start_time = time.time()

joined_df = transactions_df.join(stores_data, "store_id")
# Trigger Action
print(f"Count: {joined_df.count()}")

print(f"Time taken (Sort Merge Join): {time.time() - start_time:.2f} seconds")

# Explain Plan
print("--- Execution Plan ---")
joined_df.explain()
# Look for 'SortMergeJoin' and 'Exchange' (Shuffle) in the plan.

## 2. Broadcast Hash Join (Map-Side Join)

If one table is **small** (fits in memory), we can avoid shuffling the large table.
Spark sends a copy of the small table to **every executor**. Each executor then joins its partition of the large table with the local copy of the small table.

**Benefits:**
*   NO Shuffle for the large table.
*   NO Sorting required.
*   Extremely fast.

In [None]:
# Explicitly broadcast the small table
start_time = time.time()

broadcast_joined_df = transactions_df.join(broadcast(stores_df), "store_id")
# Trigger Action
print(f"Count: {broadcast_joined_df.count()}")

print(f"Time taken (Broadcast Join): {time.time() - start_time:.2f} seconds")

# Explain Plan
print("--- Execution Plan ---")
broadcast_joined_df.explain()
# Look for 'BroadcastHashJoin' and notice there is NO 'Exchange' for the Transactions table.

## 3. Bucketing for Big Table Joins

When joining two **large** tables, Broadcast is not possible (OOM error).
However, if we frequently join these tables on a specific column (e.g., `user_id`), we can **Bucket** them.

**Bucketing** pre-shuffles and sorts the data into fixed "buckets" (files) on disk. When we join two bucketed tables:
*   Spark knows that data for `user_id=1` is in Bucket 1 for BOTH tables.
*   It skips the Shuffle and Sort phases entirely during the join.

In [None]:
# We need to save the data as a managed table to use bucketing.
# Note: bucketBy requires saving as a Table, not just a file.

db_name = "spark_optimization_demo"
spark.sql(f"CREATE DATABASE IF NOT EXISTS {db_name}")
spark.sql(f"USE {db_name}")

# Save Transactions Data as Bucketed Table
transactions_df.write \
    .bucketBy(4, "store_id") \
    .sortBy("store_id") \
    .mode("overwrite") \
    .saveAsTable("transactions_bucketed")

# Save Stores Data as Bucketed Table
stores_df.write \
    .bucketBy(4, "store_id") \
    .sortBy("store_id") \
    .mode("overwrite") \
    .saveAsTable("stores_bucketed")

print("Bucketed Tables Created.")

In [None]:
# Read tables back
t_bucketed = spark.table("transactions_bucketed")
s_bucketed = spark.table("stores_bucketed")

# Join them
start_time = time.time()
bucket_join_df = t_bucketed.join(s_bucketed, "store_id")
print(f"Count: {bucket_join_df.count()}")
print(f"Time taken (Bucketed Join): {time.time() - start_time:.2f} seconds")

# Explain Plan
print("--- Execution Plan (Bucketed) ---")
bucket_join_df.explain()
# Ideally, you should NOT see 'Exchange' (Shuffle) here, 
# because the data was already pre-shuffled during the write phase.

## Summary

1.  **Sort Merge Join:** Default for big tables. Safe but slow (Shuffle + Sort).
2.  **Broadcast Join:** Best for Big + Small table. Avoiding shuffle makes it very fast. Use `broadcast()`.
3.  **Bucketing:** Best for frequent Big + Big table joins. Pre-shuffles data on write to speed up future reads/joins.

**Next Steps:**
This concludes the core optimization techniques. In the next (and final) notebook, we will briefly cover **Spark SQL** syntax and how to mix SQL with DataFrames seamlessly.