# Pyspark Helper Functions
### [For less verbose and fool-proof operations]

In [1]:
try:
    from pyspark import SparkConf
except ImportError:
    ! pip install pyspark==3.2.1

from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql import types as st

import spark.helpers as sh
from spark.join import JoinValidator, JoinStatement

In [2]:
# Setup Spark

conf = SparkConf().setMaster("local[1]").setAppName("examples")
spark = SparkSession.builder.config(conf=conf).getOrCreate()
spark.sparkContext.setLogLevel('ERROR')

In [3]:
# Load example datasets

dataframe_1 = spark.read.options(header=True).csv("./data/dataset_1.csv")
dataframe_2 = spark.read.options(header=True).csv("./data/dataset_2.csv")
_ = dataframe_1.show(), dataframe_2.show()

+---+---+---+---+-----+
| x1| x2| x3| x4|   x5|
+---+---+---+---+-----+
|  A|  J|734|499|595.0|
|  B|  I|357|202|525.0|
|  C|  H|864|568|433.5|
|  D|  G|530|703|112.3|
|  E|  F| 61|521|906.0|
|  F|  E|482|496| 13.0|
|  G|  D|350|279|941.0|
|  H|  C|171|267|423.0|
|  I|  B|755|133|600.0|
|  J|  A|228|765|  7.0|
+---+---+---+---+-----+

+---+---+---+---+------+
| x1| x3| x4| x6|    x7|
+---+---+---+---+------+
|  W|  K|391|140| 872.0|
|  X|  G| 88|483| 707.1|
|  Y|  M|144|476| 714.3|
|  Z|  J|896| 68| 902.0|
|  A|  O|946|187| 431.0|
|  B|  P|692|523| 503.5|
|  C|  Q|550|988|181.05|
|  D|  R| 50|419|  42.0|
|  E|  S|824|805| 558.2|
|  F|  T| 69|722| 721.0|
+---+---+---+---+------+



## Pandas-like group by

In [4]:
for group, data in sh.group_iterator(dataframe_1, "x1"):
    print(group, data.toPandas().shape[0])

A 1
B 1
C 1
D 1
E 1
F 1
G 1
H 1
I 1
J 1


## Bulk-change schema

In [5]:
before = [(x["name"], x["type"]) for x in dataframe_1.schema.jsonValue()["fields"]]

schema = {
    "x2": st.IntegerType(),
    "x5": st.FloatType(),
}
new_dataframe = sh.change_schema(dataframe_1, schema)

after = [(x["name"], x["type"]) for x in new_dataframe.schema.jsonValue()["fields"]]
check = [
    ('x1', 'string'),
    ('x2', 'integer'),
    ('x3', 'string'),
    ('x4', 'string'),
    ('x5', 'float')
]

assert before != after
assert after == check

## Improved joins

In [6]:
joined = sh.join(dataframe_1, dataframe_2, JoinStatement("x1", "x1"))
joined.toPandas()

Unnamed: 0,x1,x2,x3,x4,x5,x6,x7
0,A,J,734,499,595.0,187,431.0
1,B,I,357,202,525.0,523,503.5
2,C,H,864,568,433.5,988,181.05
3,D,G,530,703,112.3,419,42.0
4,E,F,61,521,906.0,805,558.2
5,F,E,482,496,13.0,722,721.0


### [Keeping the duplicate columns from the right dataframe]

In [7]:
joined = sh.join(dataframe_1, dataframe_2, JoinStatement("x1", "x1"), duplicate_keep="right")
joined.toPandas()

Unnamed: 0,x1,x2,x3,x4,x5,x6,x7
0,A,J,O,946,595.0,187,431.0
1,B,I,P,692,525.0,523,503.5
2,C,H,Q,550,433.5,988,181.05
3,D,G,R,50,112.3,419,42.0
4,E,F,S,824,906.0,805,558.2
5,F,E,T,69,13.0,722,721.0


### [Keeping the duplicate columns from both]

In [8]:
joined = sh.join(
    dataframe_1, dataframe_2, JoinStatement("x1", "x1"), 
    duplicate_keep=[["x1", "x3"], ["x4"]]
)
joined.toPandas()

Unnamed: 0,x1,x2,x3,x4,x5,x6,x7
0,A,J,734,946,595.0,187,431.0
1,B,I,357,692,525.0,523,503.5
2,C,H,864,550,433.5,988,181.05
3,D,G,530,50,112.3,419,42.0
4,E,F,61,824,906.0,805,558.2
5,F,E,482,69,13.0,722,721.0


### [Complex join]

In [9]:
x1_x1 = JoinStatement("x1", "x1")
x1_x3 = JoinStatement("x1", "x3")
statement = JoinStatement(x1_x1, x1_x3, "or")
joined = sh.join(dataframe_1, dataframe_2, statement)
joined.toPandas()

Unnamed: 0,x1,x2,x3,x4,x5,x6,x7
0,A,J,734,499,595.0,187,431.0
1,B,I,357,202,525.0,523,503.5
2,C,H,864,568,433.5,988,181.05
3,D,G,530,703,112.3,419,42.0
4,E,F,61,521,906.0,805,558.2
5,F,E,482,496,13.0,722,721.0
6,G,D,350,279,941.0,483,707.1
7,J,A,228,765,7.0,68,902.0


### [Further nested joins are not supported]
(Perform sequential joins instead)

In [10]:
x1_x1 = JoinStatement("x1", "x1")
x1_x2 = JoinStatement("x1", "x3")
statement = JoinStatement(x1_x1, x1_x2, "or")
statement_complex = JoinStatement(statement, statement, "and")
try:
    joined = sh.join(dataframe_1, dataframe_2, statement_complex)
except NotImplementedError as error:
    print(f"Error raised as expected: [{error}]")

Error raised as expected: [Recursive JoinStatement not implemented]
