In [24]:
from pyspark.sql.functions import col, when, lit, rand
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('DataEngineering').getOrCreate()

# Define schema for the sales data
schema = StructType([
    StructField("order_id", StringType(), True),
    StructField("product_id", StringType(), True),
    StructField("quantity", IntegerType(), True),
    StructField("price", DoubleType(), True)
])
# Generate base data
data = [("O" + str(i), "P" + str(101 + (i % 3)), i % 5 + 1, 10.0 + i * 0.1) for i in range(1000)]
df = spark.createDataFrame(data, schema)
df.show(10)

+--------+----------+--------+-----+
|order_id|product_id|quantity|price|
+--------+----------+--------+-----+
|      O0|      P101|       1| 10.0|
|      O1|      P102|       2| 10.1|
|      O2|      P103|       3| 10.2|
|      O3|      P101|       4| 10.3|
|      O4|      P102|       5| 10.4|
|      O5|      P103|       1| 10.5|
|      O6|      P101|       2| 10.6|
|      O7|      P102|       3| 10.7|
|      O8|      P103|       4| 10.8|
|      O9|      P101|       5| 10.9|
+--------+----------+--------+-----+
only showing top 10 rows



In [25]:
# Introduce skew for product_id 'P101'
skewed_data = [("O" + str(i), "P101", i % 5 + 1, 10.0 + i * 0.1) for i in range(1000, 20000)]
skewed_df = spark.createDataFrame(skewed_data, schema)
skewed_df.show(10)
final_df = df.union(skewed_df)

# Display the distribution to show the skew
final_df.groupBy("product_id").count().orderBy(col("count").desc()).show()


+--------+----------+--------+------------------+
|order_id|product_id|quantity|             price|
+--------+----------+--------+------------------+
|   O1000|      P101|       1|             110.0|
|   O1001|      P101|       2|110.10000000000001|
|   O1002|      P101|       3|             110.2|
|   O1003|      P101|       4|110.30000000000001|
|   O1004|      P101|       5|             110.4|
|   O1005|      P101|       1|             110.5|
|   O1006|      P101|       2|110.60000000000001|
|   O1007|      P101|       3|             110.7|
|   O1008|      P101|       4|110.80000000000001|
|   O1009|      P101|       5|             110.9|
+--------+----------+--------+------------------+
only showing top 10 rows

+----------+-----+
|product_id|count|
+----------+-----+
|      P101|19334|
|      P102|  333|
|      P103|  333|
+----------+-----+



In [26]:
file_source_path = "/content/sample_data/skewed_data"
final_df.write.format("parquet").mode("overwrite").save(file_source_path)

In [27]:
try:
    sales_df = spark.read.format("parquet").load(file_source_path)
    print("Successfully read data from source path.")
    sales_df.printSchema()
except Exception as e:
    print(f"Error reading from S3: {e}")
    raise

Successfully read data from source path.
root
 |-- order_id: string (nullable = true)
 |-- product_id: string (nullable = true)
 |-- quantity: integer (nullable = true)
 |-- price: double (nullable = true)



In [28]:
from pyspark.sql.functions import col, count, broadcast, concat, lit, floor, rand, explode, array
from pyspark.sql.types import IntegerType
# This aggregation will be slow due to data skew.
# In the Spark UI, you would see some tasks taking much longer than others.
skewed_agg_df = sales_df.groupBy("product_id").agg(count("order_id").alias("total_orders"))
skewed_agg_df.orderBy(col("total_orders").desc()).show()


+----------+------------+
|product_id|total_orders|
+----------+------------+
|      P101|       19334|
|      P102|         333|
|      P103|         333|
+----------+------------+



In [29]:
# Identify the skewed key(s) from above result
skew_keys = ["P101"]
# Number of salted keys to generate
salt_factor = 5
# Create a salted DataFrame for the sales data
salted_sales_df = sales_df.withColumn("salt",
    when(col("product_id").isin(skew_keys), (rand() * salt_factor).cast(IntegerType()))
    .otherwise(lit(0))
).withColumn("salted_product_id", concat(col("product_id"), lit("_"), col("salt")))

salted_agg_df = salted_sales_df.groupBy("salted_product_id").agg(count("order_id").alias("total_orders"))
salted_agg_df.show()

# Remove the salt to get the final result
processed_df = salted_agg_df.withColumn("product_id",
    when(col("salted_product_id").contains("_"), col("salted_product_id").substr(lit(1), lit(4)))
    .otherwise(col("salted_product_id"))
).groupBy("product_id").agg(count("total_orders").alias("total_orders"))

processed_df.orderBy(col("total_orders").desc()).show()


+-----------------+------------+
|salted_product_id|total_orders|
+-----------------+------------+
|           P101_1|        3986|
|           P101_4|        3851|
|           P101_2|        3762|
|           P101_0|        3838|
|           P101_3|        3897|
|           P102_0|         333|
|           P103_0|         333|
+-----------------+------------+

+----------+------------+
|product_id|total_orders|
+----------+------------+
|      P101|           5|
|      P102|           1|
|      P103|           1|
+----------+------------+



### Data skew handling methods:
    # 1. Salting
    lookup = df.select('skewed_key').distinct().limit(10)
    salted_df = salted_join(df, lookup)
    
    # 2. Broadcast join
    broadcasted_df = broadcast_join(df, lookup, 'skewed_key')
    
    # 3. Repartition
    repartitioned_df = repartition_on_skewed_key(df, 'skewed_key')