# Today's topic: Spark Repartition - How it works and how to screw it up

# 0. Set-Ups

General hints for this notebook:
- Spark UI usually accesible by http://localhost:4040/ or http://localhost:4041/
- Deep dive Spark UI happens in later episodes
- sc.setJobDescription("Description") replaces the Job Description of an action in the Spark UI with your own
- sdf.rdd.getNumPartitions() returns the number partitions of the current Spark DataFrame
- sdf.write.format("noop").mode("overwrite").save() is a good way to analyze and initiate actions for transformations without side effects during an actual write

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
import pyspark

In [3]:
spark = SparkSession \
    .builder \
    .appName("Data with Nikk the Greek Spark Session") \
    .master("local[4]") \
    .enableHiveSupport() \
    .getOrCreate()

sc = spark.sparkContext

In [4]:
#Turning off AQE as it generates more jobs which might be confusing for this scenario here. 
spark.conf.set("spark.sql.adaptive.enabled", "false")
#to not cache datafrimes... this may not create repeatable results
spark.conf.set("spark.databricks.io.cache.enabled", "false")

In [5]:
def sdf_generator(num_rows: int, num_partitions: int = None) -> "DataFrame":
    return (
        spark.range(num_rows, numPartitions=num_partitions)
        .withColumn("date", f.current_date())
        .withColumn("timestamp",f.current_timestamp())
        .withColumn("idstring", f.col("id").cast("string"))
        .withColumn("idfirst", f.col("idstring").substr(0,1))
        .withColumn("idlast", f.col("idstring").substr(-1,1))
        .withColumn("idfirsttwo", f.col("idstring").substr(0,2))
        )

In [6]:
sdf = sdf_generator(10, 4)
sdf.show()

+---+----------+--------------------+--------+-------+------+----------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|
+---+----------+--------------------+--------+-------+------+----------+
|  0|2024-01-30|2024-01-30 05:56:...|       0|      0|     0|         0|
|  1|2024-01-30|2024-01-30 05:56:...|       1|      1|     1|         1|
|  2|2024-01-30|2024-01-30 05:56:...|       2|      2|     2|         2|
|  3|2024-01-30|2024-01-30 05:56:...|       3|      3|     3|         3|
|  4|2024-01-30|2024-01-30 05:56:...|       4|      4|     4|         4|
|  5|2024-01-30|2024-01-30 05:56:...|       5|      5|     5|         5|
|  6|2024-01-30|2024-01-30 05:56:...|       6|      6|     6|         6|
|  7|2024-01-30|2024-01-30 05:56:...|       7|      7|     7|         7|
|  8|2024-01-30|2024-01-30 05:56:...|       8|      8|     8|         8|
|  9|2024-01-30|2024-01-30 05:56:...|       9|      9|     9|         9|
+---+----------+--------------------+--------+-----

In [16]:
def rows_per_partition(sdf: "DataFrame") -> None:
    num_rows = sdf.count()
    sdf_part = sdf.withColumn("partition_id", f.spark_partition_id())
    sdf_part_count = sdf_part.groupBy("partition_id").count()
    sdf_part_count = sdf_part_count.withColumn("count_perc", 100*f.col("count")/num_rows)
    sdf_part_count.orderBy("partition_id")
    sdf_part_count.show()
    return sdf_part_count

In [17]:
def rows_per_partition_col(sdf: "DataFrame", col: str) -> None:
    num_rows = sdf.count()
    sdf_part = sdf.withColumn("partition_id", f.spark_partition_id())
    sdf_part_count = sdf_part.groupBy("partition_id", col).count()
    sdf_part_count = sdf_part_count.withColumn("count_perc", 100*f.col("count")/num_rows)
    sdf_part_count = sdf_part_count.orderBy("partition_id", col)
    sdf_part_count.show()
    return sdf_part_count


# 1 - How repartitioning works
- Documentation: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.repartition.html#pyspark.sql.DataFrame.repartition
- Repartition allows to increase and decrease the number of partitions
- Repartition requires shuffling data which can be more unefficient than Coalesce
- On the other hand it creates uniform distributions unlike coalesce which only unions partions together
- Instead of partition based on the number of partitions you can partition based on a column. This uses Hash Partitioning instead of Round Robin
- If no number of partitions is defined the default value depends on spark.sql.shuffle.partitions which defaults to 200 (important later when evaluating wide transformations in later episodes)

https://stackoverflow.com/questions/65809909/spark-what-is-the-difference-between-repartition-and-repartitionbyrange

# 2 - Round Robin Shuffling

In [18]:
sdf = sdf_generator(20, 4)

In [24]:
sdf_part = sdf.withColumn("partition_id_before", f.spark_partition_id())
sdf_part.show()

+---+----------+--------------------+--------+-------+------+----------+-------------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|
+---+----------+--------------------+--------+-------+------+----------+-------------------+
|  0|2024-01-30|2024-01-30 06:21:...|       0|      0|     0|         0|                  0|
|  1|2024-01-30|2024-01-30 06:21:...|       1|      1|     1|         1|                  0|
|  2|2024-01-30|2024-01-30 06:21:...|       2|      2|     2|         2|                  0|
|  3|2024-01-30|2024-01-30 06:21:...|       3|      3|     3|         3|                  0|
|  4|2024-01-30|2024-01-30 06:21:...|       4|      4|     4|         4|                  0|
|  5|2024-01-30|2024-01-30 06:21:...|       5|      5|     5|         5|                  1|
|  6|2024-01-30|2024-01-30 06:21:...|       6|      6|     6|         6|                  1|
|  7|2024-01-30|2024-01-30 06:21:...|       7|      7|     7|         

In [27]:
sdf_round = sdf_part.repartition(10)
sdf_round = sdf_round.withColumn("partition_id", f.spark_partition_id()).orderBy("partition_id_before")
sdf_round.show()


+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
| id|      date|           timestamp|idstring|idfirst|idlast|idfirsttwo|partition_id_before|partition_id|
+---+----------+--------------------+--------+-------+------+----------+-------------------+------------+
|  4|2024-01-30|2024-01-30 06:23:...|       4|      4|     4|         4|                  0|           0|
|  0|2024-01-30|2024-01-30 06:23:...|       0|      0|     0|         0|                  0|           1|
|  3|2024-01-30|2024-01-30 06:23:...|       3|      3|     3|         3|                  0|           2|
|  1|2024-01-30|2024-01-30 06:23:...|       1|      1|     1|         1|                  0|           3|
|  2|2024-01-30|2024-01-30 06:23:...|       2|      2|     2|         2|                  0|           9|
|  8|2024-01-30|2024-01-30 06:23:...|       8|      8|     8|         8|                  1|           4|
|  5|2024-01-30|2024-01-30 06:23:...|       5|

# 3 - Hash Partitioning & Where it can screw you up

In [None]:
sdf1 = sdf.withColumn("hash", f.hash("idstring"))
sdf1 = sdf1.withColumn("part", f.abs(f.col("hash") % 4))
sdf1 = sdf1.orderBy("part")
sdf1.show()

# 4 - Repartition by range

- Documentation: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.repartition.html#pyspark.sql.DataFrame.repartition