# Today's topic: Spark Repartition - How it works and how to screw it up

1. Recap: How Repartitioning works
2. Round Robin Partitioning
3. Hash Partitioning
4. Range Partitioning
5. Summary

# 0. Set-Ups

General hints for this notebook:
- Spark UI usually accesible by http://localhost:4040/ or http://localhost:4041/
- Deep dive Spark UI happens in later episodes
- sc.setJobDescription("Description") replaces the Job Description of an action in the Spark UI with your own
- sdf.rdd.getNumPartitions() returns the number partitions of the current Spark DataFrame
- sdf.write.format("noop").mode("overwrite").save() is a good way to analyze and initiate actions for transformations without side effects during an actual write

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from pyspark.sql import types as t

In [2]:
spark = SparkSession \
    .builder \
    .appName("Data with Nikk the Greek Spark Session") \
    .master("local[4]") \
    .enableHiveSupport() \
    .getOrCreate()

sc = spark.sparkContext

In [3]:
#Turning off AQE as it generates more jobs which might be confusing for this scenario here. 
spark.conf.set("spark.sql.adaptive.enabled", "false")
#to not cache datafrimes... this may not create repeatable results
spark.conf.set("spark.databricks.io.cache.enabled", "false")

In [4]:
def sdf_generator(num_rows: int, num_partitions: int = None) -> "DataFrame":
    return (
        spark.range(num_rows, numPartitions=num_partitions)
        .withColumn("date", f.current_date())
        .withColumn("timestamp",f.current_timestamp())
        .withColumn("idstring", f.col("id").cast("string"))
        .withColumn("idfirst", f.col("idstring").substr(0,1))
        .withColumn("idlast", f.col("idstring").substr(-1,1))
        .withColumn("idfirsttwo", f.col("idstring").substr(0,2))
        )

In [23]:
sdf = sdf_generator(20, 4)
sdf.show()

+---+----------+--------------------+--------+-------+------+----------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|
+---+----------+--------------------+--------+-------+------+----------+
|  0|2024-01-31|2024-01-31 23:09:...|       0|      0|     0|         0|
|  1|2024-01-31|2024-01-31 23:09:...|       1|      1|     1|         1|
|  2|2024-01-31|2024-01-31 23:09:...|       2|      2|     2|         2|
|  3|2024-01-31|2024-01-31 23:09:...|       3|      3|     3|         3|
|  4|2024-01-31|2024-01-31 23:09:...|       4|      4|     4|         4|
|  5|2024-01-31|2024-01-31 23:09:...|       5|      5|     5|         5|
|  6|2024-01-31|2024-01-31 23:09:...|       6|      6|     6|         6|
|  7|2024-01-31|2024-01-31 23:09:...|       7|      7|     7|         7|
|  8|2024-01-31|2024-01-31 23:09:...|       8|      8|     8|         8|
|  9|2024-01-31|2024-01-31 23:09:...|       9|      9|     9|         9|
| 10|2024-01-31|2024-01-31 23:09:...|      10|     

In [6]:
def rows_per_partition(sdf: "DataFrame") -> None:
    num_rows = sdf.count()
    sdf_part = sdf.withColumn("partition_id", f.spark_partition_id())
    sdf_part_count = sdf_part.groupBy("partition_id").count()
    sdf_part_count = sdf_part_count.withColumn("count_perc", 100*f.col("count")/num_rows)
    sdf_part_count.orderBy("partition_id")
    sdf_part_count.show()
    return sdf_part_count

In [7]:
def rows_per_partition_col(sdf: "DataFrame", col: str) -> None:
    num_rows = sdf.count()
    sdf_part = sdf.withColumn("partition_id", f.spark_partition_id())
    sdf_part_count = sdf_part.groupBy("partition_id", col).count()
    sdf_part_count = sdf_part_count.withColumn("count_perc", 100*f.col("count")/num_rows)
    sdf_part_count = sdf_part_count.orderBy("partition_id", col)
    sdf_part_count.show()
    return sdf_part_count


