# PySpark Internals and Concepts -- A Deep Dive

## For Databricks Serverless Free Account

This notebook explains the **conceptual foundations** of Apache Spark that every data
engineer and analyst must understand. Instead of just showing syntax, we explain *why*
things work the way they do and *how* Spark executes your code under the hood.

All examples use inline sample data -- no external files or clusters needed.

---

### Topics Covered

| # | Topic | Why It Matters |
|---|---|---|
| 1 | Lazy Evaluation | Spark does not run anything until you ask for a result |
| 2 | Transformations vs Actions | The two categories every Spark operation falls into |
| 3 | Narrow vs Wide Transformations | Determines whether Spark needs a shuffle |
| 4 | Shuffle -- The Most Expensive Operation | Moving data across the network |
| 5 | DAG, Jobs, Stages, Tasks | How Spark breaks your code into execution units |
| 6 | Catalyst Optimizer and Tungsten | How Spark rewrites your query to make it faster |
| 7 | Reading Explain Plans | The single most useful debugging skill |
| 8 | Caching and Persistence | When to keep data in memory and when not to |
| 9 | Broadcast Joins and Variables | Avoiding shuffle for small tables |
| 10 | Partitioning (Memory and Disk) | Controlling how data is distributed |
| 11 | Adaptive Query Execution (AQE) | Runtime optimizations (Databricks default) |
| 12 | Data Skew and Salting | Fixing the most common performance problem |
| 13 | Best Practices Checklist | Rules of thumb for production code |

In [None]:
# -- Setup: imports and sample data used throughout this notebook --
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType
)
from pyspark.sql.window import Window

# In Databricks, `spark` is pre-created. On local, uncomment:
# spark = SparkSession.builder.appName("Concepts").master("local[*]").getOrCreate()

print(f"Spark version : {spark.version}")
print(f"App name      : {spark.sparkContext.appName}")

# -- Sample data: 10 employees --
data = [
    ("Alice",   "Engineering", 95000,  30, "2020-01-15"),
    ("Bob",     "Marketing",   72000,  28, "2021-03-22"),
    ("Charlie", "Engineering", 110000, 35, "2018-07-10"),
    ("Diana",   "HR",          68000,  26, "2022-06-01"),
    ("Eve",     "Marketing",   85000,  32, "2019-11-30"),
    ("Frank",   "Engineering", 102000, 40, "2017-04-18"),
    ("Grace",   "HR",          71000,  29, "2021-09-05"),
    ("Hank",    "Sales",       78000,  34, "2020-02-14"),
    ("Ivy",     "Sales",       92000,  38, "2018-12-20"),
    ("Jack",    "Engineering", 115000, 45, "2016-08-25"),
]
columns = ["name", "department", "salary", "age", "hire_date"]
df = spark.createDataFrame(data, columns)
df.show()

---
## 1. Lazy Evaluation -- Spark Does Not Run Until You Ask

### The Concept

When you write `df.filter(...)` or `df.select(...)`, Spark does **nothing immediately**.
It only records *what* you want to do -- it builds a logical plan. The actual computation
happens only when you call an **action** like `.show()`, `.count()`, or `.collect()`.

### Why?

Because lazy evaluation allows Spark to **optimise the entire pipeline** before running it.
If you filter 1 billion rows down to 100, and then select 2 columns out of 200, Spark
can push the column selection *before* the filter, or skip reading columns it does not
need. If it ran eagerly (line by line), it would miss these optimisation opportunities.

### Analogy

Think of it like a restaurant kitchen. The waiter (your code) writes down the full order
(appetiser, main, dessert). The kitchen (Spark) does not start cooking the appetiser
immediately -- it looks at the full order first and plans how to use the ovens and pans
most efficiently. The order is "lazy"; the cooking starts only when the waiter says
"fire the order" (an action).

In [None]:
# -- Demonstration of lazy evaluation --

# These three lines add transformations to the plan.
# Nothing is computed yet -- no data is scanned, no CPU time is spent.
step1 = df.filter(F.col("salary") > 80000)          # plan step 1
step2 = step1.select("name", "department", "salary") # plan step 2
step3 = step2.withColumn("tax", F.col("salary") * 0.2)  # plan step 3

# At this point 'step3' is just a plan (a recipe).
# Prove it -- check the type:
print(f"Type of step3: {type(step3)}")   # DataFrame, not a list of results

# Now trigger an ACTION --> Spark compiles and runs the plan.
step3.show()    # <-- THIS is when computation actually happens

# Another action:
print(f"Row count: {step3.count()}")     # <-- This runs the plan again

---
## 2. Transformations vs Actions -- The Two Categories

Every Spark operation is either a **Transformation** or an **Action**.

### Transformations (Lazy -- build the plan)

A transformation takes a DataFrame and returns a **new** DataFrame. The original is never
modified (DataFrames are immutable). Transformations are lazy -- they only add a step to
the plan.

| Transformation | What it does |
|---|---|
| `select()` | Pick columns |
| `filter()` / `where()` | Keep rows matching a condition |
| `withColumn()` | Add or replace a column |
| `drop()` | Remove columns |
| `groupBy()` | Group rows (must be followed by `agg()`) |
| `join()` | Combine two DataFrames |
| `orderBy()` / `sort()` | Sort rows |
| `distinct()` | Remove duplicates |
| `union()` / `unionAll()` | Stack two DataFrames vertically |
| `repartition()` | Redistribute data across partitions |
| `coalesce()` | Reduce partition count without full shuffle |

