# Today's topic: Spark Repartition - How it works and how to screw it up

# 0. Set-Ups

General hints for this notebook:
- Spark UI usually accesible by http://localhost:4040/ or http://localhost:4041/
- Deep dive Spark UI happens in later episodes
- sc.setJobDescription("Description") replaces the Job Description of an action in the Spark UI with your own
- sdf.rdd.getNumPartitions() returns the number partitions of the current Spark DataFrame
- sdf.write.format("noop").mode("overwrite").save() is a good way to analyze and initiate actions for transformations without side effects during an actual write

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
import pyspark

In [3]:
spark = SparkSession \
    .builder \
    .appName("Data with Nikk the Greek Spark Session") \
    .master("local[4]") \
    .enableHiveSupport() \
    .getOrCreate()

sc = spark.sparkContext

In [4]:
#Turning off AQE as it generates more jobs which might be confusing for this scenario here. 
spark.conf.set("spark.sql.adaptive.enabled", "false")
#to not cache datafrimes... this may not create repeatable results
spark.conf.set("spark.databricks.io.cache.enabled", "false")

In [5]:
def sdf_generator(num_rows: int, num_partitions: int = None) -> "DataFrame":
    return (
        spark.range(num_rows, numPartitions=num_partitions)
        .withColumn("date", f.current_date())
        .withColumn("timestamp",f.current_timestamp())
        .withColumn("idstring", f.col("id").cast("string"))
        .withColumn("idfirst", f.col("idstring").substr(0,1))
        .withColumn("idlast", f.col("idstring").substr(-1,1))
        .withColumn("idfirsttwo", f.col("idstring").substr(0,2))
        )

In [6]:
sdf = sdf_generator(10, 4)
sdf.show()

+---+----------+--------------------+--------+-------+------+----------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|
+---+----------+--------------------+--------+-------+------+----------+
|  0|2024-01-30|2024-01-30 05:56:...|       0|      0|     0|         0|
|  1|2024-01-30|2024-01-30 05:56:...|       1|      1|     1|         1|
|  2|2024-01-30|2024-01-30 05:56:...|       2|      2|     2|         2|
|  3|2024-01-30|2024-01-30 05:56:...|       3|      3|     3|         3|
|  4|2024-01-30|2024-01-30 05:56:...|       4|      4|     4|         4|
|  5|2024-01-30|2024-01-30 05:56:...|       5|      5|     5|         5|
|  6|2024-01-30|2024-01-30 05:56:...|       6|      6|     6|         6|
|  7|2024-01-30|2024-01-30 05:56:...|       7|      7|     7|         7|
|  8|2024-01-30|2024-01-30 05:56:...|       8|      8|     8|         8|
|  9|2024-01-30|2024-01-30 05:56:...|       9|      9|     9|         9|
+---+----------+--------------------+--------+-----

In [30]:
def rows_per_partition(sdf: "DataFrame") -> None:
    num_rows = sdf.count()
    sdf_part = sdf.withColumn("partition_id", f.spark_partition_id())
    sdf_part_count = sdf_part.groupBy("partition_id").count()
    sdf_part_count = sdf_part_count.withColumn("count_perc", 100*f.col("count")/num_rows)
    sdf_part_count.orderBy("partition_id")
    sdf_part_count.show()
    return sdf_part_count

In [31]:
def rows_per_partition_col(sdf: "DataFrame", col: str) -> None:
    num_rows = sdf.count()
    sdf_part = sdf.withColumn("partition_id", f.spark_partition_id())
    sdf_part_count = sdf_part.groupBy("partition_id", col).count()
    sdf_part_count = sdf_part_count.withColumn("count_perc", 100*f.col("count")/num_rows)
    sdf_part_count = sdf_part_count.orderBy("partition_id", col)
    sdf_part_count.show()
    return sdf_part_count


# 1 - How repartitioning works
- Documentation: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.repartition.html#pyspark.sql.DataFrame.repartition
- Repartition allows to increase and decrease the number of partitions
- Repartition requires shuffling data which can be more unefficient than Coalesce
- On the other hand it creates uniform distributions unlike coalesce which only unions partions together
- Instead of partition based on the number of partitions you can partition based on a column. This uses Hash Partitioning instead of Round Robin
- If no number of partitions is defined the default value depends on spark.sql.shuffle.partitions which defaults to 200 (important later when evaluating wide transformations in later episodes)