# 1 - How repartitioning works
- Documentation: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.repartition.html#pyspark.sql.DataFrame.repartition
- Repartition allows to increase and decrease the number of partitions
- Repartition requires shuffling data which can be more unefficient than Coalesce
- On the other hand it creates uniform distributions unlike coalesce which only unions partions together
- Instead of partition based on the number of partitions you can partition based on a column. This uses Hash Partitioning instead of Round Robin
- If no number of partitions is defined the default value depends on spark.sql.shuffle.partitions which defaults to 200 (important later when evaluating wide transformations in later episodes)

https://stackoverflow.com/questions/65809909/spark-what-is-the-difference-between-repartition-and-repartitionbyrange

# 2 - Round Robin Partitioning
- We define the number of target partitions we want to receive as an output
- Per source partition we do the following:
    - We choose a random target partition (range 0 to number target partitions)
    - Row by row we assign the next target partition, e.g. partition 4, 5, 6, 7, 8, 9, 0 etc.
- How to screw it up: 
    - For a small number of values per partition some partitions might be get assigned less or no values
    - This of course only happens rarely with no target partitions bigger than values per source partition as e.g. if you have 5 values per partition with 4 in total and increase to 10 partitions

In [33]:
sdf = sdf_generator(20002, 4)
sdf = sdf.withColumn("partition_id_before", f.spark_partition_id())
rows_per_partition(sdf)

+------------+-----+---------------+
|partition_id|count|     count_perc|
+------------+-----+---------------+
|           1| 5001|25.002499750025|
|           3| 5001|25.002499750025|
|           2| 5000|24.997500249975|
|           0| 5000|24.997500249975|
+------------+-----+---------------+



DataFrame[partition_id: int, count: bigint, count_perc: double]

In [34]:
sdf_part = sdf.repartition(10)
rows_per_partition(sdf_part)

+------------+-----+------------------+
|partition_id|count|        count_perc|
+------------+-----+------------------+
|           1| 2000|     9.99900009999|
|           6| 2000|     9.99900009999|
|           3| 2000|     9.99900009999|
|           5| 2000|     9.99900009999|
|           9| 2000|     9.99900009999|
|           4| 2001|10.003999600039997|
|           8| 2001|10.003999600039997|
|           7| 2000|     9.99900009999|
|           2| 2000|     9.99900009999|
|           0| 2000|     9.99900009999|
+------------+-----+------------------+



DataFrame[partition_id: int, count: bigint, count_perc: double]

In [29]:
sdf.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|
+---+----------+--------------------+--------+-------+------+----------+-------------------+
|  0|2024-01-31|2024-01-31 23:17:...|       0|      0|     0|         0|                  0|
|  1|2024-01-31|2024-01-31 23:17:...|       1|      1|     1|         1|                  0|
|  2|2024-01-31|2024-01-31 23:17:...|       2|      2|     2|         2|                  0|
|  3|2024-01-31|2024-01-31 23:17:...|       3|      3|     3|         3|                  0|
|  4|2024-01-31|2024-01-31 23:17:...|       4|      4|     4|         4|                  0|
|  5|2024-01-31|2024-01-31 23:17:...|       5|      5|     5|         5|                  1|
|  6|2024-01-31|2024-01-31 23:17:...|       6|      6|     6|         6|                  1|
|  7|2024-01-31|2024-01-31 23:17:...|       7|      7|     7|         

In [30]:
sdf_part = sdf_part.withColumn("partition_id", f.spark_partition_id()).orderBy("partition_id_before")
sdf_part.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|partition_id|
+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
|  4|2024-01-31|2024-01-31 23:17:...|       4|      4|     4|         4|                  0|           0|
|  0|2024-01-31|2024-01-31 23:17:...|       0|      0|     0|         0|                  0|           1|
|  3|2024-01-31|2024-01-31 23:17:...|       3|      3|     3|         3|                  0|           2|
|  1|2024-01-31|2024-01-31 23:17:...|       1|      1|     1|         1|                  0|           3|
|  2|2024-01-31|2024-01-31 23:17:...|       2|      2|     2|         2|                  0|           9|
|  8|2024-01-31|2024-01-31 23:17:...|       8|      8|     8|         8|                  1|           4|
|  5|2024-01-31|2024-01-31 23:17:...|       5|