### Actions (Eager -- trigger execution)

An action forces Spark to execute all the queued transformations and return a result to
the driver (your notebook) or write it to storage.

| Action | What it returns |
|---|---|
| `show(n)` | Prints first n rows to console |
| `count()` | Returns the number of rows (integer) |
| `collect()` | Returns ALL rows as a Python list (careful with large data!) |
| `first()` / `head()` | Returns the first row |
| `take(n)` | Returns first n rows as a list |
| `toPandas()` | Converts to a Pandas DataFrame (pulls all data to driver) |
| `write.save()` | Writes data to storage (Delta, Parquet, CSV, etc.) |
| `foreach()` | Applies a function to each row |
| `describe()` | Returns summary statistics |

### Key Rule

**Each action triggers a separate job.** If you call `.show()` and then `.count()` on the
same DataFrame, Spark runs the entire plan **twice** (unless you cache -- see Section 8).

In [None]:
# -- Transformations vs Actions demo --

# TRANSFORMATION chain (nothing runs yet)
result = (
    df
    .filter(F.col("department") == "Engineering")   # transformation
    .select("name", "salary")                       # transformation
    .withColumn("bonus", F.col("salary") * 0.15)    # transformation
    .orderBy(F.col("salary").desc())                 # transformation
)

print("Plan built. No computation yet.")

# ACTION 1 -- triggers the full plan
print("\n--- Action: show() ---")
result.show()                                        # ACTION

# ACTION 2 -- triggers the full plan AGAIN (from scratch)
print(f"--- Action: count() = {result.count()} ---") # ACTION

# ACTION 3 -- collect all rows to driver as a Python list
rows = result.collect()                               # ACTION
print(f"--- Action: collect() returned {len(rows)} Python Row objects ---")
print(f"First row: {rows[0]}")

# ACTION 4 -- convert to Pandas (pulls everything to driver memory)
pdf = result.toPandas()                               # ACTION
print(f"\n--- Action: toPandas() ---")
print(pdf)

---
## 3. Narrow vs Wide Transformations

Not all transformations are equal. Spark classifies them by how much data needs to
move between machines.

### Narrow Transformations (No Shuffle)

Each output partition depends on **exactly one** input partition. Data stays on the same
machine. These are fast.

```
Partition 1 --> Partition 1
Partition 2 --> Partition 2
Partition 3 --> Partition 3
```

**Examples:** `select`, `filter`, `withColumn`, `map`, `flatMap`, `union`, `coalesce`

### Wide Transformations (Requires Shuffle)

Each output partition may depend on **multiple** input partitions. Data must be moved
across the network (shuffled). These are expensive.

```
Partition 1 ──┐
Partition 2 ──┼──> Partition A
Partition 3 ──┘
```

**Examples:** `groupBy`, `join`, `orderBy/sort`, `distinct`, `repartition`,
`reduceByKey`, `aggregateByKey`

### Why Does This Matter?

- A **shuffle** means Spark must write intermediate data to disk, send it over the
  network, and read it back. On large datasets this can take minutes.
- Spark creates a **new stage** at every shuffle boundary. More shuffles = more stages
  = more overhead.
- Reducing unnecessary shuffles is the single biggest performance optimisation you can
  make.

### Visual: Narrow vs Wide

```
NARROW (fast):                    WIDE (slow):
+--------+    +--------+         +--------+
| Part 1 | -> | Part 1 |         | Part 1 |---\
+--------+    +--------+         +--------+    \   +--------+
+--------+    +--------+         +--------+     +->| Part A |
| Part 2 | -> | Part 2 |         | Part 2 |---/   +--------+
+--------+    +--------+         +--------+
                                  Each output needs data from ALL inputs
```

In [None]:
# -- Narrow vs Wide -- see it in the execution plan --

# NARROW transformations only -- no shuffle.
narrow_result = df.filter(F.col("salary") > 80000).select("name", "salary")
print("=== NARROW plan (no Exchange node) ===")
narrow_result.explain()
# You will NOT see "Exchange" in the plan.

print("\n" + "=" * 70 + "\n")

# WIDE transformation -- groupBy triggers a shuffle.
wide_result = df.groupBy("department").agg(F.avg("salary").alias("avg_salary"))
print("=== WIDE plan (Exchange = shuffle) ===")
wide_result.explain()
# You WILL see "Exchange hashpartitioning" -- that is the shuffle.

print("\n" + "=" * 70 + "\n")

# WIDE -- orderBy also triggers a shuffle (global sort needs all data).
sorted_result = df.orderBy(F.col("salary").desc())
print("=== WIDE plan (orderBy = shuffle) ===")
sorted_result.explain()
# "Exchange rangepartitioning" = data is shuffled for the sort.

---
## 4. Shuffle -- The Most Expensive Operation in Spark

### What Happens During a Shuffle

1. **Map side (write):** Each executor writes its portion of data to local disk,
   organised by the target partition key (hash or range).
2. **Network transfer:** Data is sent from every mapper to every reducer over the
   network.
3. **Reduce side (read):** Each reducer reads the pieces it needs from all mappers and
   combines them.

### Shuffle is expensive because:

- **Disk I/O:** Intermediate data is serialised and written to disk.
- **Network I/O:** Data crosses the network; bandwidth is limited.
- **Serialisation:** Data is converted from internal format to bytes and back.
- **Memory pressure:** Buffers fill up, causing spills to disk.

### When does Shuffle happen?