https://stackoverflow.com/questions/65809909/spark-what-is-the-difference-between-repartition-and-repartitionbyrange

# 2 - Round Robin Shuffling
- We define the number of target partitions we want to receive as an output
- Per source partition we do the following:
    - We choose a random target partition (range 0 to number target partitions)
    - Row by row we assign the next target partition, e.g. partition 4, 5, 6, 7, 8, 9, 0 etc.
- Issues: 
    - For a small number of values per partition some partitions might be get assigned less or no values
    - This of course only happens rarely with no target partitions bigger than values per source partition as e.g. if you have 5 values per partition with 4 in total and increase to 10 partitions

In [123]:
sdf = sdf_generator(20, 10)
sdf = sdf.withColumn("partition_id_before", f.spark_partition_id())
rows_per_partition(sdf)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           1|    2|      10.0|
|           6|    2|      10.0|
|           3|    2|      10.0|
|           5|    2|      10.0|
|           9|    2|      10.0|
|           4|    2|      10.0|
|           8|    2|      10.0|
|           7|    2|      10.0|
|           2|    2|      10.0|
|           0|    2|      10.0|
+------------+-----+----------+



DataFrame[partition_id: int, count: bigint, count_perc: double]

In [124]:
sdf_part = sdf.repartition(5)
rows_per_partition(sdf_part)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           1|    4|      20.0|
|           3|    1|       5.0|
|           4|    6|      30.0|
|           2|    4|      20.0|
|           0|    5|      25.0|
+------------+-----+----------+



DataFrame[partition_id: int, count: bigint, count_perc: double]

In [125]:
sdf.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|
+---+----------+--------------------+--------+-------+------+----------+-------------------+
|  0|2024-01-30|2024-01-30 21:46:...|       0|      0|     0|         0|                  0|
|  1|2024-01-30|2024-01-30 21:46:...|       1|      1|     1|         1|                  0|
|  2|2024-01-30|2024-01-30 21:46:...|       2|      2|     2|         2|                  1|
|  3|2024-01-30|2024-01-30 21:46:...|       3|      3|     3|         3|                  1|
|  4|2024-01-30|2024-01-30 21:46:...|       4|      4|     4|         4|                  2|
|  5|2024-01-30|2024-01-30 21:46:...|       5|      5|     5|         5|                  2|
|  6|2024-01-30|2024-01-30 21:46:...|       6|      6|     6|         6|                  3|
|  7|2024-01-30|2024-01-30 21:46:...|       7|      7|     7|         

In [126]:
sdf_part = sdf_part.withColumn("partition_id", f.spark_partition_id()).orderBy("partition_id_before")
sdf_part.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|partition_id|
+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
|  1|2024-01-30|2024-01-30 21:46:...|       1|      1|     1|         1|                  0|           0|
|  0|2024-01-30|2024-01-30 21:46:...|       0|      0|     0|         0|                  0|           4|
|  3|2024-01-30|2024-01-30 21:46:...|       3|      3|     3|         3|                  1|           0|
|  2|2024-01-30|2024-01-30 21:46:...|       2|      2|     2|         2|                  1|           4|
|  4|2024-01-30|2024-01-30 21:46:...|       4|      4|     4|         4|                  2|           0|
|  5|2024-01-30|2024-01-30 21:46:...|       5|      5|     5|         5|                  2|           4|
|  6|2024-01-30|2024-01-30 21:46:...|       6|

# 3 - Hash Partitioning & Where it can screw you up
- How it works:
    - During repartitioning each value of the column to be rapartitioned by is hashed to an integer
    - Spark uses the Murmur3_x86_32 
    - Applying Module % num_target_partitions results into returning the assigned target partition
- Issue:
    - If you are unlucky the values you have in a column don't distribute the same way. E.g. you have 10 distinct values but only 4 output partitions
    - Especially for a low distinct value range this might be the case
    - Increasing the number of rows per partition does not help
    - adjusting the num of target partitions influences the result
