In [1]:
from pyspark.sql.types import *
import pyspark.sql.functions as F
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.master("local[*]").getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/20 18:07:12 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/08/20 18:07:12 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
spark.conf.set("spark.sql.shuffle.partitions", "3")
spark.conf.get("spark.sql.shuffle.partitions")
spark.conf.set("spark.sql.adaptive.enabled", "false")

In [4]:
df_uniform = spark.createDataFrame([i for i in range(1000000)], IntegerType())
df_uniform.show(5, False)

+-----+
|value|
+-----+
|0    |
|1    |
|2    |
|3    |
|4    |
+-----+
only showing top 5 rows


                                                                                

In [5]:
(
    df_uniform
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show(15, False)
)

+---------+------+
|partition|count |
+---------+------+
|0        |99328 |
|1        |100352|
|2        |100352|
|3        |99328 |
|4        |100352|
|5        |100352|
|6        |99328 |
|7        |100352|
|8        |100352|
|9        |99904 |
+---------+------+



                                                                                

In [6]:
df0 = spark.createDataFrame([0] * 999990, IntegerType()).repartition(1)
df1 = spark.createDataFrame([1] * 15, IntegerType()).repartition(1)
df2 = spark.createDataFrame([2] * 10, IntegerType()).repartition(1)
df3 = spark.createDataFrame([3] * 5, IntegerType()).repartition(1)
df_skew = df0.union(df1).union(df2).union(df3)
df_skew.show(5, False)

+-----+
|value|
+-----+
|0    |
|0    |
|0    |
|0    |
|0    |
+-----+
only showing top 5 rows


In [7]:
(
    df_skew
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show()
)

+---------+------+
|partition| count|
+---------+------+
|        0|999990|
|        1|    15|
|        2|    10|
|        3|     5|
+---------+------+



                                                                                

In [8]:
df_joined_c1 = df_skew.join(df_uniform, "value", 'inner')

In [9]:
import time 
start = time.time()
(
    df_joined_c1
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .show(5, False)
)
end = time.time()
print(f"total_execution time {end-start}")

[Stage 19:>                                                       (0 + 10) / 10]

+---------+-------+
|partition|count  |
+---------+-------+
|0        |1000005|
|1        |15     |
+---------+-------+

total_execution time 1.193896770477295


                                                                                

In [10]:
SALT_NUMBER = int(spark.conf.get("spark.sql.shuffle.partitions"))
SALT_NUMBER

3

In [11]:
df_skew = df_skew.withColumn("salt",(F.rand()*SALT_NUMBER).cast("int"))

In [12]:
df_skew.show(5,False)

+-----+----+
|value|salt|
+-----+----+
|0    |1   |
|0    |0   |
|0    |0   |
|0    |2   |
|0    |0   |
+-----+----+
only showing top 5 rows


In [13]:
df_uniform = (
    df_uniform
    .withColumn("salt_value",F.array([F.lit(i) for i in range(SALT_NUMBER)]))
    .withColumn("salt",F.explode(F.col("salt_value")))
)

In [14]:
df_uniform.show(5,False
)

+-----+----------+----+
|value|salt_value|salt|
+-----+----------+----+
|0    |[0, 1, 2] |0   |
|0    |[0, 1, 2] |1   |
|0    |[0, 1, 2] |2   |
|1    |[0, 1, 2] |0   |
|1    |[0, 1, 2] |1   |
+-----+----------+----+
only showing top 5 rows


In [15]:
df_joined = df_skew.join(df_uniform, ["value", "salt"], 'inner')

In [16]:
import time 
start = time.time()
(
    df_joined
    .withColumn("partition", F.spark_partition_id())
    .groupBy("value", "partition")
    .count()
    .orderBy("value", "partition")
    .show()
)
end = time.time()
print(f"total_execution time {end-start}")

[Stage 41:>                                                       (0 + 10) / 10]

+-----+---------+------+
|value|partition| count|
+-----+---------+------+
|    0|        0|333567|
|    0|        1|333138|
|    0|        2|333285|
|    1|        0|     4|
|    1|        1|    11|
|    2|        0|     6|
|    2|        1|     1|
|    2|        2|     3|
|    3|        0|     1|
|    3|        1|     3|
|    3|        2|     1|
+-----+---------+------+

total_execution time 1.2699790000915527


                                                                                

In [17]:
import time
start = time.time()

df_skew.groupBy("value").count().show()
end = time.time()
print(f"total_execution time {end-start}")

+-----+------+
|value| count|
+-----+------+
|    0|999990|
|    2|    10|
|    3|     5|
|    1|    15|
+-----+------+

total_execution time 0.6018378734588623


In [18]:
import time
start = time.time()
(
    df_skew
    .withColumn("salt", (F.rand() * SALT_NUMBER).cast("int"))
    .groupBy("value", "salt")
    .agg(F.count("value").alias("count"))
    .groupBy("value")
    .agg(F.sum("count").alias("count"))
    .show()
)
end = time.time()
print(f"total_execution time {end-start}")

+-----+------+
|value| count|
+-----+------+
|    0|999990|
|    2|    10|
|    3|     5|
|    1|    15|
+-----+------+

total_execution time 0.6174449920654297


In [19]:
spark.stop()