| Operation | Shuffle? | Why |
|---|---|---|
| `filter`, `select`, `withColumn` | No | Each row is independent |
| `groupBy().agg()` | Yes | Rows with the same key must be on the same machine |
| `join` (sort-merge) | Yes | Matching keys need to be co-located |
| `join` (broadcast) | No | Small table is copied to all nodes |
| `orderBy` / `sort` | Yes | Global ordering requires all data in one place |
| `distinct` | Yes | Must compare all rows |
| `repartition(n)` | Yes | Explicitly redistributes data |
| `coalesce(n)` | No | Only merges existing partitions (no data movement) |

### How to Reduce Shuffle

1. **Filter early** -- reduce data volume before a groupBy or join.
2. **Broadcast small tables** -- avoids shuffle in joins.
3. **Use coalesce instead of repartition** when reducing partition count.
4. **Pre-partition data on join/group keys** when writing to storage.
5. **Avoid unnecessary sorts** -- use `orderBy` only when the consumer needs it.

In [None]:
# -- Seeing shuffle in action --

# Query 1: groupBy causes a shuffle
grouped = df.groupBy("department").agg(
    F.count("*").alias("count"),
    F.round(F.avg("salary"), 2).alias("avg_salary")
)
print("=== groupBy plan -- look for 'Exchange' ===")
grouped.explain()
grouped.show()

# Query 2: join causes a shuffle (sort-merge join by default)
dept_data = [("Engineering", "Building A"), ("Marketing", "Building B"),
             ("HR", "Building C"), ("Sales", "Building D")]
df_dept = spark.createDataFrame(dept_data, ["department", "location"])

joined = df.join(df_dept, on="department", how="inner")
print("\n=== join plan -- TWO Exchanges (one per side) ===")
joined.explain()

# Query 3: broadcast join -- NO shuffle for the small table
joined_broadcast = df.join(F.broadcast(df_dept), on="department", how="inner")
print("\n=== broadcast join plan -- only BroadcastHashJoin, no Exchange ===")
joined_broadcast.explain()
joined_broadcast.show()

---
## 5. DAG, Jobs, Stages, and Tasks -- How Spark Executes Your Code

### The Execution Hierarchy

```
YOUR CODE
   |
   v
 Logical Plan  (what you want)
   |
   v  [Catalyst Optimizer]
 Optimised Logical Plan
   |
   v
 Physical Plan (how to do it)
   |
   v
 JOB  (triggered by each action: .show(), .count(), .write)
   |
   +-- STAGE 1  (runs until a shuffle boundary)
   |     +-- Task 1.1  (one per partition)
   |     +-- Task 1.2
   |     +-- Task 1.3
   |
   +-- STAGE 2  (starts after shuffle from Stage 1)
         +-- Task 2.1
         +-- Task 2.2
```

### Definitions

| Term | What it is |
|---|---|
| **DAG** (Directed Acyclic Graph) | The full graph of transformations from source to action. Spark builds this from your code. |
| **Job** | One unit of work triggered by an action. Each `.show()`, `.count()`, or `.write()` creates one job. |
| **Stage** | A subset of the job that can run without a shuffle. Stages are separated by shuffle (Exchange) boundaries. |
| **Task** | The smallest unit of work. One task processes one partition in one stage. If a stage has 200 partitions, it has 200 tasks. |

### How to see Jobs and Stages

In Databricks, open the **Spark UI** (link at the top of any cell output). You will see:
- **Jobs tab:** One row per action.
- **Stages tab:** Shows how many stages each job has and where the shuffle boundaries are.
- **Tasks:** Inside each stage, you can see how many tasks ran, their duration, and
  whether any were slow (skewed).

### Practical Impact

- **Too few partitions (e.g., 1):** Only 1 task runs -- no parallelism.
- **Too many partitions (e.g., 10,000 for 100 rows):** Excessive overhead scheduling
  10,000 tiny tasks.
- **Rule of thumb:** 2-4 partitions per CPU core, each partition 100-200 MB.

In [None]:
# -- Inspecting DAG, partitions, and stages --

# How many partitions does our DataFrame have?
print(f"Partitions in df: {df.rdd.getNumPartitions()}")

# A query with TWO stages (one shuffle from groupBy):
two_stage_query = (
    df
    .filter(F.col("salary") > 70000)                          # Stage 1: narrow
    .groupBy("department")                                     # shuffle boundary
    .agg(F.sum("salary").alias("total"), F.count("*").alias("n"))  # Stage 2: aggregate
)

print("=== Physical plan -- each Exchange = new stage ===")
two_stage_query.explain(mode="formatted")

# A query with THREE stages (two shuffles: groupBy + orderBy):
three_stage_query = (
    df
    .groupBy("department")                                     # shuffle 1
    .agg(F.avg("salary").alias("avg_sal"))
    .orderBy(F.col("avg_sal").desc())                          # shuffle 2
)

print("\n=== Three-stage plan ===")
three_stage_query.explain(mode="formatted")
three_stage_query.show()

---
## 6. Catalyst Optimizer and Tungsten Engine

### Catalyst Optimizer -- Rewrites Your Query

When you write PySpark code, it goes through four stages of optimisation:

```
1. PARSED Logical Plan     -- raw translation of your code
       |
2. ANALYZED Logical Plan   -- columns and types resolved against the catalog
       |
3. OPTIMIZED Logical Plan  -- rules applied to make it faster
       |
4. PHYSICAL Plan           -- the actual execution strategy chosen
```

