In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, concat, rand, split, expr, collect_list

# Initialize SparkSession
spark = SparkSession.builder.appName("SaltingExample").getOrCreate()

# Sample skewed data
data = [("key1", 1), ("key1", 2), ("key1", 3), ("key1", 4), ("key2", 5), ("key3", 6)]
columns = ["key", "value"]

df = spark.createDataFrame(data, columns)

# Original DataFrame
df.show()
# Output:
# +----+-----+
# | key|value|
# +----+-----+
# |key1|    1|
# |key1|    2|
# |key1|    3|
# |key1|    4|
# |key2|    5|
# |key3|    6|
# +----+-----+


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/03 17:44:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/01/03 17:44:49 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
                                                                                

+----+-----+
| key|value|
+----+-----+
|key1|    1|
|key1|    2|
|key1|    3|
|key1|    4|
|key2|    5|
|key3|    6|
+----+-----+



In [3]:
from pyspark.sql.functions import concat, floor, rand

# Add random salt to the keys
salted_df = df.withColumn("salted_key", concat(col("key"), lit("_"), floor(rand() * 3)))

salted_df.show()
# Output (salted keys):
# +----+-----+-----------+
# | key|value| salted_key|
# +----+-----+-----------+
# |key1|    1| key1_0    |
# |key1|    2| key1_2    |
# |key1|    3| key1_1    |
# |key1|    4| key1_0    |
# |key2|    5| key2_0    |
# |key3|    6| key3_2    |
# +----+-----+-----------+

+----+-----+----------+
| key|value|salted_key|
+----+-----+----------+
|key1|    1|    key1_1|
|key1|    2|    key1_0|
|key1|    3|    key1_2|
|key1|    4|    key1_2|
|key2|    5|    key2_0|
|key3|    6|    key3_1|
+----+-----+----------+



In [4]:
# Group by salted keys and collect values
grouped_salted_df = salted_df.groupBy("salted_key").agg(collect_list("value").alias("values"))

grouped_salted_df.show()
# Output:
# +-----------+-------------+
# | salted_key|       values|
# +-----------+-------------+
# |     key1_0|      [1, 4]|
# |     key1_2|         [2]|
# |     key1_1|         [3]|
# |     key2_0|         [5]|
# |     key3_2|         [6]|
# +-----------+-------------+



+----------+------+
|salted_key|values|
+----------+------+
|    key1_1|   [1]|
|    key1_0|   [2]|
|    key1_2|[3, 4]|
|    key2_0|   [5]|
|    key3_1|   [6]|
+----------+------+



                                                                                