# Exploring Joins in Spark

### Library Imports

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

Create a `SparkSession`. No need to create `SparkContext` as you automatically get it as part of the `SparkSession`.

In [3]:
spark = SparkSession.builder \
    .master("local") \
    .appName("Exploring Joins") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()

# Case Study 1: Natural vs Regular Joins

Two types of joins:
1. `Natural Join`  
    A Natural Join is where 2 tables are joined on the basis of all common columns.      
    ie. `left.join(right, 'key')`

2. `Regular Join`  
    A Inner Join is where 2 tables are joined on the basis of common columns mentioned in the ON clause.
    ie. `left.join(right, left[lkey] == right[rkey])`

**Question:**
    Is `rename`ing a `column` then doing a `natural join` better than doing an `inner join`? Is it just a style choice?

### Initial Datasets

In [3]:
df_1 = spark.createDataFrame(
    [
        (1, 1, 'a'), 
        (2, 1, 'b'), 
        (2, 2, 'c'), 
    ], ['id', 'data_id', 'val_1']
)

df_1.toPandas()

Unnamed: 0,id,data_id,val_1
0,1,1,a
1,2,1,b
2,2,2,c


In [5]:
df_2 = spark.createDataFrame(
    [
        (1, 1, 10), 
        (2, 2, 20), 
    ], ['shop_id', 'data_id', 'val_2']
)

df_2.toPandas()

Unnamed: 0,shop_id,data_id,val_2
0,1,1,10
1,2,2,20


## Option 1: Rename Key, then Join

In [6]:
df_3 = df_1.withColumnRenamed('id', 'shop_id')

df = df_3.join(df_2, 'shop_id')

df.toPandas()

Unnamed: 0,shop_id,data_id,val_1,data_id.1,val_2
0,1,1,a,1,10
1,2,1,b,2,20
2,2,2,c,2,20


In [7]:
df.explain()