### Common Catalyst Optimisation Rules

| Rule | What it does | Example |
|---|---|---|
| **Predicate Pushdown** | Moves filters as close to the data source as possible | If you filter after a join, Catalyst pushes the filter *before* the join to reduce data early |
| **Column Pruning** | Reads only the columns you actually use | You select 3 columns from a 200-column Parquet file -- Spark skips the other 197 |
| **Constant Folding** | Pre-computes constant expressions | `F.lit(2) * F.lit(3)` becomes `F.lit(6)` at plan time |
| **Join Reordering** | Reorders joins to minimise intermediate data | Puts the smaller table on the build side of a hash join |
| **Broadcast Detection** | Auto-broadcasts small tables | If a table is below `spark.sql.autoBroadcastJoinThreshold` (default 10 MB), Spark broadcasts it |

### Tungsten Engine -- How Data Is Stored in Memory

Tungsten manages memory at the byte level instead of using Java objects. This means:
- **Less garbage collection** -- fewer Java objects to track.
- **Cache-friendly layouts** -- data is stored in contiguous memory for CPU efficiency.
- **Code generation** -- Spark generates Java bytecode at runtime for your specific
  query, avoiding the overhead of generic interpreters.

You do not need to configure Tungsten -- it is always on. But knowing it exists helps
you understand why PySpark is faster than naive Python code.

In [None]:
# -- See Catalyst in action: Predicate Pushdown --

# Write a query where filter comes AFTER the join.
# Catalyst will push the filter BEFORE the join automatically.

result = (
    df.join(df_dept, on="department")
    .filter(F.col("salary") > 90000)          # written after join
    .select("name", "department", "salary", "location")
)

print("=== Extended plan -- watch 'PushedFilters' and filter position ===")
result.explain(mode="extended")
# In the optimised plan, you will see the Filter node moved below the Join.

print("\n" + "=" * 70)

# -- Column pruning --
# We only select 2 columns. Spark will not read the rest from a Parquet/Delta source.
pruned = df.select("name", "salary")
print("\n=== Column Pruning ===")
pruned.explain()
# On a file-based source you would see: ReadSchema with only name and salary.

print("\n" + "=" * 70)

# -- Compare all four plan stages --
print("\n=== All four plan stages (parsed -> analyzed -> optimized -> physical) ===")
(
    df
    .filter(F.col("department") == "Engineering")
    .groupBy("department")
    .agg(F.avg("salary").alias("avg_sal"))
).explain(mode="extended")

---
## 7. Reading Explain Plans -- The Most Useful Debugging Skill

### How to Read an Explain Plan

Plans are printed **bottom to top** -- the bottom is where Spark starts (reading data),
and the top is where Spark finishes (returning results).

### Key Nodes to Look For

| Node Name | What It Means |
|---|---|
| `Scan` / `FileScan` | Reading data from a source (Delta, Parquet, CSV, memory) |
| `Filter` | Applying a WHERE condition |
| `Project` | Selecting or computing columns |
| `Exchange` | **SHUFFLE** -- data is redistributed across partitions. This is the expensive one. |
| `HashAggregate` | Computing aggregations (sum, avg, count, etc.) |
| `SortMergeJoin` | Joining two large tables (both sides shuffled and sorted) |
| `BroadcastHashJoin` | Joining with a small table broadcast to all nodes (fast, no shuffle) |
| `BroadcastExchange` | Sending the small table to all executors |
| `Sort` | Sorting data |
| `WholeStageCodegen` | Tungsten has generated optimised bytecode for this section |

### explain() Modes

| Mode | What it shows |
|---|---|
| `explain()` or `explain(False)` | Physical plan only (most common) |
| `explain(True)` or `explain(mode="extended")` | All four plans: parsed, analyzed, optimised, physical |
| `explain(mode="formatted")` | Nicely formatted physical plan (best for Databricks) |
| `explain(mode="cost")` | Plan with estimated row counts and sizes |
| `explain(mode="codegen")` | The generated Java code (advanced) |

### What to Look For When Debugging Slow Queries

1. **How many Exchange nodes?** Each one is a shuffle. Can you eliminate any?
2. **SortMergeJoin vs BroadcastHashJoin?** If one side is small, broadcast it.
3. **Filter position:** Is it before or after the join? It should be before.
4. **Scan:** How many columns are being read? Are partition filters being applied?

In [None]:
# -- Reading explain plans: a guided example --

# A realistic query: "Average salary by department for employees earning > 70K,
# enriched with department location."
query = (
    df
    .filter(F.col("salary") > 70000)                           # filter early
    .join(F.broadcast(df_dept), on="department", how="inner")   # broadcast the small table
    .groupBy("department", "location")
    .agg(
        F.round(F.avg("salary"), 2).alias("avg_salary"),
        F.count("*").alias("emp_count")
    )
    .orderBy(F.col("avg_salary").desc())
)

print("=== FORMATTED plan -- read bottom to top ===")
query.explain(mode="formatted")

# What you should see (bottom to top):
# 1. Scan (read df from memory)
# 2. Filter (salary > 70000)           <-- Catalyst pushed this before the join
# 3. BroadcastExchange (send df_dept to all nodes)
# 4. BroadcastHashJoin                 <-- no shuffle for the join!
# 5. HashAggregate (partial)           <-- pre-aggregation on each partition
# 6. Exchange hashpartitioning         <-- shuffle to group by department
# 7. HashAggregate (final)             <-- final aggregation
# 8. Exchange rangepartitioning        <-- shuffle for orderBy
# 9. Sort                              <-- final ordering

