# How Spark partitions data for wide transformations

This notebook demonstrates how Spark partitions data in wide transformations such as `.groupBy()`, `.join()` and `Window.partitionBy()`. To test this we will create a skewed variable within a DataFrame to make it more obvious how Spark is partitioning the data. 

We will also see how a `.cache()` might affect the execution plan.

## Getting set up
Start with some imports and a Spark session. Note that the `spark.executor.memory` is set to `1GB`

In [1]:
from pyspark.sql import SparkSession, functions as F, Window 

spark = (
    SparkSession.builder.appName("wide-skew")
         .config("spark.executor.memory", "1g")
    .getOrCreate()
)

Next create a DataFrame with a skewed variable, `skew_col`. Also add another column `rand_val`, which we will use later. We will use `spark.range()` and set the number of partitions to two.

In [2]:
row_ct = 10**7
seed_no = 42

skewed_df = spark.range(row_ct, numPartitions=2)

skewed_df = (
    skewed_df.withColumn("skew_col", F.when(F.col("id") < 100, "A")
                               .when(F.col("id") < 1000, "B")
                               .when(F.col("id") < 10000, "C")
                               .when(F.col("id") < 100000, "D")
                               .otherwise("E"))
              .withColumn("rand_val", F.rint(F.rand(seed_no)*10).cast("int"))
)

skewed_df.show(5)

+---+--------+--------+
| id|skew_col|rand_val|
+---+--------+--------+
|  0|       A|       7|
|  1|       A|       9|
|  2|       A|       9|
|  3|       A|       9|
|  4|       A|       4|
+---+--------+--------+
only showing top 5 rows



To do the `.join()` we'll need a second DataFrame. Note this has only 5 rows.

In [3]:
small_df = spark.createDataFrame([
    ["A", 1],
    ["B", 2],
    ["C", 3],
    ["D", 4],
    ["E", 5]
], ["skew_col", "number_col"])

We're going to need the Spark UI, create a link

In [4]:
import os, IPython
url = "spark-%s.%s" % (os.environ["CDSW_ENGINE_ID"], os.environ["CDSW_DOMAIN"])
IPython.display.HTML("<a href=http://%s>Spark UI</a>" % url)

## Some wide transformations

The wide transformations are put into functions so we can call them repeatedly later. For each function the job description is set so we can find the correct jobs in the Spark UI. The functions return a sample of the DataFrame, which isn't important. The important thing here is to look at the task metrics and possibly the SQL DAG diagrams in the Spark UI.

There is one function for a `.groupBy()`, another for a `.join()` and a third for a `Window.partitionBy()`.

In [5]:
def action_join():
    spark.sparkContext.setJobDescription("join")
    joined_df = skewed_df.join(small_df, on="skew_col", how="left")
    return joined_df.show(3)

def action_groupby_sum():
    spark.sparkContext.setJobDescription("groupby_sum")
    return skewed_df.groupBy("skew_col").agg(F.sum("rand_val")).show(3)

def action_window_sum():
    spark.sparkContext.setJobDescription("Window")
    window_df = skewed_df.withColumn("id_window_sum", F.sum("rand_val").over(Window.partitionBy("skew_col")))
    return window_df.show(3)

In [6]:
action_join()
action_groupby_sum()
action_window_sum()

+--------+-------+--------+----------+
|skew_col|     id|rand_val|number_col|
+--------+-------+--------+----------+
|       E| 100000|       6|         5|
|       E|5438200|       7|         5|
|       E| 100001|       6|         5|
+--------+-------+--------+----------+
only showing top 3 rows

+--------+-------------+
|skew_col|sum(rand_val)|
+--------+-------------+
|       E|     49491417|
|       B|         4639|
|       D|       448943|
+--------+-------------+
only showing top 3 rows

+-------+--------+--------+-------------+
|     id|skew_col|rand_val|id_window_sum|
+-------+--------+--------+-------------+
|5000000|       E|       4|     49491417|
| 438200|       E|       1|     49491417|
|5000001|       E|       3|     49491417|
+-------+--------+--------+-------------+
only showing top 3 rows



Here are some screenshots from the Spark UI showing how Spark partitioned the data and executed the tasks. Note that the images might look slightly different on multiple runs of this notbook. The processing times will also vary.

### Join

In the Task Timeline below we can clearly see there is a skew in the length of the tasks in this diagram, which translates to a skew in partition size. This is because Spark has partitioned the data on the `skew_col` before doing the `.join()`. Also note that the full partition could not fit on one executor, this is indicated by the precense of the Shuffle Spill (Memory) and Shuffle Spill (Disk) metrics.

![skew_join](./images/skew_join.PNG)


### Group by

This time we see more evenly sized tasks in the Task Timeline. This means that Spark does not partition by the grouped variable to complete the `.groupBy()`. Instead the aggregation is applied within the original partitions and results are summed up at the end. Note there is no Shuffle Spill (Memory) and Shuffle Spill (Disk) metrics, meaning there was no need to spill data from executor memory to disk in this `.groupBy()`.

We also see less green in the Task Timeline, this indicates that more time is spent on scheduling tasks or serialising data for shuffling, although we won't worry about that too much here. 

