# Case Study 3: Joins on Skewed Data

A `skewed dataset` is defined by a dataset that has a class imbalance, this leads to poor or uncompletable spark jobs often getting `OOM` (out of memory) errors.

When performing a `join` onto a `skewed dataset` it means that there exists a class imbalance on the `key`s on which the join is performed on. This results in a majority of the data falls onto one partition, which will take longer to complete than the other partitions.

Some examples of this are:
1. The keys consist mainly of `null` values which fall onto a single partition.
2. There is a subset of keys that makeup the majority percentage of the keys which fall onto a single partition.

### Library Imports

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

Create a `SparkSession`. No need to create `SparkContext` as you automatically get it as part of the `SparkSession`.

In [2]:
spark = SparkSession.builder \
    .master("local") \
    .appName("Exploring Joins") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()

sc = spark.sparkContext

## Situation 1: Null Keys

Inital Datasets

In [19]:
customers = spark.createDataFrame([
    (1, None), 
    (2, None), 
    (3, 1),
], ["id", "card_id"])

customers.toPandas()

Unnamed: 0,id,card_id
0,1,
1,2,
2,3,1.0


In [20]:
cards = spark.createDataFrame([
    (1, "john", "doe", 21), 
    (2, "rick", "roll", 10), 
    (3, "bob", "brown", 2)
], ["card_id", "first_name", "last_name", "age"])

cards.toPandas()

Unnamed: 0,card_id,first_name,last_name,age
0,1,john,doe,21
1,2,rick,roll,10
2,3,bob,brown,2


### Option #1: Join Regularly

In [21]:
df = customers.join(cards, "card_id", "left")

df.toPandas()

Unnamed: 0,card_id,id,first_name,last_name,age
0,,1,,,
1,,2,,,
2,1.0,3,john,doe,21.0


In [22]:
df.explain()