# 3 - Hash Partitioning
- How it works:
    - During repartitioning each value of the column to be rapartitioned by is hashed to an integer
    - Spark uses the Murmur3_x86_32 
    - Applying Module % num_target_partitions results into returning the assigned target partition
- How to screw it up:
    - If you are unlucky the values you have in a column don't distribute the same way. E.g. you have 10 distinct values but only 4 output partitions
    - Especially for a low distinct value range this might be the case
    - Increasing the number of rows per partition does not help
    - adjusting the num of target partitions influences the result

In [45]:
sdf = sdf_generator(20, 4)
sdf = sdf.withColumn("partition_id_before", f.spark_partition_id())
sdf.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|
+---+----------+--------------------+--------+-------+------+----------+-------------------+
|  0|2024-01-31|2024-01-31 23:28:...|       0|      0|     0|         0|                  0|
|  1|2024-01-31|2024-01-31 23:28:...|       1|      1|     1|         1|                  0|
|  2|2024-01-31|2024-01-31 23:28:...|       2|      2|     2|         2|                  0|
|  3|2024-01-31|2024-01-31 23:28:...|       3|      3|     3|         3|                  0|
|  4|2024-01-31|2024-01-31 23:28:...|       4|      4|     4|         4|                  0|
|  5|2024-01-31|2024-01-31 23:28:...|       5|      5|     5|         5|                  1|
|  6|2024-01-31|2024-01-31 23:28:...|       6|      6|     6|         6|                  1|
|  7|2024-01-31|2024-01-31 23:28:...|       7|      7|     7|         

In [46]:
rows_per_partition(sdf)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           1|    5|      25.0|
|           3|    5|      25.0|
|           2|    5|      25.0|
|           0|    5|      25.0|
+------------+-----+----------+



DataFrame[partition_id: int, count: bigint, count_perc: double]

In [47]:
sdf_part = sdf.repartition(5, "idlast")
rows_per_partition(sdf_part)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           1|    4|      20.0|
|           3|    4|      20.0|
|           4|    8|      40.0|
|           0|    4|      20.0|
+------------+-----+----------+



DataFrame[partition_id: int, count: bigint, count_perc: double]

In [15]:
sdf_part = sdf_part.withColumn("partition_id", f.spark_partition_id()).orderBy("id")
sdf_part.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|partition_id|
+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
|  0|2024-01-31|2024-01-31 23:00:...|       0|      0|     0|         0|                  0|           0|
|  1|2024-01-31|2024-01-31 23:00:...|       1|      1|     1|         1|                  0|           4|
|  2|2024-01-31|2024-01-31 23:00:...|       2|      2|     2|         2|                  0|           4|
|  3|2024-01-31|2024-01-31 23:00:...|       3|      3|     3|         3|                  0|           3|
|  4|2024-01-31|2024-01-31 23:00:...|       4|      4|     4|         4|                  0|           1|
|  5|2024-01-31|2024-01-31 23:00:...|       5|      5|     5|         5|                  1|           4|
|  6|2024-01-31|2024-01-31 23:00:...|       6|

In [36]:
sdf_hash = sdf.withColumn("hash", f.hash("idlast"))
sdf_hash = sdf_hash.withColumn("part", f.abs(f.col("hash") % 5))
sdf_hash = sdf_hash.orderBy("id")
sdf_hash.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+-----------+----+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|       hash|part|
+---+----------+--------------------+--------+-------+------+----------+-------------------+-----------+----+
|  0|2024-01-31|2024-01-31 23:25:...|       0|      0|     0|         0|                  0|  735846435|   0|
|  1|2024-01-31|2024-01-31 23:25:...|       1|      1|     1|         1|                  0| 1625004744|   4|
|  2|2024-01-31|2024-01-31 23:25:...|       2|      2|     2|         2|                  0|  870267989|   4|
|  3|2024-01-31|2024-01-31 23:25:...|       3|      3|     3|         3|                  0|-1756013582|   2|
|  4|2024-01-31|2024-01-31 23:25:...|       4|      4|     4|         4|                  0|-2142269034|   4|
|  5|2024-01-31|2024-01-31 23:25:...|       5|      5|     5|         5|                  1|  135093849|   4|
|  6|2024-

