### Objective

In this notebook, learners will understand the concepts of **caching** and **broadcast variables** in PySpark.

They will:

- Create and explore large and small DataFrames
- Use caching to avoid repeated computation
- Use broadcast joins to optimize performance

These optimizations are essential for working efficiently with large-scale distributed data.

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, broadcast

# Create a Spark session â€” the entry point to use PySpark
spark = SparkSession.builder \
    .appName("CachingAndBroadcasting") \
    .getOrCreate()

In [None]:
# Generate a large dataset: 1 million sales records
# Each record contains: sale_id, user_id (user_0 to user_99), state_code (0 to 49)
sales_data = [(i, f"user_{i%100}", i % 50) for i in range(1_000_000)]

# Create a DataFrame from the sales data
sales_df = spark.createDataFrame(sales_data, ["sale_id", "user_id", "state_code"])

# Create a small lookup DataFrame for state codes
states = [
    (0, "Andhra Pradesh"),
    (1, "Bihar"),
    (2, "Delhi"),
    (3, "Karnataka"),
    (4, "Maharashtra")
] + [(i, f"State_{i}") for i in range(5, 50)]  # Generic names for remaining codes

states_df = spark.createDataFrame(states, ["state_code", "state_name"])

In [None]:
# Caching Example

# Filter the large dataset to only include state codes less than 10
# This simulates a heavy computation
filtered_sales = sales_df.filter(col("state_code") < 10)

# Cache the filtered DataFrame in memory
# This means Spark will store the result after the first computation
filtered_sales.cache()

# First action: group by state_code and count sales
# This triggers computation and caches the result
filtered_sales.groupBy("state_code").count().show()

+----------+-----+
|state_code|count|
+----------+-----+
|         0|20000|
|         7|20000|
|         6|20000|
|         9|20000|
|         5|20000|
|         1|20000|
|         3|20000|
|         8|20000|
|         2|20000|
|         4|20000|
+----------+-----+



In [None]:
# Second action: count distinct user IDs
# Since the DataFrame is cached, Spark reuses the stored result
filtered_sales.select("user_id").distinct().count()

20

### What Just Happened?

- We filtered the DataFrame on `state_code < 10`, simulating a time-consuming operation.
- By using `.cache()`, we stored this result in memory.
- The **first action** (`groupBy + count`) triggered the computation and cached the data.
- The **second action** reused the cached result without repeating the computation.

**Why use caching?**
- It reduces execution time when reusing the same filtered or transformed data multiple times.

In [None]:
# Broadcast Example
# Regular join: joins sales_df with states_df using a shuffle
joined_normal = sales_df.join(states_df, on="state_code")
joined_normal.select("sale_id", "state_name").show(5)

+-------+----------+
|sale_id|state_name|
+-------+----------+
|     26|  State_26|
|     29|  State_29|
|     76|  State_26|
|     79|  State_29|
|    126|  State_26|
+-------+----------+
only showing top 5 rows



In [None]:
# Optimized join: broadcast the small lookup table
# Spark sends a copy of 'states_df' to all executor nodes
joined_broadcast = sales_df.join(broadcast(states_df), on="state_code")
joined_broadcast.select("sale_id", "state_name").show(5)

+-------+--------------+
|sale_id|    state_name|
+-------+--------------+
|      0|Andhra Pradesh|
|      1|         Bihar|
|      2|         Delhi|
|      3|     Karnataka|
|      4|   Maharashtra|
+-------+--------------+
only showing top 5 rows



### What Just Happened?

- In the **normal join**, Spark performs a full shuffle of `states_df` across partitions.
- In the **broadcast join**, Spark sends the small DataFrame (`states_df`) to each executor.

**Why broadcast?**
- It avoids network shuffling.
- Great for joining large DataFrames with smaller lookup tables.
- Improves performance significantly in distributed environments.


### Summary

In this notebook, you learned:

- How to use **caching** to store intermediate results in memory
- How to use **broadcast variables** to optimize joins with small DataFrames

These techniques are essential for building efficient data pipelines with PySpark.
