# PySpark: Zero to Hero
## Module 25: Skewness, Spillage, and Salting

In distributed computing, **Data Skew** is one of the most common performance bottlenecks. It occurs when data is not evenly distributed across partitions, causing a few tasks to take significantly longer than others (stragglers).

This often leads to **Spillage**, where the data for a single task exceeds the memory allocated to the executor, forcing Spark to write intermediate data to disk. This serialization/deserialization process drastically slows down the job.

### Agenda:
1.  **Understanding Skewness:** How it looks in the Spark UI.
2.  **Spillage:** Memory vs. Disk spillage.
3.  **Diagnosis:** finding the skewed partition using `spark_partition_id()`.
4.  **The Solution:** Implementing the **Salting Technique** to redistribute data evenly.

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
import random

# We initiate the session with specific memory configs to easily reproduce spillage on smaller data
spark = SparkSession.builder \
    .appName("Skewness_and_Salting") \
    .master("local[*]") \
    .config("spark.executor.memory", "512m") \
    .config("spark.sql.adaptive.enabled", "false") \
    .config("spark.sql.shuffle.partitions", "200") \
    .getOrCreate()

# Note: We disable Adaptive Query Execution (AQE) because AQE can automatically 
# handle skew joins in newer Spark versions. We want to see the problem manually first.

print("Spark Session Created")

In [None]:
# In the video, we used external CSVs. Here, let's simulate skewed data 
# so you can run this notebook immediately.

# 1. Create Department Data (Small lookup table)
dept_data = [(i, f"Dept_{i}") for i in range(0, 10)]
dept_df = spark.createDataFrame(dept_data, ["dept_id", "dept_name"])

# 2. Create Skewed Employee Data
# We will generate 1 million records. 
# 90% of employees will belong to dept_id 9 and 10 (Creating Skew)
def generate_skewed_data():
    data = []
    for _ in range(1000000):
        # High probability for dept 9 and 8
        if random.random() > 0.1:
            dept = random.choice([8, 9])
        else:
            dept = random.choice(range(0, 8))
        data.append((dept, "Emp_Name"))
    return data

# Note: In a real scenario, this step involves reading a large file.
# Generating 1M rows in local python might take a moment.
print("Generating Skewed Data...")
emp_rdd = spark.sparkContext.parallelize(generate_skewed_data())
emp_df = spark.createDataFrame(emp_rdd, ["dept_id", "emp_name"])

print("Dataframes Ready.")

In [None]:
# Let's perform a standard join.
# Because dept_id 8 and 9 have massive amounts of data, the tasks processing 
# those specific keys will process significantly more records than others.

df_joined = emp_df.join(dept_df, "dept_id", "left_outer")

# Trigger the action (No-Op write is used for benchmarking)
print("Running Skewed Join...")
df_joined.write.format("noop").mode("overwrite").save()
print("Skewed Join Completed.")

## 2. Diagnosing the Issue

If you look at the **Spark UI** (Stages tab) for the job above:
1.  **Event Timeline:** You will see most green bars (tasks) finish instantly, but one or two bars stretch out for much longer.
2.  **Summary Metrics:** The *Max* duration will be much higher than *Median* or *75th percentile*.
3.  **Spill (Memory/Disk):** You might see columns for "Spill (Memory)" and "Spill (Disk)". This means the data didn't fit in RAM.

### Verifying Partition Distribution via Code
We can use the `spark_partition_id()` function to group by partition and count records.

In [None]:
# Analyze the distribution of data across partitions after shuffle
skew_analysis = df_joined \
    .withColumn("partition_num", spark_partition_id()) \
    .groupBy("partition_num") \
    .count() \
    .orderBy(col("count").desc())

print("Top 5 Heaviest Partitions:")
skew_analysis.show(5)

# You will likely see 1 or 2 partitions with huge counts (e.g., 400k+) 
# while others have very few.

## 3. The Solution: Salting

**Salting** involves adding a random number (the "salt") to the join keys of the skewed dataset (Employee) and replicating the rows of the small dataset (Department) to match those salts.

**Logic:**
1.  **Salt Factor:** Decide a number, e.g., 16 (matches your core count or a multiple of it).
2.  **Small Table:** Cross Join (multiply) the department table with numbers 0-15. Create a new key `dept_id_salt` (e.g., `8_0`, `8_1`... `8_15`).
3.  **Large Table:** Add a *random* number 0-15 to every employee row. Create a new key `dept_id_salt` (e.g., `8_3`).
4.  **Join:** Join on the new composite key `dept_id_salt`.

This breaks the massive "bucket" for Department 8 into 16 smaller buckets.

In [None]:
SALT_FACTOR = 16  # Splitting the skewed key into 16 parts

# 1. Create a DataFrame containing the range of salts
salt_df = spark.range(0, SALT_FACTOR).toDF("salt_id")

# 2. Cross Join Department with Salt Range
# If Dept table has 10 rows, it will now have 10 * 16 = 160 rows
salted_dept_df = dept_df.crossJoin(salt_df) \
    .withColumn("salted_dept_id", concat(col("dept_id"), lit("_"), col("salt_id"))) \
    .drop("salt_id")

print("Salted Department Data Sample:")
salted_dept_df.show(5)

In [None]:
# 1. Add a random salt (0 to 15) to every record in the large skewed table
salted_emp_df = emp_df \
    .withColumn("salt_id", (rand() * SALT_FACTOR).cast("int")) \
    .withColumn("salted_dept_id", concat(col("dept_id"), lit("_"), col("salt_id"))) \
    .drop("salt_id")

print("Salted Employee Data Sample:")
salted_emp_df.show(5)

In [None]:
# Now we join on the new 'salted_dept_id'. 
# Since the data for dept 8 is now broken into 16 chunks (8_0 to 8_15), 
# Spark can process these chunks in parallel tasks!

salted_joined_df = salted_emp_df.join(
    salted_dept_df, 
    on="salted_dept_id", 
    how="left_outer"
)

print("Running Salted Join...")
# Trigger action
salted_joined_df.write.format("noop").mode("overwrite").save()
print("Salted Join Completed.")

In [None]:
# Let's verify if the data is distributed more evenly now.
salted_analysis = salted_joined_df \
    .withColumn("partition_num", spark_partition_id()) \
    .groupBy("partition_num") \
    .count() \
    .orderBy(col("count").desc())

print("Top 5 Heaviest Partitions (After Salting):")
salted_analysis.show(5)

# You should see the counts are much lower and closer to each other 
# compared to the static join.

## Summary

1.  **Skewness:** Occurs when keys are not unique or data is concentrated on specific keys.
2.  **Spillage:** When memory fills up, Spark writes to disk, causing severe performance hits.
3.  **Salting:**
    *   **Pros:** Ideally distributes skewed data, eliminates spillage, allows parallel processing of stragglers.
    *   **Cons:** Increases the size of the smaller table (replication). Be careful if the "small" table is actually quite large.
4.  **AQE:** In Spark 3.0+, enabling `spark.sql.adaptive.enabled` can often handle skew automatically without needing manual salting code. However, knowing Salting is vital for edge cases or older Spark versions.