![skew_groupby](./images/skew_groupby.PNG)

### Window

Again we see a clear skew for the Window function. Also we see the spill is back.

![skew_window](./images/skew_window.PNG)


### Another group by - count distinct
The groupby was different, let's try a different groupby to see if Spark does the same thing.

In [7]:
def action_groupby_distinct():
    spark.sparkContext.setJobDescription("groupby_countDistinct")
    return skewed_df.groupBy("skew_col").agg(F.countDistinct("rand_val")).show(3)

action_groupby_distinct()

+--------+------------------------+
|skew_col|count(DISTINCT rand_val)|
+--------+------------------------+
|       E|                      11|
|       B|                      11|
|       D|                      11|
+--------+------------------------+
only showing top 3 rows



Semms like Spark does the same thing as the last groupBy() we tried- aggregate within the original partitions and sum the results for each grouped variable at the end. We haven't included a screenshot here to leave it to the reader to find the relevant diagram.

### Window alternative- group by and join

Let's look more closely at the Window function. We can achieve the same result by doing a `.groupBy()` first then `.join()` the aggregated data back onto the original DataFrame. In general, the groupby-join method is slower than a Window function because it involves two shuffles, one for the `.groupby()` and a second for the `.join()`, instead of just one shuffle for the `.Window.partitionBy()`.  

In [8]:
def action_groupby_join():
    spark.sparkContext.setJobDescription("groupby-join")
    sum_id = skewed_df.join(skewed_df.groupBy("skew_col").agg(F.sum("rand_val")),
                        on="skew_col",
                        how="left")
    return sum_id.show(3)

action_groupby_join()

+--------+-------+--------+-------------+
|skew_col|     id|rand_val|sum(rand_val)|
+--------+-------+--------+-------------+
|       E| 100000|       6|     49491417|
|       E|5438200|       7|     49491417|
|       E| 100001|       6|     49491417|
+--------+-------+--------+-------------+
only showing top 3 rows



The duration of these Spark jobs will vary with different runs of this notebook, and finding the times to compare isn't the easiest task either. The fairest comparison to make is to look at the job numbers for each method, then go to the SQL tab in the UI, and look at the duration for the corresponding SQL queries.

After executing the `action_window()` and `action_groupby_join()` functions a couple of times I get some consistent results. The Window method takes 6 seconds and the groupby-join method takes 3 seconds. 

It looks like the groupby-join method is more efficient in this case. Why? 

There are a couple of things at play here. 
- Firstly, there is a large skew that slows down the Window function but has little effect on the groupby. 
- Secondly, the second DataFrame in the join is small, so the join is relatively quick. 

So the groupby-join method is quicker for this very specific case. In practice it is quite rare to get a skew this severe, and unusual to use a join for such a small DataFrame. 

## Does `.cache()` have any effect?

As a final investigation, does caching both DataFrames have any effect on how Spark processes the data?

The function below caches both DataFrames, then the various functions are run again to see how Spark processes the data differently.

In [9]:
def action_cache():
    spark.sparkContext.setJobDescription("Cache")
    skewed_df.cache().count()
    small_df.cache().count()
    
action_cache()

In [10]:
action_join()
action_groupby_sum()
action_groupby_distinct()
action_window_sum()
action_groupby_join()

+--------+---+--------+----------+
|skew_col| id|rand_val|number_col|
+--------+---+--------+----------+
|       A|  0|       7|         1|
|       A|  1|       9|         1|
|       A|  2|       9|         1|
+--------+---+--------+----------+
only showing top 3 rows

+--------+-------------+
|skew_col|sum(rand_val)|
+--------+-------------+
|       E|     49491417|
|       B|         4639|
|       D|       448943|
+--------+-------------+
only showing top 3 rows

+--------+------------------------+
|skew_col|count(DISTINCT rand_val)|
+--------+------------------------+
|       E|                      11|
|       B|                      11|
|       D|                      11|
+--------+------------------------+
only showing top 3 rows

+-------+--------+--------+-------------+
|     id|skew_col|rand_val|id_window_sum|
+-------+--------+--------+-------------+
|5000000|       E|       4|     49491417|
| 438200|       E|       1|     49491417|
|5000001|       E|       3|     49491417|
+

It seems that caching the DataFrames had little or no effect on both groupbys, the Window function or the groupby-join. Looking through the Spark UI, the only noticeable difference was to the `.join()`. The reason being is that after the caching the small DataFrame, `small_df`, Spark knew that it was a small DataFrame and therefore performed a broadcast join instead of the default sort-merge join.

## Summary

**Join**

- Without caching we get spill and a SortMergeJoin
- With caching we don't get spill and a BroadcastHashJoin

**Group by**

- No spill
- Not affected too badly by skew or caching

**Window**

- Obvious skew and spill
- Cache has little effect

**Groupby and join**

- No skew or spill in groupby part
- Skew and spill in the join part, but this could be avoided by forcing a broadcast join, i.e. `F.broadcast()` around the smaller DataFrame
- This method was quicker than the Window function, but it unlikely to be quicker in most cases

In [11]:
spark.stop()