print("\n=== RESULT ===")
query.show()

---
## 8. Caching and Persistence -- When to Keep Data in Memory

### The Problem

Every time you call an action (`.show()`, `.count()`, `.write()`), Spark re-executes the
entire plan from scratch -- reading data, applying filters, joins, etc. If you use the
same transformed DataFrame in 5 different places, Spark does the work **5 times**.

### The Solution: cache() and persist()

Calling `.cache()` tells Spark: "After computing this DataFrame for the first time, keep
the result in memory so future actions can reuse it instantly."

### cache() vs persist()

| Method | Storage Level | When to Use |
|---|---|---|
| `df.cache()` | `MEMORY_AND_DISK` | Default choice. Stores in memory; spills to disk if memory is full. |
| `df.persist(StorageLevel.MEMORY_ONLY)` | Memory only | When you have enough memory and want maximum speed. If it does not fit, partitions are recomputed. |
| `df.persist(StorageLevel.DISK_ONLY)` | Disk only | When data is too large for memory but you still want to avoid recomputation. |
| `df.persist(StorageLevel.MEMORY_AND_DISK)` | Same as cache() | Explicit equivalent of `.cache()`. |
| `df.persist(StorageLevel.MEMORY_AND_DISK_SER)` | Serialised | Uses less memory (compressed) but slower to read back. |

### When to Cache

- You use the same DataFrame in **multiple actions** (e.g., show + count + write).
- The DataFrame involves **expensive transformations** (joins, aggregations, UDFs).
- The DataFrame is used in a **loop** (e.g., iterative ML algorithms).

### When NOT to Cache

- The DataFrame is used **only once** -- caching adds overhead for no benefit.
- The data is **too large** to fit in memory -- you get spills and slowdowns.
- You are reading from a **Delta table with aggressive caching in the storage layer**
  (Databricks already caches Delta files via Delta Cache).

### Important: Cache is Lazy

`.cache()` itself does not trigger computation. The data is only cached when the **first
action** runs. After that, subsequent actions reuse the cached data.

### Cleanup with unpersist()

Always call `df.unpersist()` when you are done with a cached DataFrame. Cached data
occupies cluster memory that other queries could use.

In [None]:
from pyspark import StorageLevel

# -- WITHOUT caching: plan runs twice --
expensive_df = (
    df
    .filter(F.col("salary") > 70000)
    .join(F.broadcast(df_dept), on="department")
    .withColumn("bonus", F.col("salary") * 0.1)
)

print("=== Without cache: each action recomputes ===")
expensive_df.show(3)       # computes the full plan
print(f"Count: {expensive_df.count()}")  # computes the full plan AGAIN

# -- WITH caching: plan runs once, result is reused --
expensive_df.cache()       # mark for caching (lazy -- nothing happens yet)
expensive_df.count()       # first action: computes AND caches
print(f"Is cached: {expensive_df.is_cached}")

print("\n=== With cache: subsequent actions use cached data ===")
expensive_df.show(3)                                 # instant -- from cache
print(f"Count: {expensive_df.count()}")              # instant -- from cache
expensive_df.groupBy("department").count().show()     # instant -- from cache

# Check the explain plan -- you will see "InMemoryTableScan" instead of the full plan
print("=== Plan after caching (InMemoryTableScan) ===")
expensive_df.explain()

# -- Cleanup --
expensive_df.unpersist()
print(f"\nAfter unpersist, is cached: {expensive_df.is_cached}")

In [None]:
# -- persist() with different storage levels --

# MEMORY_ONLY -- fastest, but if data does not fit, partitions are recomputed
df_mem = df.persist(StorageLevel.MEMORY_ONLY)
df_mem.count()
print(f"MEMORY_ONLY cached: {df_mem.is_cached}")
df_mem.unpersist()

# DISK_ONLY -- slowest, but useful for very large intermediate results
df_disk = df.persist(StorageLevel.DISK_ONLY)
df_disk.count()
print(f"DISK_ONLY cached:   {df_disk.is_cached}")
df_disk.unpersist()

# MEMORY_AND_DISK_SER -- serialised (compressed), uses less memory, slower to read
df_ser = df.persist(StorageLevel.MEMORY_AND_DISK_SER)
df_ser.count()
print(f"MEMORY_AND_DISK_SER cached: {df_ser.is_cached}")
df_ser.unpersist()

# -- Check what is currently cached in the cluster --
print("\nStorage info:")
for rdd_info in spark.sparkContext._jsc.sc().getRDDStorageInfo():
    print(f"  {rdd_info}")

---
## 9. Broadcast Joins and Broadcast Variables

### The Problem with Regular Joins

A regular join (SortMergeJoin) shuffles **both** tables so that rows with the same key
land on the same partition. If one table has 1 billion rows and the other has 100 rows,
shuffling the 100-row table is wasteful.

### Broadcast Join -- The Solution for Small Tables

Instead of shuffling, Spark copies (broadcasts) the small table to every executor's
memory. Each executor then joins its partition of the large table with the local copy
of the small table. **No shuffle on either side.**

```
SORT-MERGE JOIN (both shuffle):         BROADCAST JOIN (no shuffle):
Large Table  -->  shuffle  --|          Large Table  (stays in place)
                              +-join        +
Small Table  -->  shuffle  --|          Small Table  (copied to all nodes)
```

### When Does Spark Auto-Broadcast?