- https://stackoverflow.com/questions/73303061/the-results-of-murmurhash-in-pyspark-and-local-python-are-different

In [127]:
sdf = sdf_generator(20, 4)
sdf = sdf.withColumn("partition_id_before", f.spark_partition_id())
sdf.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|
+---+----------+--------------------+--------+-------+------+----------+-------------------+
|  0|2024-01-30|2024-01-30 21:46:...|       0|      0|     0|         0|                  0|
|  1|2024-01-30|2024-01-30 21:46:...|       1|      1|     1|         1|                  0|
|  2|2024-01-30|2024-01-30 21:46:...|       2|      2|     2|         2|                  0|
|  3|2024-01-30|2024-01-30 21:46:...|       3|      3|     3|         3|                  0|
|  4|2024-01-30|2024-01-30 21:46:...|       4|      4|     4|         4|                  0|
|  5|2024-01-30|2024-01-30 21:46:...|       5|      5|     5|         5|                  1|
|  6|2024-01-30|2024-01-30 21:46:...|       6|      6|     6|         6|                  1|
|  7|2024-01-30|2024-01-30 21:46:...|       7|      7|     7|         

In [128]:
rows_per_partition(sdf)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           1|    5|      25.0|
|           3|    5|      25.0|
|           2|    5|      25.0|
|           0|    5|      25.0|
+------------+-----+----------+



DataFrame[partition_id: int, count: bigint, count_perc: double]

In [131]:
sdf_part = sdf.repartition(5, "idlast")
rows_per_partition(sdf_part)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           1|    4|      20.0|
|           3|    4|      20.0|
|           4|    8|      40.0|
|           0|    4|      20.0|
+------------+-----+----------+



DataFrame[partition_id: int, count: bigint, count_perc: double]

In [132]:
sdf_part = sdf_part.withColumn("partition_id", f.spark_partition_id()).orderBy("id")
sdf_part.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|partition_id|
+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
|  0|2024-01-30|2024-01-30 21:47:...|       0|      0|     0|         0|                  0|           0|
|  1|2024-01-30|2024-01-30 21:47:...|       1|      1|     1|         1|                  0|           4|
|  2|2024-01-30|2024-01-30 21:47:...|       2|      2|     2|         2|                  0|           4|
|  3|2024-01-30|2024-01-30 21:47:...|       3|      3|     3|         3|                  0|           3|
|  4|2024-01-30|2024-01-30 21:47:...|       4|      4|     4|         4|                  0|           1|
|  5|2024-01-30|2024-01-30 21:47:...|       5|      5|     5|         5|                  1|           4|
|  6|2024-01-30|2024-01-30 21:47:...|       6|

In [137]:
sdf_hash = sdf.withColumn("hash", f.hash("idlast"))
sdf_hash = sdf_hash.withColumn("part", f.abs(f.col("hash") % 5))
sdf_hash = sdf_hash.orderBy("part")
sdf_hash.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+-----------+----+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|       hash|part|
+---+----------+--------------------+--------+-------+------+----------+-------------------+-----------+----+
|  0|2024-01-30|2024-01-30 22:08:...|       0|      0|     0|         0|                  0|  735846435|   0|
|  6|2024-01-30|2024-01-30 22:08:...|       6|      6|     6|         6|                  1|-1929623325|   0|
| 10|2024-01-30|2024-01-30 22:08:...|      10|      1|     0|        10|                  2|  735846435|   0|
| 16|2024-01-30|2024-01-30 22:08:...|      16|      1|     6|        16|                  3|-1929623325|   0|
|  3|2024-01-30|2024-01-30 22:08:...|       3|      3|     3|         3|                  0|-1756013582|   2|
| 13|2024-01-30|2024-01-30 22:08:...|      13|      1|     3|        13|                  2|-1756013582|   2|
|  7|2024-

In [134]:
sdf_hash.groupBy("part").count().show()


+----+-----+
|part|count|
+----+-----+
|   3|    2|
|   4|   12|
|   2|    2|
|   0|    4|
+----+-----+



# 4 - Repartition by range

- Documentation: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.repartition.html#pyspark.sql.DataFrame.repartition