== Physical Plan ==
*(3) Project [card_id#85L, id#84L, first_name#89, last_name#90, age#91L]
+- SortMergeJoin [card_id#85L], [card_id#88L], LeftOuter
   :- *(1) Sort [card_id#85L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(card_id#85L, 200)
   :     +- Scan ExistingRDD[id#84L,card_id#85L]
   +- *(2) Sort [card_id#88L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(card_id#88L, 200)
         +- Scan ExistingRDD[card_id#88L,first_name#89,last_name#90,age#91L]


**What Happened**:
* Rows that didn't join up were brought to the join.
* They get `Null` values for the right side columns.

**Results**:
* We brought more data to the join than we had to.

### Option #2: Filter Null Keys First, then Join, then Union

In [23]:
def null_skew_helper(left, right, key):
    """
    Steps:
        1. Filter out the null rows.
        2. Create the columns you would get from the join.
        3. Join the tables.
        4. Union the null rows to joined table.
    """
    df1 = left.where(F.col(key).isNull())
    for f in right.schema.fields:
            df1 = df1.withColumn(f.name, F.lit(None).cast(f.dataType))
    
    df2 = left.where(F.col(key).isNotNull())
    df2 = df2.join(right, key, "left")
    
    return df1.union(df2.select(df1.columns))
    
    
df = null_skew_helper(customers, cards, "card_id")

df.toPandas()

Unnamed: 0,id,card_id,first_name,last_name,age
0,1,,,,
1,2,,,,
2,3,1.0,john,doe,21.0


In [24]:
df.explain()

== Physical Plan ==
Union
:- *(1) Project [id#84L, null AS card_id#102L, null AS first_name#105, null AS last_name#109, null AS age#114L]
:  +- *(1) Filter isnull(card_id#85L)
:     +- Scan ExistingRDD[id#84L,card_id#85L]
+- *(5) Project [id#84L, card_id#85L, first_name#89, last_name#90, age#91L]
   +- SortMergeJoin [card_id#85L], [card_id#88L], LeftOuter
      :- *(3) Sort [card_id#85L ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(card_id#85L, 200)
      :     +- *(2) Filter isnotnull(card_id#85L)
      :        +- Scan ExistingRDD[id#84L,card_id#85L]
      +- *(4) Sort [card_id#88L ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(card_id#88L, 200)
            +- Scan ExistingRDD[card_id#88L,first_name#89,last_name#90,age#91L]


**What Happened**:
* We filtered all the rows out before the join.
* We did the join with less data.
* We read the table again and got the null rows.
* Unioned with the joined results.

**Results**:
* We brought less data to the join.
* But we read the data twice.

### Option #3: Cache Table, Filter Null Keys First, then Join, then Union

**Helper Function**

In [25]:
def null_skew_helper(left, right, key):
    """
    Steps:
        1. Cache table.
        2. Filter out the null rows.
        3. Create the columns you would get from the join.
        4. Join the tables.
        5. Union the null rows to joined table.
    """
    left = left.cache()
    
    df1 = left.where(F.col(key).isNull())
    for f in right.schema.fields:
            df1 = df1.withColumn(f.name, F.lit(None).cast(f.dataType))
    
    df2 = left.where(F.col(key).isNotNull())
    df2 = df2.join(right, key, "left")
    
    return df1.union(df2.select(df1.columns))

In [26]:
df = null_skew_helper(customers, cards, "card_id")

df.toPandas()

Unnamed: 0,id,card_id,first_name,last_name,age
0,1,,,,
1,2,,,,
2,3,1.0,john,doe,21.0


In [27]:
df.explain()

== Physical Plan ==
Union
:- *(1) Project [id#84L, null AS card_id#147L, null AS first_name#150, null AS last_name#154, null AS age#159L]
:  +- *(1) Filter isnull(card_id#85L)
:     +- *(1) InMemoryTableScan [card_id#85L, id#84L], [isnull(card_id#85L)]
:           +- InMemoryRelation [id#84L, card_id#85L], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
:                 +- Scan ExistingRDD[id#84L,card_id#85L]
+- *(5) Project [id#84L, card_id#85L, first_name#89, last_name#90, age#91L]
   +- SortMergeJoin [card_id#85L], [card_id#88L], LeftOuter
      :- *(3) Sort [card_id#85L ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(card_id#85L, 200)
      :     +- *(2) Filter isnotnull(card_id#85L)
      :        +- *(2) InMemoryTableScan [id#84L, card_id#85L], [isnotnull(card_id#85L)]
      :              +- InMemoryRelation [id#84L, card_id#85L], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
      :                    +- Scan ExistingRDD[i

**What Happened**:
* Similar to option #2, but we did a `InMemoryTableScan` instead of a read of the data.

**Results**:
* We brought less data to the join.
* We did 1 less read, but we used more memory.

## TL;DR

As always there is pros and cons.

Pros:
* Ideally you want to bring less data to a join. 
* This is unneeded data and in most cases causes a spark job to fail. 
* This is due to the fact that all the null key rows will go onto one partition.

Cons:
* There will either be one extra read of data or more memory used.

All to say:
1. It's definitely better to bring less data to a join, so performing an union and filter of the `null keys` before the join is definitely suggested.
2. This will result in an extra read of data or memory usage.
3. Decide if you can afford the extra read vs memory usage and `cache` the table before the filter.

Always check the spread of values for the `join key`, to detect if there's any skew and any optimizations you can do.

## Situation 2: Data Skew

Inital Datasets

In [28]:
customers = spark.createDataFrame([
    (1, "John"), 
    (2, "Bob"),
], ["customer_id", "first_name"])

customers.toPandas()

Unnamed: 0,customer_id,first_name
0,1,John
1,2,Bob


In [29]:
orders = spark.createDataFrame([
    (i, 1 if i < 95 else 2, "order #{}".format(i)) for i in range(100) 
], ["id", "customer_id", "order_name"])

orders.toPandas().tail(6)

Unnamed: 0,id,customer_id,order_name
94,94,1,order #94
95,95,2,order #95
96,96,2,order #96
97,97,2,order #97
98,98,2,order #98
99,99,2,order #99


### Option 1: Regular Join

In [30]:
df = customers.join(orders, "customer_id")

df.toPandas().tail(10)

Unnamed: 0,customer_id,first_name,id,order_name
90,1,John,90,order #90
91,1,John,91,order #91
92,1,John,92,order #92
93,1,John,93,order #93
94,1,John,94,order #94
95,2,Bob,95,order #95
96,2,Bob,96,order #96
97,2,Bob,97,order #97
98,2,Bob,98,order #98
99,2,Bob,99,order #99


In [31]:
df.explain()

== Physical Plan ==
*(5) Project [customer_id#201L, first_name#202, id#205L, order_name#207]
+- *(5) SortMergeJoin [customer_id#201L], [customer_id#206L], Inner
   :- *(2) Sort [customer_id#201L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(customer_id#201L, 200)
   :     +- *(1) Filter isnotnull(customer_id#201L)
   :        +- Scan ExistingRDD[customer_id#201L,first_name#202]
   +- *(4) Sort [customer_id#206L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(customer_id#206L, 200)
         +- *(3) Filter isnotnull(customer_id#206L)
            +- Scan ExistingRDD[id#205L,customer_id#206L,order_name#207]


**What Happened**:
* We want to find what `order`s each `customer` made, so we will be `join`ing the `customer`s table to the `order`s table.
* When performing the join, we perform a `hashpartitioning` on `customer_id`.
* From our data creation, this means 95% of the data landed onto a single partition. 

**Results**:
* Similar to the `Null Skew` case, this means that single task/partition will take a lot longer than the others, and most likely erroring out.

### Option 2: Salt the key, then Join

**Helper Function**

In [32]:
def data_skew_helper(left, right, key, number_of_partitions, how="inner"):
    salt_value = F.lit(F.rand() * number_of_partitions % number_of_partitions).cast('int')
    left = left.withColumn("salt", salt_value)
    
    salt_col = F.explode(F.array([F.lit(i) for i in range(number_of_partitions)])).alias("salt")
    right = right.select("*",  salt_col)

    return left.join(right, [key, "salt"], how).drop("salt")

**Example**

In [33]:
left = orders

salt_value = F.lit(F.rand() * 2 % 2).cast('int')    
left.withColumn("salt", salt_value).toPandas().head(5)

Unnamed: 0,id,customer_id,order_name,salt
0,0,1,order #0,1
1,1,1,order #1,1
2,2,1,order #2,1
3,3,1,order #3,1
4,4,1,order #4,1


In [34]:
right = customers

salt_col = F.explode(F.array([F.lit(i) for i in range(2)])).alias("salt")
right.select("*",  salt_col).toPandas().head(10)


Unnamed: 0,customer_id,first_name,salt
0,1,John,0
1,1,John,1
2,2,Bob,0
3,2,Bob,1


In [35]:
df = data_skew_helper(orders, customers, "customer_id", 5)

df.orderBy('id').toPandas().tail(10)

Unnamed: 0,customer_id,id,order_name,first_name
90,1,90,order #90,John
91,1,91,order #91,John
92,1,92,order #92,John
93,1,93,order #93,John
94,1,94,order #94,John
95,2,95,order #95,Bob
96,2,96,order #96,Bob
97,2,97,order #97,Bob
98,2,98,order #98,Bob
99,2,99,order #99,Bob


In [36]:
df.explain()

== Physical Plan ==
*(5) Project [customer_id#206L, id#205L, order_name#207, first_name#202]
+- *(5) SortMergeJoin [customer_id#206L, salt#225], [customer_id#201L, salt#231], Inner
   :- *(2) Sort [customer_id#206L ASC NULLS FIRST, salt#225 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(customer_id#206L, salt#225, 200)
   :     +- *(1) Filter (isnotnull(customer_id#206L) && isnotnull(salt#225))
   :        +- *(1) Project [id#205L, customer_id#206L, order_name#207, cast(((rand(454193232799453665) * 5.0) % 5.0) as int) AS salt#225]
   :           +- Scan ExistingRDD[id#205L,customer_id#206L,order_name#207]
   +- *(4) Sort [customer_id#201L ASC NULLS FIRST, salt#231 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(customer_id#201L, salt#231, 200)
         +- Generate explode([0,1,2,3,4]), [customer_id#201L, first_name#202], false, [salt#231]
            +- *(3) Filter isnotnull(customer_id#201L)
               +- Scan ExistingRDD[customer_id#201L,first_name#20

**What Happened**:
* We create a new `salt` column for both datasets.
* On one of the dataset we duplicate the data so we have a row for each `salt` value.
* When performing the join, we perform a `hashpartitioning` on `[customer_id, salt]`.

**Results**:
* When we produce a row per `salt` value, we have essentially duplicated `(num_partitions - 1) * N` rows.
* This allows us to spread the data across more partitions as you can see from `hashpartitioning(customer_id, salt)`.

## TL;DR

All to say:
* Again we will sacrifice more resources in order to get a performance gain or a successful run.
* By `salt`ing our keys, the `skewed` dataset gets divided into smaller partitions. Thus removing the skew.
* We produced more data by creating `(num_partitions - 1) * N` more data for the right side.
