# Join Strategies & Performance Tuning with PySpark (DataFrame-only, Serverless-friendly)

**Datasets:**
- `samples.tpch.customer`
- `samples.tpch.orders`
- `samples.tpch.lineitem`

In this notebook you will:
1. Perform star-schema joins
2. Inspect physical plans (`explain`)
3. Use broadcast joins
4. Use caching & reuse
5. Use `repartition` / `coalesce` with **DataFrame-only partition introspection**
6. Enable Adaptive Query Execution (AQE)


In [None]:
from pyspark.sql import functions as F

customer_df = spark.read.table("samples.tpch.customer")
orders_df   = spark.read.table("samples.tpch.orders")
lineitem_df = spark.read.table("samples.tpch.lineitem")

print("Customer count:", customer_df.count())
print("Orders count:", orders_df.count())
print("Lineitem count:", lineitem_df.count())

display(customer_df.limit(5))


## 1. Basic Star-Schema Join

We'll join:
- `customer` -> `orders` on `c_custkey = o_custkey`
- `orders` -> `lineitem` on `o_orderkey = l_orderkey`


In [None]:
# Join customer to orders
cust_orders_df = (
    customer_df.alias("c")
    .join(orders_df.alias("o"), F.col("c.c_custkey") == F.col("o.o_custkey"), "inner")
)

# Join the result to lineitem
cust_orders_lineitem_df = (
    cust_orders_df.alias("co")
    .join(lineitem_df.alias("l"), F.col("co.o_orderkey") == F.col("l.l_orderkey"), "inner")
)

display(cust_orders_lineitem_df.select("c_custkey", "o_orderkey", "l_linenumber").limit(10))


## 2. Inspect the Physical Plan with `explain`

This shows:
- Join types (BroadcastHashJoin, SortMergeJoin, etc.)
- Shuffle operations
- Estimated statistics


In [None]:
cust_orders_lineitem_df.explain(mode="extended")


## 3. Aggregate Query as a Baseline

Example query:
- Revenue per customer (`c_custkey`)
- Using sum of `l_extendedprice * (1 - l_discount)`


In [None]:
baseline_revenue_df = (
    cust_orders_lineitem_df
    .groupBy("c_custkey")
    .agg(
        F.sum(
            F.col("l_extendedprice") * (1 - F.col("l_discount"))
        ).alias("customer_revenue")
    )
)

display(baseline_revenue_df.orderBy(F.col("customer_revenue").desc()).limit(20))


## 4. Broadcast Join Optimization

- If one side of a join is **small enough**, we can broadcast it.
- Spark then avoids a shuffle on that side.

We'll:
- Broadcast the `customer` table when joining to `orders`.


In [None]:
from pyspark.sql.functions import broadcast

broadcast_cust_orders_df = (
    broadcast(customer_df.alias("c"))
    .join(orders_df.alias("o"), F.col("c.c_custkey") == F.col("o.o_custkey"), "inner")
)

broadcast_all_df = (
    broadcast_cust_orders_df.alias("co")
    .join(lineitem_df.alias("l"), F.col("co.o_orderkey") == F.col("l.l_orderkey"), "inner")
)

broadcast_all_df.explain(mode="extended")


## 5. Caching & Reuse

If you use the same intermediate result many times:
- Use `.cache()` or `.persist()` to avoid recompute + re-read.


In [None]:
# Cache the heavy join
broadcast_all_df_cached = broadcast_all_df.cache()

# Trigger cache materialization
broadcast_all_df_cached.count()

# Re-use cached DF for multiple aggregations
revenue_by_customer_df = (
    broadcast_all_df_cached
    .groupBy("c_custkey")
    .agg(
        F.sum(
            F.col("l_extendedprice") * (1 - F.col("l_discount"))
        ).alias("customer_revenue")
    )
)

revenue_by_nation_df = (
    broadcast_all_df_cached
    .groupBy("c_nationkey")
    .agg(
        F.sum(
            F.col("l_extendedprice") * (1 - F.col("l_discount"))
        ).alias("nation_revenue")
    )
)

display(revenue_by_customer_df.orderBy(F.col("customer_revenue").desc()).limit(10))
display(revenue_by_nation_df.orderBy(F.col("nation_revenue").desc()).limit(10))


## 6. Repartitioning & Coalescing (DataFrame-only Partition Introspection)

- Use `repartition()` to **increase** parallelism or shuffle by keys.
- Use `coalesce()` to **decrease** number of partitions without a full shuffle.
- Instead of `df.rdd.getNumPartitions()`, we use `spark_partition_id()` to count partitions.


In [None]:
# Repartition by key used in downstream aggregations
repartitioned_df = broadcast_all_df.repartition(64, "c_custkey")  # 64 is just an example

repartitioned_with_pid = repartitioned_df.withColumn("partition_id", F.spark_partition_id())
num_parts_repart = (
    repartitioned_with_pid
    .select("partition_id")
    .agg(F.countDistinct("partition_id").alias("num_partitions"))
    .collect()[0]["num_partitions"]
)

print("Repartitioned partitions (via spark_partition_id):", num_parts_repart)


In [None]:
# Coalesce when writing out or for subsequent stages
coalesced_df = repartitioned_df.coalesce(8)
coalesced_with_pid = coalesced_df.withColumn("partition_id", F.spark_partition_id())
num_parts_coal = (
    coalesced_with_pid
    .select("partition_id")
    .agg(F.countDistinct("partition_id").alias("num_partitions"))
    .collect()[0]["num_partitions"]
)

print("Coalesced partitions (via spark_partition_id):", num_parts_coal)


## 7. Adaptive Query Execution (AQE)

AQE can:
- Automatically coalesce shuffle partitions
- Change join strategies at runtime
- Handle skew

Make sure it's enabled:


In [None]:
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

print("AQE enabled:", spark.conf.get("spark.sql.adaptive.enabled"))

# Run an aggregation to let AQE optimize it
aqe_example_df = (
    broadcast_all_df
    .groupBy("c_custkey")
    .agg(F.sum("l_extendedprice").alias("total_extended_price"))
)

aqe_example_df.explain(mode="extended")
