## Fixing data skew by salting
Reference: https://www.youtube.com/watch?v=rZGsc5y8AQk

Salting: Adding randomness to distribute uneven data evenly

In [66]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
# Initialize Spark Session
spark = SparkSession.builder.appName("skew").getOrCreate()
sc = spark.sparkContext

In [67]:
spark.conf.set("spark.sql.shuffle.partitions", "3")
spark.conf.get("spark.sql.shuffle.partitions")
spark.conf.set("spark.sql.adaptive.enabled", "false")

### Simulating skewed join

In [68]:
uniform_df = spark.createDataFrame([i for i in range(1000000)], IntegerType())
uniform_df.show(5, False)

+-----+
|value|
+-----+
|0    |
|1    |
|2    |
|3    |
|4    |
+-----+
only showing top 5 rows



                                                                                

In [69]:
sc.setJobDescription("uniform_dataset")

uniform_df.withColumn('partition',spark_partition_id()).groupBy('partition').count().orderBy('partition').show()

+---------+------+
|partition| count|
+---------+------+
|        0|124928|
|        1|124928|
|        2|124928|
|        3|124928|
|        4|124928|
|        5|124928|
|        6|124928|
|        7|125504|
+---------+------+



In [70]:
df0 = spark.createDataFrame([0] * 999990, IntegerType()).repartition(1)
df1 = spark.createDataFrame([1] * 15, IntegerType()).repartition(1)
df2 = spark.createDataFrame([2] * 10, IntegerType()).repartition(1)
df3 = spark.createDataFrame([3] * 5, IntegerType()).repartition(1)
df_skew = df0.union(df1).union(df2).union(df3)
df_skew.show(5, False)

+-----+
|value|
+-----+
|0    |
|0    |
|0    |
|0    |
|0    |
+-----+
only showing top 5 rows



                                                                                

In [71]:
sc.setJobDescription("skewed_data")

df_skew.withColumn('partition',spark_partition_id()).groupBy('partition').count().orderBy('partition').show()

+---------+------+
|partition| count|
+---------+------+
|        0|999990|
|        1|    15|
|        2|    10|
|        3|     5|
+---------+------+



In [72]:
df_joined_c1 = df_skew.join(uniform_df, "value", 'inner')

In [73]:
sc.setJobDescription("skewed_join")

df_joined_c1.withColumn('partition',spark_partition_id()).groupBy('partition').count().orderBy('partition').show()

[Stage 19:>                                                         (0 + 8) / 8]

+---------+-------+
|partition|  count|
+---------+-------+
|        0|1000005|
|        1|     15|
+---------+-------+



                                                                                

### Salting

In [49]:
salt_no = int(spark.conf.get("spark.sql.shuffle.partitions"))
salt_no

3

In [50]:
df_skew = df_skew.withColumn("salt", (rand() * salt_no).cast("int"))

In [51]:
df_skew.show(10)

+-----+----+
|value|salt|
+-----+----+
|    0|   1|
|    0|   1|
|    0|   1|
|    0|   1|
|    0|   0|
|    0|   2|
|    0|   1|
|    0|   1|
|    0|   0|
|    0|   1|
+-----+----+
only showing top 10 rows



In [52]:
uniform_df = uniform_df.withColumn('salt_vals',array([lit(i) for i in range(salt_no)])).withColumn('salt',explode(col('salt_vals')))

In [53]:
uniform_df.show(10, truncate=False)

+-----+---------+----+
|value|salt_vals|salt|
+-----+---------+----+
|0    |[0, 1, 2]|0   |
|0    |[0, 1, 2]|1   |
|0    |[0, 1, 2]|2   |
|1    |[0, 1, 2]|0   |
|1    |[0, 1, 2]|1   |
|1    |[0, 1, 2]|2   |
|2    |[0, 1, 2]|0   |
|2    |[0, 1, 2]|1   |
|2    |[0, 1, 2]|2   |
|3    |[0, 1, 2]|0   |
+-----+---------+----+
only showing top 10 rows



In [54]:
df_joined = df_skew.join(uniform_df, ["value", "salt"], 'inner')

In [None]:
sc.setJobDescription("salted_join")

df_joined.withColumn('partition',spark_partition_id()).groupBy('value','partition').count().orderBy('value','partition').show()

[Stage 33:>                                                         (0 + 8) / 8]

+-----+---------+------+
|value|partition| count|
+-----+---------+------+
|    0|        0|333289|
|    0|        1|333404|
|    0|        2|333297|
|    1|        0|     4|
|    1|        1|    11|
|    2|        0|     6|
|    2|        1|     2|
|    2|        2|     2|
|    3|        0|     2|
|    3|        2|     3|
+-----+---------+------+



                                                                                

In [56]:
df_skew.groupBy("value").count().show()

+-----+------+
|value| count|
+-----+------+
|    0|999990|
|    2|    10|
|    3|     5|
|    1|    15|
+-----+------+



In [58]:
df_skew.withColumn("salt", (rand() * salt_no).cast("int"))\
    .groupBy("value", "salt")\
        .agg(count("value").alias("count"))\
            .groupBy("value")\
                .agg(sum("count").alias("count")).show()

+-----+------+
|value| count|
+-----+------+
|    0|999990|
|    2|    10|
|    3|     5|
|    1|    15|
+-----+------+



In [65]:
spark.stop()