Spark broadcasts automatically when a table is smaller than
`spark.sql.autoBroadcastJoinThreshold` (default: **10 MB**).

You can also force it manually with `F.broadcast(df)`.

### Broadcast Variables (General Purpose)

A **broadcast variable** is not limited to joins. You can broadcast any Python object
(dictionary, list, model) to all executors for use inside UDFs or transformations.

```python
lookup = spark.sparkContext.broadcast({"key": "value"})
# Access in UDF: lookup.value["key"]
```

In [None]:
# -- Broadcast join vs sort-merge join comparison --

# Check the auto-broadcast threshold
threshold = spark.conf.get("spark.sql.autoBroadcastJoinThreshold")
print(f"Auto broadcast threshold: {threshold} bytes ({int(threshold) / 1024 / 1024:.0f} MB)")

# SORT-MERGE JOIN -- disable auto-broadcast to force it
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)  # disable auto-broadcast

smj = df.join(df_dept, on="department")
print("=== Sort-Merge Join (both sides shuffled) ===")
smj.explain()
# You will see: Exchange hashpartitioning on BOTH sides + SortMergeJoin

# BROADCAST JOIN -- explicitly broadcast the small table
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 10485760)  # restore default (10 MB)

bhj = df.join(F.broadcast(df_dept), on="department")
print("\n=== Broadcast Hash Join (no shuffle) ===")
bhj.explain()
# You will see: BroadcastExchange + BroadcastHashJoin , NO Exchange on df side

bhj.show()

# -- Broadcast variable example (for UDFs) --
# A lookup dictionary broadcast to all executors.
dept_codes = {"Engineering": "ENG", "Marketing": "MKT", "HR": "HR", "Sales": "SAL"}
broadcast_codes = spark.sparkContext.broadcast(dept_codes)

from pyspark.sql.functions import udf

@udf(StringType())
def get_dept_code(dept):
    return broadcast_codes.value.get(dept, "UNK")

df.select("name", "department", get_dept_code("department").alias("dept_code")).show()

---
## 10. Partitioning -- Controlling Data Distribution

### What is a Partition?

A partition is a chunk of data that one task processes. Spark splits your DataFrame into
multiple partitions and processes them in parallel across the cluster.

### In-Memory Partitioning (repartition vs coalesce)

| Method | What it does | Shuffle? |
|---|---|---|
| `repartition(n)` | Redistributes data into exactly `n` partitions | Yes (full shuffle) |
| `repartition("col")` | Redistributes by column hash (same key values go to same partition) | Yes |
| `coalesce(n)` | Reduces partitions by merging (only reduces, cannot increase) | No (just merges) |

### When to Repartition

- **Before a join or groupBy on a specific column:** Pre-partitioning on the join key
  avoids an extra shuffle during the join.
- **Before writing to storage:** Control the number of output files.
- **When partitions are very uneven (skewed).**

### On-Disk Partitioning (partitionBy in write)

When you write data to Delta/Parquet with `.partitionBy("column")`, Spark creates separate
folders for each unique value. When you later read with a filter on that column, Spark
skips entire folders -- this is called **partition pruning**.

```python
df.write.format("delta").partitionBy("department").save("/path")
# Creates folders: department=Engineering/, department=HR/, etc.
# Reading with filter(department == 'HR') only scans the HR folder.
```

### Choosing Partition Columns

- Pick a column you **frequently filter on** (date, region, department).
- The column should have **moderate cardinality** (10-1000 unique values).
  - Too few (2-3): Large partitions, no benefit.
  - Too many (millions): Millions of tiny files -- the "small files problem".

In [None]:
# -- Partitioning in practice --

# Check current partitions
print(f"Default partitions in df: {df.rdd.getNumPartitions()}")

# repartition(n) -- redistribute into n partitions (FULL SHUFFLE)
df_4 = df.repartition(4)
print(f"After repartition(4): {df_4.rdd.getNumPartitions()}")

# repartition by column -- same department values go to the same partition
df_by_dept = df.repartition("department")
print(f"After repartition('department'): {df_by_dept.rdd.getNumPartitions()}")

# See how data is distributed across partitions
from pyspark.sql.functions import spark_partition_id
df_by_dept.withColumn("partition_id", spark_partition_id()).groupBy("partition_id", "department").count().orderBy("partition_id").show()

# coalesce(n) -- reduce partitions WITHOUT a shuffle (only merges)
df_2 = df_4.coalesce(2)
print(f"After coalesce(2): {df_2.rdd.getNumPartitions()}")

# -- On-disk partitioning --
delta_path = "/tmp/concepts_tutorial/emp_partitioned"
df.write.format("delta").mode("overwrite").partitionBy("department").save(delta_path)

# When you filter by department, Spark only reads that folder (partition pruning)
print("\n=== Partition pruning -- only scans Engineering folder ===")
spark.read.format("delta").load(delta_path).filter(F.col("department") == "Engineering").explain()

# Cleanup
dbutils.fs.rm("/tmp/concepts_tutorial", recurse=True)

---
## 11. Adaptive Query Execution (AQE) -- Runtime Optimisation

### What is AQE?

AQE is a runtime optimisation framework that was introduced in Spark 3.0 and is
**enabled by default in Databricks**. Unlike Catalyst (which optimises at plan time),
AQE optimises **during execution** based on real statistics.

### What AQE Does