In [37]:
sdf_hash.groupBy("part").count().show()


+----+-----+
|part|count|
+----+-----+
|   3|    2|
|   4|   12|
|   2|    2|
|   0|    4|
+----+-----+



# 4 - Repartition by range
- How it works:
  - In range partitioning spark creates ranges based on the values
  - Sampling is used to identify the ranges which then are divided into the target partitions
  - e.g. id col below is split into 5 partitions 1-4, 5-8, 9-12, 13-16, 17-20
  - it usually used for continious values
  - If no number of partitions is defined the default value depends on spark.sql.shuffle.partitions which defaults to 200 (important later when evaluating wide transformations in later episodes)
- How to screw it up:
  - Using columns with descrete values and some values repeat more often than others. Basically a skewed dataset

- Documentation: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.repartitionByRange.html
- Reference: https://stackoverflow.com/questions/65809909/spark-what-is-the-difference-between-repartition-and-repartitionbyrange

In [48]:
sdf = sdf_generator(20, 4)
sdf = sdf.withColumn("partition_id_before", f.spark_partition_id()) #.where((f.col("id") < 10) | (f.col("id") > 14))
sdf.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|
+---+----------+--------------------+--------+-------+------+----------+-------------------+
|  0|2024-01-31|2024-01-31 23:29:...|       0|      0|     0|         0|                  0|
|  1|2024-01-31|2024-01-31 23:29:...|       1|      1|     1|         1|                  0|
|  2|2024-01-31|2024-01-31 23:29:...|       2|      2|     2|         2|                  0|
|  3|2024-01-31|2024-01-31 23:29:...|       3|      3|     3|         3|                  0|
|  4|2024-01-31|2024-01-31 23:29:...|       4|      4|     4|         4|                  0|
|  5|2024-01-31|2024-01-31 23:29:...|       5|      5|     5|         5|                  1|
|  6|2024-01-31|2024-01-31 23:29:...|       6|      6|     6|         6|                  1|
|  7|2024-01-31|2024-01-31 23:29:...|       7|      7|     7|         

In [49]:
rows_per_partition(sdf)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           1|    5|      25.0|
|           3|    5|      25.0|
|           2|    5|      25.0|
|           0|    5|      25.0|
+------------+-----+----------+



DataFrame[partition_id: int, count: bigint, count_perc: double]

In [52]:
sdf_part = sdf.repartitionByRange(5, "idfirst")
rows_per_partition(sdf_part)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           1|    1|       5.0|
|           3|    2|      10.0|
|           4|    4|      20.0|
|           2|    1|       5.0|
|           0|   12|      60.0|
+------------+-----+----------+



DataFrame[partition_id: int, count: bigint, count_perc: double]

In [21]:
sdf_part = sdf_part.withColumn("partition_id", f.spark_partition_id()).orderBy("id")
sdf_part.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|partition_id|
+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
|  0|2024-01-31|2024-01-31 23:00:...|       0|      0|     0|         0|                  0|           0|
|  1|2024-01-31|2024-01-31 23:00:...|       1|      1|     1|         1|                  0|           0|
|  2|2024-01-31|2024-01-31 23:00:...|       2|      2|     2|         2|                  0|           1|
|  3|2024-01-31|2024-01-31 23:00:...|       3|      3|     3|         3|                  0|           1|
|  4|2024-01-31|2024-01-31 23:00:...|       4|      4|     4|         4|                  0|           2|
|  5|2024-01-31|2024-01-31 23:00:...|       5|      5|     5|         5|                  1|           2|
|  6|2024-01-31|2024-01-31 23:00:...|       6|

# 5 - Summary
- Use Round Robin Shuffling if you want to have equal partition sizes not depending on column grouping
- Use Hash Partitioning to repartition data based on cols. Take care to have a bigger number of distinct values. It might create a skew even though the col distribution is not skewed
- Use Range partitioning for continious increasing columns