In [None]:
1) Unsafe program that often causes OOM

This intentionally creates a large “large_df” and a non-broadcast join with a small-ish lookup_df — plus a .collect() 
at the end to force driver OOM. On a modest cluster this will either explode executors (shuffle OOM) or the driver (collect OOM).

In [None]:
# bad_program.py
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit

spark = SparkSession.builder \
    .appName("OOMDemoBad") \
    .config("spark.executor.memory", "2g") \
    .config("spark.driver.memory", "1g") \
    .getOrCreate()

# Simulate a large dataset (replace range size to exceed your memory)
large_df = spark.range(0, 50_000_000).withColumnRenamed("id", "user_id")  # 50M rows

# Small lookup but still possibly > broadcast threshold in some environments
lookup_data = [(i, f"category_{i%100}") for i in range(100000)]  # 100k rows
lookup_df = spark.createDataFrame(lookup_data, ["user_id", "category"])

# This join will cause a large shuffle (SortMergeJoin) and possibly OOM
joined = large_df.join(lookup_df, on="user_id", how="inner")

# Force everything to the driver -> driver OOM for large outputs
result = joined.collect()   # DANGEROUS: collects entire joined result to driver

print("Rows:", len(result))
spark.stop()


In [None]:
# 2) Fix A — Broadcast the small table (best when one side is small)
# fix_broadcast.py
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast

spark = SparkSession.builder.appName("OOMFixBroadcast").getOrCreate()

large_df = spark.range(0, 50_000_000).withColumnRenamed("id", "user_id")
lookup_df = spark.read.parquet("/path/to/lookup_small.parquet")  # or as created earlier

# Broadcast the small DataFrame
joined = large_df.join(broadcast(lookup_df), on="user_id", how="inner")

# Do NOT collect results; write to disk instead
joined.write.mode("overwrite").parquet("/tmp/joined_output/")

spark.stop()

In [None]:
3) Fix B — Repartition to increase parallelism and reduce partition size
# fix_repartition.py
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("OOMFixRepartition").getOrCreate()
spark.conf.set("spark.sql.shuffle.partitions", 400)   # increase shuffle parallelism

large_df = spark.range(0, 50_000_000).withColumnRenamed("id", "user_id")

# Repartition large_df before join so per-partition memory is lower
large_df = large_df.repartition(400, "user_id")

lookup_df = spark.createDataFrame([(i, f"category_{i%100}") for i in range(100000)],
                                  ["user_id", "category"])

joined = large_df.join(lookup_df, on="user_id", how="inner")

# Write out rather than collect
joined.write.mode("overwrite").parquet("/tmp/joined_output_repart/")
spark.stop()


In [None]:
# 4) Fix C — Persist with MEMORY_AND_DISK and unpersist when done
# fix_persist.py
from pyspark import StorageLevel
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("OOMFixPersist").getOrCreate()

df = spark.range(0, 50_000_000).withColumnRenamed("id", "user_id")
# some transformations
df = df.withColumn("flag", (df.user_id % 2 == 0).cast("int"))

# Persist to memory+disk
df.persist(StorageLevel.MEMORY_AND_DISK)

# Trigger materialization once
df.count()

# Use df multiple times safely (will spill to disk, avoiding OOM)
# ... do joins, aggregations ...
df.unpersist()
spark.stop()

In [None]:
8) How to verify the fix (what to check)

Spark UI (http://<driver-host>:4040)

Storage tab: cached RDD/DataFrame sizes
Stages tab: shuffle read/write sizes; long tasks
Executors tab: memory usage, GC time
Logs: look for OutOfMemoryError, TaskKilled due to executor loss, or lots of spill to disk messages.
Execution plan: df.explain(True) — check for BroadcastHashJoin vs SortMergeJoin.
Time/size metrics: successful runs should show lower shuffle/write/read per task.