| Feature | What it does | Why it helps |
|---|---|---|
| **Coalescing shuffle partitions** | After a shuffle, AQE merges tiny partitions into larger ones | Reduces task scheduling overhead |
| **Converting joins** | Converts SortMergeJoin to BroadcastHashJoin at runtime if one side is small enough after filtering | Eliminates a shuffle mid-execution |
| **Skew join optimisation** | Detects partitions that are much larger than others and splits them | Prevents one task from running much longer than the rest |

### Example: Dynamic Partition Coalescing

Without AQE, the default shuffle partition count (`spark.sql.shuffle.partitions`) is 200.
If your data produces only 10 groups, 190 partitions are empty -- wasted overhead.
AQE detects this and merges the empty partitions into the 10 useful ones.

### Configuration

```python
# Check if AQE is enabled (it should be on Databricks)
spark.conf.get("spark.sql.adaptive.enabled")  # "true"

# Key AQE settings
spark.conf.get("spark.sql.adaptive.coalescePartitions.enabled")      # true
spark.conf.get("spark.sql.adaptive.skewJoin.enabled")                # true
spark.conf.get("spark.sql.adaptive.localShuffleReader.enabled")      # true
```

### Should You Do Anything?

On Databricks Serverless, AQE is fully managed. You rarely need to tune it. The main
scenario where you still intervene is **severe data skew** that AQE cannot fully resolve
-- in that case, use manual salting (Section 12).

In [None]:
# -- AQE configuration check --

print("=== Adaptive Query Execution Settings ===")
aqe_settings = [
    "spark.sql.adaptive.enabled",
    "spark.sql.adaptive.coalescePartitions.enabled",
    "spark.sql.adaptive.skewJoin.enabled",
    "spark.sql.adaptive.localShuffleReader.enabled",
    "spark.sql.shuffle.partitions",
    "spark.sql.autoBroadcastJoinThreshold",
]

for setting in aqe_settings:
    try:
        print(f"  {setting} = {spark.conf.get(setting)}")
    except Exception:
        print(f"  {setting} = (not set)")

# -- Demonstrate AQE coalescing partitions --
# Default shuffle.partitions is 200. With only 4 departments, 196 partitions are empty.
# AQE will coalesce them automatically.

print(f"\nDefault shuffle partitions: {spark.conf.get('spark.sql.shuffle.partitions')}")

grouped = df.groupBy("department").agg(F.sum("salary").alias("total"))
grouped.show()

# In the Spark UI, you will see that AQE reduced the actual partition count
# from 200 to something much smaller (typically 1-4 for this small dataset).
print("Check the Spark UI -> SQL tab to see AQE's CustomShuffleReader coalescing partitions.")

---
## 12. Data Skew and Salting -- Fixing the Most Common Performance Problem

### What is Data Skew?

Data skew means one partition has **much more data** than others. When you do a
`groupBy("customer_id")` or `join` on a skewed key, one task gets millions of rows while
others get thousands. The slow task becomes a bottleneck -- all other tasks finish quickly
but the job waits for the one slow task.

### How to Detect Skew

1. **Spark UI -> Stages -> Summary Metrics:** Look at the median vs max task duration.
   If the max is 10x+ the median, you have skew.
2. **Check value distribution:** `df.groupBy("key").count().orderBy(F.desc("count")).show()`

### Salting -- The Manual Fix

Salting adds a random number to the skewed key, spreading its rows across multiple
partitions. After the aggregation or join, you remove the salt.

```
Before salting:           After salting (salt 0-3):
Key "USA" -> 1M rows      Key "USA_0" -> 250K rows
                           Key "USA_1" -> 250K rows
                           Key "USA_2" -> 250K rows
                           Key "USA_3" -> 250K rows
```

### When to Use Salting vs AQE

- **Try AQE first** -- it handles moderate skew automatically.
- **Use salting** when AQE is not enough (extreme skew, e.g., one key has 100x more
  data than others).

In [None]:
# -- Detecting and fixing data skew --

# Create a skewed dataset: department "Engineering" has many more rows.
skewed_data = (
    [("Engineering", i * 1000) for i in range(100)] +   # 100 rows
    [("Marketing",   i * 1000) for i in range(10)]  +   # 10 rows
    [("HR",          i * 1000) for i in range(5)]    +   # 5 rows
    [("Sales",       i * 1000) for i in range(5)]        # 5 rows
)
df_skewed = spark.createDataFrame(skewed_data, ["department", "revenue"])

# Step 1: DETECT skew -- look at the distribution
print("=== Data distribution (skew detected!) ===")
df_skewed.groupBy("department").count().orderBy(F.desc("count")).show()

# Step 2: SALTING -- add a random salt to spread the skewed key
NUM_SALT_BUCKETS = 4

df_salted = df_skewed.withColumn(
    "salt", (F.rand() * NUM_SALT_BUCKETS).cast("int")
).withColumn(
    "salted_key", F.concat(F.col("department"), F.lit("_"), F.col("salt"))
)

# Aggregate on the salted key (distributes "Engineering" across 4 partitions)
partial_agg = df_salted.groupBy("salted_key", "department").agg(
    F.sum("revenue").alias("partial_sum"),
    F.count("*").alias("partial_count")
)
print("=== Partial aggregation on salted key ===")
partial_agg.show(truncate=False)

# Remove salt and do final aggregation
final_agg = partial_agg.groupBy("department").agg(
    F.sum("partial_sum").alias("total_revenue"),
    F.sum("partial_count").alias("total_count")
)
print("=== Final aggregation after removing salt ===")
final_agg.orderBy(F.desc("total_revenue")).show()