== Physical Plan ==
*(5) Project [shop_id#12L, data_id#1L, val_1#2, data_id#7L, val_2#8L]
+- *(5) SortMergeJoin [shop_id#12L], [shop_id#6L], Inner
   :- *(2) Sort [shop_id#12L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(shop_id#12L, 200)
   :     +- *(1) Project [id#0L AS shop_id#12L, data_id#1L, val_1#2]
   :        +- *(1) Filter isnotnull(id#0L)
   :           +- Scan ExistingRDD[id#0L,data_id#1L,val_1#2]
   +- *(4) Sort [shop_id#6L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(shop_id#6L, 200)
         +- *(3) Filter isnotnull(shop_id#6L)
            +- Scan ExistingRDD[shop_id#6L,data_id#7L,val_2#8L]


**What Happened**:
* An extra `Project` was performed before the join due to the rename.

## Option 2: Don't Rename, Regular Join, Drop Column

In [8]:
join_condition = df_1['id'] == df_2['shop_id']

df = df_1 \
    .join(df_2, join_condition) \
    .drop(df_1['id'])

df.toPandas()

Unnamed: 0,data_id,val_1,shop_id,data_id.1,val_2
0,1,a,1,1,10
1,1,b,2,2,20
2,2,c,2,2,20


In [9]:
df.explain()

== Physical Plan ==
*(5) Project [data_id#1L, val_1#2, shop_id#6L, data_id#7L, val_2#8L]
+- *(5) SortMergeJoin [id#0L], [shop_id#6L], Inner
   :- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(id#0L, 200)
   :     +- *(1) Filter isnotnull(id#0L)
   :        +- Scan ExistingRDD[id#0L,data_id#1L,val_1#2]
   +- *(4) Sort [shop_id#6L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(shop_id#6L, 200)
         +- *(3) Filter isnotnull(shop_id#6L)
            +- Scan ExistingRDD[shop_id#6L,data_id#7L,val_2#8L]


**What Happened**:
* No `Project` was done before the join

## TL;DR

Option #1
* Looks nicer and more elegant.
* How does it perform?

Option #2
* There is one less project as expected without the `withColumnRenamed`.
* But the join claus is a lot longer and ugly.

**I personally like option #1.**

# Case Study 2: Filter Pushdown

> `Filter pushdown` improves performance by reducing the amount of data shuffled during any dataframes transformations.

### Initial Datasets

In [4]:
df_1 = spark.createDataFrame(
    [
        (1, 1, 'a'), 
        (2, 1, 'b'), 
        (2, 2, 'c'), 
    ], ['shop_id', 'data_id', 'val_1']
)

df_1.toPandas()

Unnamed: 0,shop_id,data_id,val_1
0,1,1,a
1,2,1,b
2,2,2,c


In [5]:
df_2 = spark.createDataFrame(
    [
        (1, 1, 10), 
        (2, 2, 20), 
    ], ['shop_id', 'data_id', 'val_2']
)

df_2.toPandas()

Unnamed: 0,shop_id,data_id,val_2
0,1,1,10
1,2,2,20


## Option #1: Join the data, then perform Filter

In [10]:
df = df_1 \
    .join(df_2.drop('shop_id'), 'data_id') \
    .filter(F.col('shop_id') == 1)

df.toPandas()

Unnamed: 0,data_id,shop_id,val_1,val_2
0,1,1,a,10


In [11]:
df.explain()

== Physical Plan ==
*(5) Project [data_id#1L, shop_id#0L, val_1#2, val_2#8L]
+- *(5) SortMergeJoin [data_id#1L], [data_id#7L], Inner
   :- *(2) Sort [data_id#1L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(data_id#1L, 200)
   :     +- *(1) Filter ((isnotnull(shop_id#0L) && (shop_id#0L = 1)) && isnotnull(data_id#1L))
   :        +- Scan ExistingRDD[shop_id#0L,data_id#1L,val_1#2]
   +- *(4) Sort [data_id#7L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(data_id#7L, 200)
         +- *(3) Project [data_id#7L, val_2#8L]
            +- *(3) Filter isnotnull(data_id#7L)
               +- Scan ExistingRDD[shop_id#6L,data_id#7L,val_2#8L]


**What Happened:**

* We can see that the filter is after the join and not pushed down. 
* This means all of the data is brough to the join.
* Then the filter is done.

**Results:**

We bring more data to the join and shuffle, **this is bad**.

## Option #2: Join on Filter Key, then Filter

In [12]:
df = df_1 \
    .join(df_2, ['shop_id', 'data_id']) \
    .filter(F.col('shop_id') == 1)

df.toPandas()

Unnamed: 0,shop_id,data_id,val_1,val_2
0,1,1,a,10


In [13]:
df.explain()

== Physical Plan ==
*(5) Project [shop_id#0L, data_id#1L, val_1#2, val_2#8L]
+- *(5) SortMergeJoin [shop_id#0L, data_id#1L], [shop_id#6L, data_id#7L], Inner
   :- *(2) Sort [shop_id#0L ASC NULLS FIRST, data_id#1L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(shop_id#0L, data_id#1L, 200)
   :     +- *(1) Filter ((isnotnull(shop_id#0L) && (shop_id#0L = 1)) && isnotnull(data_id#1L))
   :        +- Scan ExistingRDD[shop_id#0L,data_id#1L,val_1#2]
   +- *(4) Sort [shop_id#6L ASC NULLS FIRST, data_id#7L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(shop_id#6L, data_id#7L, 200)
         +- *(3) Filter ((isnotnull(shop_id#6L) && isnotnull(data_id#7L)) && (shop_id#6L = 1))
            +- Scan ExistingRDD[shop_id#6L,data_id#7L,val_2#8L]


**What Happened:**
* The filter got pushed down.
* Less data is brought to the join and shuffle.

**Results:**

We bring less data to the join and shuffle, **this is good**.

## Option #3: Filter Left, then Join

In [14]:
df = df_1 \
    .filter(F.col('shop_id') == 1) \
    .join(df_2.drop('shop_id'), 'data_id')

df.toPandas()

Unnamed: 0,data_id,shop_id,val_1,val_2
0,1,1,a,10


In [15]:
df.explain()

== Physical Plan ==
*(5) Project [data_id#1L, shop_id#0L, val_1#2, val_2#8L]
+- *(5) SortMergeJoin [data_id#1L], [data_id#7L], Inner
   :- *(2) Sort [data_id#1L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(data_id#1L, 200)
   :     +- *(1) Filter ((isnotnull(shop_id#0L) && (shop_id#0L = 1)) && isnotnull(data_id#1L))
   :        +- Scan ExistingRDD[shop_id#0L,data_id#1L,val_1#2]
   +- *(4) Sort [data_id#7L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(data_id#7L, 200)
         +- *(3) Project [data_id#7L, val_2#8L]
            +- *(3) Filter isnotnull(data_id#7L)
               +- Scan ExistingRDD[shop_id#6L,data_id#7L,val_2#8L]


**What Happened:**
* This is exactly the same as case 1.

**Results:**

We bring less data to the join and shuffle, **this is bad**.

## Option #4: Filter Both, then Join

In [16]:
df = df_2 \
    .filter(F.col('shop_id') == 1) \
    .drop('shop_id')

df = df_1 \
    .filter(F.col('shop_id') == 1) \
    .join(df_3, 'data_id')

df.toPandas()

Unnamed: 0,data_id,shop_id,val_1,shop_id.1,val_1.1,val_2
0,1,1,a,1,a,10


In [17]:
df.explain()

== Physical Plan ==
*(8) Project [data_id#1L, shop_id#0L, val_1#2, shop_id#54L, val_1#56, val_2#8L]
+- *(8) SortMergeJoin [data_id#1L], [data_id#55L], Inner
   :- *(2) Sort [data_id#1L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(data_id#1L, 200)
   :     +- *(1) Filter ((isnotnull(shop_id#0L) && (shop_id#0L = 1)) && isnotnull(data_id#1L))
   :        +- Scan ExistingRDD[shop_id#0L,data_id#1L,val_1#2]
   +- *(7) Project [data_id#55L, shop_id#54L, val_1#56, val_2#8L]
      +- *(7) SortMergeJoin [data_id#55L], [data_id#7L], Inner
         :- *(4) Sort [data_id#55L ASC NULLS FIRST], false, 0
         :  +- ReusedExchange [shop_id#54L, data_id#55L, val_1#56], Exchange hashpartitioning(data_id#1L, 200)
         +- *(6) Sort [data_id#7L ASC NULLS FIRST], false, 0
            +- Exchange hashpartitioning(data_id#7L, 200)
               +- *(5) Project [data_id#7L, val_2#8L]
                  +- *(5) Filter isnotnull(data_id#7L)
                     +- Scan ExistingRDD[shop_id

## TL;DR

* We should always try to push the filter down as much as possible. 
* This means that there will be less data being shuffled and joined during the join. 
* This can be achieved with join in case #2 or #4.

**Option #2** (Good)
* When we `join`ed on `filter`ed on the key `shop_id` this caused a `filter-pushdown` which is good.
* But this made us `sort` on 2 keys.

**Option #4** (Better)
* When we pre `filter` the `join`ing datasets, this caused a `filter-pushdown` which is good.
* We only `join` on one key as well, which is good as we only sort on 1 key.

# Case Study 3: Joins on Skewed Data

A `skewed dataset` is defined by a dataset that has a class imbalance, this leads to poor or uncompletable spark jobs often getting `OOM` (out of memory) errors.

When performing a `join` onto a `skewed dataset` it means that there exists a class imbalance on the `key`s on which the join is performed on. This results in a majority of the data falls onto one partition, which will take longer to complete than the other partitions.

Some examples of this are:
1. The keys consist mainly of `null` values which fall onto a single partition.
2. There is a subset of keys that makeup the majority percentage of the keys which fall onto a single partition.

## Situation 1: Null Keys

Inital Datasets

In [64]:
customers = spark.createDataFrame([
    (1, None), 
    (2, None), 
    (3, 1),
], ["id", "card_id"])

customers.toPandas()

Unnamed: 0,id,card_id
0,1,
1,2,
2,3,1.0


In [65]:
cards = spark.createDataFrame([
    (1, "john", "doe", 21), 
    (2, "rick", "roll", 10), 
    (3, "bob", "brown", 2)
], ["card_id", "first_name", "last_name", "age"])

cards.toPandas()

Unnamed: 0,card_id,first_name,last_name,age
0,1,john,doe,21
1,2,rick,roll,10
2,3,bob,brown,2


### Option #1: Join Regularly

In [23]:
df = customers.join(cards, "card_id", "left")

df.toPandas()

Unnamed: 0,card_id,id,first_name,last_name,age
0,,1,,,
1,,2,,,
2,1.0,3,john,doe,21.0


In [24]:
df.explain()

== Physical Plan ==
*(3) Project [card_id#84L, id#83L, first_name#88, last_name#89, age#90L]
+- SortMergeJoin [card_id#84L], [card_id#87L], LeftOuter
   :- *(1) Sort [card_id#84L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(card_id#84L, 200)
   :     +- Scan ExistingRDD[id#83L,card_id#84L]
   +- *(2) Sort [card_id#87L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(card_id#87L, 200)
         +- Scan ExistingRDD[card_id#87L,first_name#88,last_name#89,age#90L]


**What Happened**:
* Rows that didn't join up were brought to the join.
* They get `Null` values for the right side columns.

**Results**:
* We brought more data to the join than we had to.

### Option #2: Filter Null Keys First, then Join, then Union

In [27]:
def null_skew_helper(left, right, key):
    """
    Steps:
        1. Filter out the null rows.
        2. Create the columns you would get from the join.
        3. Join the tables.
        4. Union the null rows to joined table.
    """
    df1 = left.where(F.col(key).isNull())
    for f in right.schema.fields:
            df1 = df1.withColumn(f.name, F.lit(None).cast(f.dataType))
    
    df2 = left.where(F.col(key).isNotNull())
    df2 = df2.join(right, key, "left")
    
    return df1.union(df2.select(df1.columns))
    
    
df = null_skew_helper(customers, cards, "card_id")

df.toPandas()

Unnamed: 0,id,card_id,first_name,last_name,age
0,1,,,,
1,2,,,,
2,3,1.0,john,doe,21.0


In [28]:
df.explain()

== Physical Plan ==
Union
:- *(1) Project [id#83L, null AS card_id#101L, null AS first_name#104, null AS last_name#108, null AS age#113L]
:  +- *(1) Filter isnull(card_id#84L)
:     +- Scan ExistingRDD[id#83L,card_id#84L]
+- *(5) Project [id#83L, card_id#84L, first_name#88, last_name#89, age#90L]
   +- SortMergeJoin [card_id#84L], [card_id#87L], LeftOuter
      :- *(3) Sort [card_id#84L ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(card_id#84L, 200)
      :     +- *(2) Filter isnotnull(card_id#84L)
      :        +- Scan ExistingRDD[id#83L,card_id#84L]
      +- *(4) Sort [card_id#87L ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(card_id#87L, 200)
            +- Scan ExistingRDD[card_id#87L,first_name#88,last_name#89,age#90L]


**What Happened**:
* We filtered all the rows out before the join.
* We did the join with less data.
* We read the table again and got the null rows.
* Unioned with the joined results.

**Results**:
* We brought less data to the join.
* But we read the data twice.

### Option #3: Cache Table, Filter Null Keys First, then Join, then Union

**Helper Function**

In [62]:
def null_skew_helper(left, right, key):
    """
    Steps:
        1. Filter out the null rows.
        2. Create the columns you would get from the join.
        3. Join the tables.
        4. Union the null rows to joined table.
    """
    left = left.cache()
    
    df1 = left.where(F.col(key).isNull())
    for f in right.schema.fields:
            df1 = df1.withColumn(f.name, F.lit(None).cast(f.dataType))
    
    df2 = left.where(F.col(key).isNotNull())
    df2 = df2.join(right, key, "left")
    
    return df1.union(df2.select(df1.columns))

In [66]:
df = null_skew_helper(customers, cards, "card_id")

df.toPandas()

Unnamed: 0,id,card_id,first_name,last_name,age
0,1,,,,
1,2,,,,
2,3,1.0,john,doe,21.0


In [30]:
df.explain()

== Physical Plan ==
Union
:- *(1) Project [id#83L, null AS card_id#146L, null AS first_name#149, null AS last_name#153, null AS age#158L]
:  +- *(1) Filter isnull(card_id#84L)
:     +- *(1) InMemoryTableScan [card_id#84L, id#83L], [isnull(card_id#84L)]
:           +- InMemoryRelation [id#83L, card_id#84L], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
:                 +- Scan ExistingRDD[id#83L,card_id#84L]
+- *(5) Project [id#83L, card_id#84L, first_name#88, last_name#89, age#90L]
   +- SortMergeJoin [card_id#84L], [card_id#87L], LeftOuter
      :- *(3) Sort [card_id#84L ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(card_id#84L, 200)
      :     +- *(2) Filter isnotnull(card_id#84L)
      :        +- *(2) InMemoryTableScan [id#83L, card_id#84L], [isnotnull(card_id#84L)]
      :              +- InMemoryRelation [id#83L, card_id#84L], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
      :                    +- Scan ExistingRDD[i

**What Happened**:
* Similar to option #2, but we did a `InMemoryTableScan` instead of a read of the data.

**Results**:
* We brought less data to the join.
* We did 1 less read, but we used more memory.

## TL;DR

As always there is pros and cons.

Pros:
* Ideally you want to bring less data to a join. 
* This is unneeded data and in most cases causes a spark job to fail. 
* This is due to the fact that all the null key rows will go onto one partition.

Cons:
* There will either be one extra read of data or more memory used.

All to say:
1. It's definitely better to bring less data to a join, so performing an union and filter of the `null keys` before the join is definitely suggested.
2. This will result in an extra read of data or memory usage.
3. Decide if you can afford the extra read vs memory usage and `cache` the table before the filter.

Always check the spread of values for the `join key`, to detect if there's any skew and any optimizations you can do.

## Situation 2: Data Skew

Inital Datasets

In [85]:
customers = spark.createDataFrame([
    (1, "John"), 
    (2, "Bob"),
], ["customer_id", "first_name"])

customers.toPandas()

Unnamed: 0,customer_id,first_name
0,1,John
1,2,Bob


In [46]:
orders = spark.createDataFrame([
    (i, 1 if i < 95 else 2, "order #{}".format(i)) for i in range(100) 
], ["id", "customer_id", "order_name"])

orders.toPandas().tail(6)

Unnamed: 0,id,customer_id,order_name
94,94,1,order #94
95,95,2,order #95
96,96,2,order #96
97,97,2,order #97
98,98,2,order #98
99,99,2,order #99


### Option 1: Regular Join

In [58]:
df = customers.join(orders, "customer_id")

df.toPandas().tail(10)

Unnamed: 0,customer_id,first_name,id,order_name
90,1,John,90,order #90
91,1,John,91,order #91
92,1,John,92,order #92
93,1,John,93,order #93
94,1,John,94,order #94
95,2,Bob,95,order #95
96,2,Bob,96,order #96
97,2,Bob,97,order #97
98,2,Bob,98,order #98
99,2,Bob,99,order #99


In [59]:
df.explain()

== Physical Plan ==
*(5) Project [customer_id#194L, first_name#195, id#188L, order_name#190]
+- *(5) SortMergeJoin [customer_id#194L], [customer_id#189L], Inner
   :- *(2) Sort [customer_id#194L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(customer_id#194L, 200)
   :     +- *(1) Filter isnotnull(customer_id#194L)
   :        +- Scan ExistingRDD[customer_id#194L,first_name#195]
   +- *(4) Sort [customer_id#189L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(customer_id#189L, 200)
         +- *(3) Filter isnotnull(customer_id#189L)
            +- Scan ExistingRDD[id#188L,customer_id#189L,order_name#190]


**What Happened**:
* We want to find what `order`s each `customer` made, so we will be `join`ing the `customer`s table to the `order`s table.
* When performing the join, we perform a `hashpartitioning` on `customer_id`.
* From our data creation, this means 95% of the data landed onto a single partition. 

**Results**:
* Similar to the `Null Skew` case, this means that single task/partition will take a lot longer than the others, and most likely erroring out.

### Option 2: Salt the key, then Join

**Helper Function**

In [123]:
def data_skew_helper(left, right, key, number_of_partitions, how="inner"):
    salt_value = F.lit(F.rand() * num_partitions % num_partitions).cast('int')
    
    left = left.withColumn("salt", salt_value)

    salt_values = right.sql_ctx.range(numPartitions, numPartitions=1)
    salt_values = salt_values.withColumnRenamed(salt_values.columns[0], "salt")
    
    right = right.crossJoin(salt_values)

    return left.join(right, [key, "salt"], how)

**Example**

In [127]:
left = orders

salt_value = F.lit(F.rand() * num_partitions % num_partitions).cast('int')    
left.withColumn("salt", salt_value).toPandas().head(5)

Unnamed: 0,id,customer_id,order_name,salt
0,0,1,order #0,1
1,1,1,order #1,4
2,2,1,order #2,4
3,3,1,order #3,1
4,4,1,order #4,1


In [128]:
right = customers

salt_values = right.sql_ctx.range(2, numPartitions=1)
salt_values = salt_values.withColumnRenamed(salt_values.columns[0], "salt")

right.crossJoin(salt_values).toPandas().head(10)


Unnamed: 0,customer_id,first_name,salt
0,1,John,0
1,1,John,1
2,2,Bob,0
3,2,Bob,1


In [117]:
df = data_skew_helper(orders, customers, "customer_id", 5)

df.toPandas().tail(10)

Unnamed: 0,customer_id,salt,id,order_name,first_name
90,1,2,40,order #40,John
91,1,2,47,order #47,John
92,1,2,54,order #54,John
93,1,2,56,order #56,John
94,1,2,61,order #61,John
95,1,2,78,order #78,John
96,1,2,81,order #81,John
97,1,2,86,order #86,John
98,1,2,89,order #89,John
99,1,2,94,order #94,John


In [94]:
df.explain()

== Physical Plan ==
*(6) Project [customer_id#373L, salt#445, first_name#374, id#188L, order_name#190]
+- *(6) SortMergeJoin [customer_id#373L, cast(salt#445 as bigint)], [customer_id#189L, salt#451L], Inner
   :- *(2) Sort [customer_id#373L ASC NULLS FIRST, cast(salt#445 as bigint) ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(customer_id#373L, cast(salt#445 as bigint), 200)
   :     +- *(1) Filter (isnotnull(salt#445) && isnotnull(customer_id#373L))
   :        +- *(1) Project [customer_id#373L, first_name#374, cast(((rand(2394801416674268296) * 5.0) % 5.0) as int) AS salt#445]
   :           +- Scan ExistingRDD[customer_id#373L,first_name#374]
   +- *(5) Sort [customer_id#189L ASC NULLS FIRST, salt#451L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(customer_id#189L, salt#451L, 200)
         +- BroadcastNestedLoopJoin BuildRight, Cross
            :- *(3) Filter isnotnull(customer_id#189L)
            :  +- Scan ExistingRDD[id#188L,customer_id#189L,ord

**What Happened**:
* We create a new `salt` column for both datasets.
* On one of the dataset we duplicate the data so we have a row for each `salt` value.
* When performing the join, we perform a `hashpartitioning` on `[customer_id, salt]`.

**Results**:
* When we produce a row per `salt` value, we have essentially duplicated `(num_partitions - 1) * N` rows.
* This allows us to spread the data across more partitions as you can see from `hashpartitioning(customer_id, salt)`.

## TL;DR

All to say:
* Again we will sacrifice more resources in order to get a performance gain or a successful run.
* By `salt`ing our keys, the `skewed` dataset gets divided into smaller partitions. Thus removing the skew.
* We produced more data by creating `(num_partitions - 1) * N` more data for the right side.


## Case Study 4: Range Join Conditions [TODO]

> A naive approach (just specifying this as the range condition) would result in a full cartesian product and a filter that enforces the condition (tested using Spark 2.0). This has a horrible effect on performance, especially if DataFrames are more than a few hundred thousands records.

source: http://zachmoshe.com/2016/09/26/efficient-range-joins-with-spark.html

> The source of the problem is pretty simple. When you execute join and join condition is not equality based the only thing that Spark can do right now is expand it to Cartesian product followed by filter what is pretty much what happens inside `BroadcastNestedLoopJoin`

source: https://stackoverflow.com/questions/37953830/spark-sql-performance-join-on-value-between-min-and-max?answertab=active#tab-top

### Initial Dataset

In [21]:
geo_loc_table = spark.createDataFrame([
    (1, 10, "foo"), 
    (11, 36, "bar"), 
    (37, 59, "baz"),
], ["ipstart", "ipend", "loc"])

geo_loc_table.toPandas()

Unnamed: 0,ipstart,ipend,loc
0,1,10,foo
1,11,36,bar
2,37,59,baz


In [22]:
records_table = spark.createDataFrame([
    (1, 11), 
    (2, 38), 
    (3, 50),
],["id", "inet"])

records_table.toPandas()

Unnamed: 0,id,inet
0,1,11
1,2,38
2,3,50


### Option #1

In [21]:
join_condition = [
    records_table['inet'] >= geo_loc_table['ipstart'],
    records_table['inet'] <= geo_loc_table['ipend'],
]

df = records_table.join(geo_loc_table, join_condition, "left")

df.toPandas()

Unnamed: 0,id,inet,ipstart,ipend,loc
0,1,11,11,36,bar
1,2,38,37,59,baz
2,3,50,37,59,baz


In [22]:
df.explain()

== Physical Plan ==
BroadcastNestedLoopJoin BuildRight, LeftOuter, ((inet#90L >= ipstart#83L) && (inet#90L <= ipend#84L))
:- Scan ExistingRDD[id#89L,inet#90L]
+- BroadcastExchange IdentityBroadcastMode
   +- Scan ExistingRDD[ipstart#83L,ipend#84L,loc#85]


### Option #2

In [23]:
from bisect import bisect_right
from pyspark.sql.functions import udf
from pyspark.sql.types import LongType

geo_start_bd = sc.broadcast(map(lambda x: x.ipstart, geo_loc_table
    .select("ipstart")
    .orderBy("ipstart")
    .collect()
))

def find_le(x):
    'Find rightmost value less than or equal to x'
    i = bisect_right(geo_start_bd.value, x)
    if i:
        return geo_start_bd.value[i-1]
    return None

records_table_with_ipstart = records_table.withColumn(
    "ipstart", udf(find_le, LongType())("inet")
)

df = records_table_with_ipstart.join(geo_loc_table, ["ipstart"], "left")

df.toPandas()

Unnamed: 0,ipstart,id,inet,ipend,loc
0,37,2,38,59,baz
1,37,3,50,59,baz
2,11,1,11,36,bar


In [24]:
df.explain()

== Physical Plan ==
*(4) Project [ipstart#110L, id#89L, inet#90L, ipend#84L, loc#85]
+- SortMergeJoin [ipstart#110L], [ipstart#83L], LeftOuter
   :- *(2) Sort [ipstart#110L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(ipstart#110L, 200)
   :     +- *(1) Project [id#89L, inet#90L, pythonUDF0#119L AS ipstart#110L]
   :        +- BatchEvalPython [find_le(inet#90L)], [id#89L, inet#90L, pythonUDF0#119L]
   :           +- Scan ExistingRDD[id#89L,inet#90L]
   +- *(3) Sort [ipstart#83L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(ipstart#83L, 200)
         +- Scan ExistingRDD[ipstart#83L,ipend#84L,loc#85]


In [29]:
df = spark.createDataFrame([
    (1, 11), 
    (2, 38), 
    (3, 50),
],["id", "val"])

df.toPandas()

Unnamed: 0,id,val
0,1,11
1,2,38
2,3,50
