### Data Skew

Its one of the most common causes of data shuffle perfomance issues, it occurs when some partitions have much more data than others
leading to slow-running tasks, disk spills and Out-Of-Memory errors

Partitions are chunks of data distributed and processed by worker nodes

<img src="../img/spark_cluster.png" alt="Cluster" width="650">

- Correct Data Skew:
    - Repartition: `df = df.partition(<number-of-partitions-OR-list-of-columns>)`
    - Add a "salt" column

```Python
# Add a 'salt' column with a random value for each row
import pyspark.sql.functions as F

df = df.withColumn("salt", F.rand())

# Replartition the DataFrame into 8 partitions based on the 'salt' column
df = df.repartition(8, "salt")
```

In [2]:
# Imports
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

In [3]:
# Step 1: Set up PySpark Session

# Initialize Spark session
spark = SparkSession.builder.appName("DataSkewHandling").getOrCreate()

# Set log level to ERROR to reduce verbosity
spark.sparkContext.setLogLevel("ERROR")

In [4]:
# Step 2: Create a Sample DataFrame with skewed data
data = [(1, "A")] * 1000 + [(2, "B")] * 100 + [(3, "C")] * 10
df = spark.createDataFrame(data, ["id", "category"])

# Show the DataFrame
print("Sample DataFrame:")
df.show(5)


Sample DataFrame:


                                                                                

+---+--------+
| id|category|
+---+--------+
|  1|       A|
|  1|       A|
|  1|       A|
|  1|       A|
|  1|       A|
+---+--------+
only showing top 5 rows



In [5]:
# Step 3: Diagnose Data Skew

# Check the number of rows per partition
print("\nNumber of rows per partition:")
df.groupBy(F.spark_partition_id()).count().show()

# Inspect data distribution in partitions
print("\nData in partitions (first 2 rows per partition):")
partitions = df.rdd.glom().collect()
for i, partition in enumerate(partitions):
    print(f"Partition {i}: {partition[:2]}")



Number of rows per partition:
+--------------------+-----+
|SPARK_PARTITION_ID()|count|
+--------------------+-----+
|                   0|   92|
|                   1|   92|
|                   2|   92|
|                   3|   92|
|                   4|   92|
|                   5|   92|
|                   6|   92|
|                   7|   92|
|                   8|   92|
|                   9|   92|
|                  10|   92|
|                  11|   98|
+--------------------+-----+


Data in partitions (first 2 rows per partition):
Partition 0: [Row(id=1, category='A'), Row(id=1, category='A')]
Partition 1: [Row(id=1, category='A'), Row(id=1, category='A')]
Partition 2: [Row(id=1, category='A'), Row(id=1, category='A')]
Partition 3: [Row(id=1, category='A'), Row(id=1, category='A')]
Partition 4: [Row(id=1, category='A'), Row(id=1, category='A')]
Partition 5: [Row(id=1, category='A'), Row(id=1, category='A')]
Partition 6: [Row(id=1, category='A'), Row(id=1, category='A')]
Partit

In [6]:
# Check distribution of 'id' across partitions
df.withColumn("partition_id", F.spark_partition_id()) \
  .groupBy("partition_id", "id") \
  .count() \
  .orderBy("partition_id", "id") \
  .show()

+------------+---+-----+
|partition_id| id|count|
+------------+---+-----+
|           0|  1|   92|
|           1|  1|   92|
|           2|  1|   92|
|           3|  1|   92|
|           4|  1|   92|
|           5|  1|   92|
|           6|  1|   92|
|           7|  1|   92|
|           8|  1|   92|
|           9|  1|   92|
|          10|  1|   80|
|          10|  2|   12|
|          11|  2|   88|
|          11|  3|   10|
+------------+---+-----+



In [7]:
# Step 5: Handle Data Skew - Repartition by Column

# Repartition the DataFrame by the skewed column
print("\nRepartitioning by 'id' column...")
df_repartitioned = df.repartition("id")

# Check the new distribution
print("\nNumber of rows per partition after repartitioning:")
df_repartitioned.groupBy(F.spark_partition_id()).count().show()



Repartitioning by 'id' column...

Number of rows per partition after repartitioning:
+--------------------+-----+
|SPARK_PARTITION_ID()|count|
+--------------------+-----+
|                   0| 1110|
+--------------------+-----+



In [8]:
# Step 6: Handle Data Skew - Salting

# Add a salt column to evenly distribute data
print("\nAdding a salt column for even distribution...")
df_salted = df.withColumn("salt", F.rand())

# Repartition by the salt column
df_salted = df_salted.repartition(8, "salt")

# Check the new distribution
print("\nNumber of rows per partition after salting:")
df_salted.groupBy(F.spark_partition_id()).count().show()



Adding a salt column for even distribution...

Number of rows per partition after salting:
+--------------------+-----+
|SPARK_PARTITION_ID()|count|
+--------------------+-----+
|                   0|  116|
|                   1|  142|
|                   2|  152|
|                   3|  132|
|                   4|  132|
|                   5|  159|
|                   6|  130|
|                   7|  147|
+--------------------+-----+



In [9]:
#Stop Spark Session
spark.stop()