## Using Salting to handle Data Skew when writing Data sets.


In this case, my Skew Column is 'bucket' which has records per value varying from 100 to 10000 columns. Writing this data without resolving the skew:

- Makes writes slower
- Put pressure on a few Executors

Steps to resolve the Skew:

1. Add a Salt Column.
2. Repartition by Salt Column.
3. Write Partitioned by Skew Column.


## Read the Data

In [5]:
input_df = spark.read.csv("s3://<bucket>/dataskew/files/")
input_df = input_df.toDF("bucket","prefix")
input_df.createOrReplaceTempView("input_df_v")
input_df.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+
|    bucket|prefix|
+----------+------+
|5UTV1TLVXT|  5296|
|5UTV1TLVXT|  7400|
|5UTV1TLVXT|  8573|
|5UTV1TLVXT|  4216|
|5UTV1TLVXT|  9965|
+----------+------+
only showing top 5 rows

In [9]:
spark.sql("Select bucket, count(1) as group_counts from input_df_v group by bucket").show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------------+
|    bucket|group_counts|
+----------+------------+
|GZVQF7CQRP|        1000|
|5UTV1TLVXT|       10000|
|TMG607P21Z|         100|
+----------+------------+

### Find Split Size and Salt Key column

In [6]:
group_counts=spark.sql("Select bucket, count(1) as group_counts from input_df_v group by bucket").rdd.collectAsMap()
group_counts_b = sc.broadcast(group_counts) 

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [7]:
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf
import random
import math

# partition size = 100
salt_udf = udf(lambda key: key+str(random.randint(1,math.ceil(group_counts_b.value[key]/100))), StringType())

input_df=input_df.withColumn("salted_key", salt_udf(input_df.bucket))
input_df.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+------------+
|    bucket|prefix|  salted_key|
+----------+------+------------+
|5UTV1TLVXT|  5296|5UTV1TLVXT30|
|5UTV1TLVXT|  7400|5UTV1TLVXT56|
|5UTV1TLVXT|  8573|5UTV1TLVXT43|
|5UTV1TLVXT|  4216|5UTV1TLVXT32|
|5UTV1TLVXT|  9965|5UTV1TLVXT19|
+----------+------+------------+
only showing top 5 rows

In [8]:
input_df.createOrReplaceTempView("input_df_v")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [10]:
spark.sql("Select bucket, count(distinct salted_key) as group_counts from input_df_v group by bucket").show(10)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------------+
|    bucket|group_counts|
+----------+------------+
|GZVQF7CQRP|          10|
|5UTV1TLVXT|         100|
|TMG607P21Z|           1|
+----------+------------+

In [13]:
spark.sql("Select salted_key, count(1) as group_counts from input_df_v group by salted_key order by 2 asc").show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------+------------+
|  salted_key|group_counts|
+------------+------------+
|5UTV1TLVXT94|          79|
|5UTV1TLVXT63|          80|
+------------+------------+
only showing top 2 rows

In [41]:
spark.sql("Select salted_key, count(1) as group_counts from input_df_v group by salted_key order by 2 desc").show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------+------------+
|  salted_key|group_counts|
+------------+------------+
|5UTV1TLVXT55|         128|
|5UTV1TLVXT45|         124|
+------------+------------+
only showing top 2 rows

Note the record counts for each salted key varies a bit from 128 to 79 but that should not be a problem.

In [14]:
partition_columns=["salted_key"]
input_df=input_df.repartition(*partition_columns)
input_df.rdd.getNumPartitions()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

29

In [28]:
input_df.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+------------+
|    bucket|prefix|  salted_key|
+----------+------+------------+
|5UTV1TLVXT|  6781|5UTV1TLVXT67|
|5UTV1TLVXT|  1899|5UTV1TLVXT67|
|5UTV1TLVXT|  2992|5UTV1TLVXT67|
+----------+------+------------+
only showing top 3 rows

In [42]:
partition_columns='bucket'
input_df.write.mode("OVERWRITE").partitionBy(partition_columns).csv("s3://<bucket>/dataskew/output/",header=True)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Verify Outputs

In [33]:
from pyspark.sql.functions import input_file_name 

output_df=spark.read.csv("s3://<bucket>/dataskew/output/",header=True)
output_df=output_df.withColumn("filename",input_file_name())
output_df.createOrReplaceTempView("output_df_v")
output_df.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- prefix: string (nullable = true)
 |-- salted_key: string (nullable = true)
 |-- bucket: string (nullable = true)
 |-- filename: string (nullable = false)

In [40]:
spark.sql("SELECT bucket, count(distinct filename) as files from output_df_v group by bucket ORDER by 2 DESC").show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+-----+
|    bucket|files|
+----------+-----+
|5UTV1TLVXT|   28|
|GZVQF7CQRP|    8|
|TMG607P21Z|    1|
+----------+-----+