# Verify: same result without salting
print("=== Direct aggregation (for comparison) ===")
df_skewed.groupBy("department").agg(
    F.sum("revenue").alias("total_revenue"),
    F.count("*").alias("total_count")
).orderBy(F.desc("total_revenue")).show()

---
## 13. Best Practices Checklist -- Rules for Production Spark Code

### Transformations and Queries

| Rule | Why |
|---|---|
| Filter early, filter often | Reduces data volume before expensive operations |
| Select only the columns you need | Column pruning saves memory and I/O |
| Avoid `collect()` and `toPandas()` on large data | Pulls everything to the driver -- causes OOM |
| Use built-in functions instead of UDFs | Built-in functions run in the JVM; UDFs serialize data to Python and back |
| Use `F.col("x")` instead of `df.x` | `F.col()` is unambiguous, especially after joins with duplicate column names |

### Shuffle and Joins

| Rule | Why |
|---|---|
| Broadcast small tables (`F.broadcast(df)`) | Eliminates shuffle in joins |
| Use `coalesce(n)` to reduce partitions, not `repartition(n)` | Coalesce avoids a full shuffle |
| Pre-partition data on join keys when writing | Future joins on the same key avoid a shuffle |
| Avoid unnecessary `orderBy` | Global sort requires a shuffle; use it only when the consumer needs sorted output |

### Caching

| Rule | Why |
|---|---|
| Cache only if a DataFrame is used in multiple actions | Caching a one-time DataFrame wastes memory |
| Always `unpersist()` when done | Frees cluster memory for other work |
| Prefer Delta Cache over manual caching on Databricks | Delta Cache is automatic and disk-based |

### Storage

| Rule | Why |
|---|---|
| Use Delta format for all production tables | ACID transactions, time travel, MERGE support |
| Partition on-disk by a column you filter on frequently (date, region) | Enables partition pruning |
| Avoid too many partitions (small files problem) | Millions of tiny files slow down reads |
| Use Z-ORDER for multi-column filter optimisation (Databricks) | Co-locates related data within files |

### Debugging

| Rule | Why |
|---|---|
| Use `explain(mode="formatted")` to check plans | Catch unnecessary shuffles and missing pushdowns |
| Check the Spark UI for skew (max vs median task time) | Skew is the number one cause of slow jobs |
| Monitor `Exchange` nodes in the plan | Each Exchange is a shuffle -- minimise them |

In [None]:
# -- Best practices applied: GOOD vs BAD examples --

# ─── BAD: filter late, select all columns, unnecessary sort ─────────
bad_query = (
    df
    .join(df_dept, on="department")                   # join ALL rows first
    .orderBy("salary")                                 # unnecessary sort (adds shuffle)
    .filter(F.col("department") == "Engineering")     # filter AFTER join + sort
    .select("name", "salary")                         # select late
)
print("=== BAD query plan (filter late, unnecessary sort) ===")
bad_query.explain()

# ─── GOOD: filter early, select only what is needed, broadcast ──────
good_query = (
    df
    .filter(F.col("department") == "Engineering")     # filter FIRST (reduces rows)
    .select("name", "department", "salary")           # select early (column pruning)
    .join(F.broadcast(df_dept), on="department")      # broadcast the small table
)
print("\n=== GOOD query plan (filter early, broadcast, no unnecessary sort) ===")
good_query.explain()

# Both produce the same result:
print("\n=== BAD result ===")
bad_query.show()
print("=== GOOD result ===")
good_query.show()

---
## Summary -- Concept Map

```
YOUR PYSPARK CODE
       |
       v
  Transformations (lazy)          Actions (eager)
  - select, filter, withColumn    - show, count, collect
  - join, groupBy, orderBy        - write, toPandas, foreach
       |                                |
       v                                v
  LOGICAL PLAN (what you want)    triggers EXECUTION
       |
       v  [Catalyst Optimizer]
  OPTIMISED PLAN
  - Predicate pushdown
  - Column pruning
  - Join reordering
  - Constant folding
       |
       v
  PHYSICAL PLAN
  - Narrow transforms --> same partition (fast)
  - Wide transforms   --> SHUFFLE (Exchange) (slow)
       |
       v
  DAG --> Jobs --> Stages (split at shuffle) --> Tasks (one per partition)
       |
       v  [Tungsten Engine]
  EXECUTION with code generation and memory management
       |
       v  [AQE -- runtime optimisation]
  - Coalesce shuffle partitions
  - Convert SortMergeJoin to BroadcastHashJoin
  - Handle skew
       |
       v
  RESULT (or written to Delta/Parquet)
```

### Key Takeaways

1. **Lazy evaluation** lets Spark optimise before running -- do not fight it.
2. **Transformations** build the plan; **actions** execute it.
3. **Shuffle (Exchange)** is the most expensive operation -- minimise it.
4. **Broadcast** small tables to eliminate shuffle in joins.
5. **Cache** DataFrames you reuse in multiple actions; **unpersist** when done.
6. **explain()** is your best debugging tool -- read plans bottom to top.
7. **AQE** handles most runtime optimisations automatically on Databricks.
8. **Salting** fixes extreme data skew that AQE cannot resolve.
9. **Filter early, select only needed columns, use built-in functions over UDFs.**
10. **Use Delta format** for production storage -- it adds ACID, time travel, and MERGE.

---
*This notebook is compatible with Databricks Community Edition (Free) and Databricks Serverless Compute.*