In [None]:
# Step 1: Set up PySpark Session
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder \
    .appName("DataSkewHandling") \
    .getOrCreate()

In [None]:
# Step 2: Create a Sample DataFrame with Skew
from pyspark.sql import functions as F
# Create a DataFrame with skewed data
data = [(1, "A")] * 1000 + [(2, "B")] * 100 + [(3, "C")] * 10
df = spark.createDataFrame(data, ["id", "category"])
# Show the DataFrame
print("Sample DataFrame:")
df.show(5)


In [None]:
# Step 3: Diagnose Data Skew
# Check the number of rows per partition
print("\nNumber of rows per partition:")
df.groupBy(F.spark_partition_id()).count().show()
# Inspect data distribution in partitions
print("\nData in partitions (first 2 rows per partition):")
partitions = df.rdd.glom().collect()
for i, partition in enumerate(partitions):
    print(f"Partition {i}: {partition[:2]}")


In [None]:
# Check distribution of 'id' across partitions
df.withColumn("partition_id", F.spark_partition_id()) \
  .groupBy("partition_id", "id") \
  .count() \
  .orderBy("partition_id", "id") \
  .show()

In [None]:
# Step 5: Handle Data Skew - Repartition by Column
# Repartition the DataFrame by the skewed column
print("\nRepartitioning by 'id' column...")
df_repartitioned = df.repartition("id")
# Check the new distribution
print("\nNumber of rows per partition after repartitioning:")
df_repartitioned.groupBy(F.spark_partition_id()).count().show()


In [None]:
# Step 6: Handle Data Skew - Salting
# Add a salt column to evenly distribute data
print("\nAdding a salt column for even distribution...")
df_salted = df.withColumn("salt", F.rand())
# Repartition by the salt column
df_salted = df_salted.repartition(8, "salt")
# Check the new distribution
print("\nNumber of rows per partition after salting:")
df_salted.groupBy(F.spark_partition_id()).count().show()
