# About

This is a notebook about how to use pyspark.  
This is written based on Microsoft Azure Databricks, no guarantees provided for other cloud providers or environments.  
Core documentation: https://spark.apache.org/docs/latest/api/python/reference/index.html

# SparkSession Object

Databricks instantiates a SparkSesion object by default, named `spark`.  
In vanilla python, you will need to instantiate it yourself first.

In [0]:
# ​https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

# Querying SQL

## Basic SQL Query Methods

There are three main ways to query SQL tables with pyspark.

1. `spark.sql(query_string)`
2. `spark.table(table_name)`
3. `spark.read.table(table_name)`

### 1. `spark.sql(query_string: str, [**kwargs]) -> pyspark.sql.DataFrame`

**Call Tree:**  
1. `pyspark.sql.SparkSession` object
2. `pyspark.sql.SparkSession.sql()` method
3. `pyspark.sql.DataFrame` object`

**Docs:**
- [SparkSession object](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/spark_session.html)
- [sql() method](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.SparkSession.sql.html)
- [DataFrame object](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.html)

**Example:**

In [0]:
spark.sql("select * from curated_udp_db.txn_policy")

#### Notes:

##### Programmatic Query via `spark.sql`

It is possible to programmatically alter sql queries by providing _kwargs_ options to `spark.sql`.
[This is the officially documented method in pyspark.](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.SparkSession.sql.html#pyspark.sql.SparkSession.sql)

In [0]:
spark.sql(
    "SELECT * FROM range(10) WHERE id > {bound1} AND id < {bound2}", bound1=7, bound2=9
).show()

Alternatively, using python string `str.format()` method and its variations should have the same behavior:

In [0]:
spark.sql(
    "SELECT * FROM range(10) WHERE id > {bound1} AND id < {bound2}".format(bound1=7, bound2=9)
).show()

### 2. `spark.table(table_name: str) -> pyspark.sql.DataFrame`

**Call Tree:**  
1. `pyspark.sql.SparkSession` object
2. `pyspark.sql.SparkSession.table()` method
3. `pyspark.sql.DataFrame` object

**Docs:**  
- [SparkSession object](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/spark_session.html)
- [table() method](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.table.html)
- [DataFrame object](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.html)

**Example:**

In [0]:
spark.table("curated_udp_db.txn_policy")

### 3. `spark.read.table(table_name: str) -> pyspark.sql.DataFrame`

**Call Tree:**  
1. `pyspark.sql.SparkSession` object
2. `pyspark.sql.SparkSession.read` property
3. `pyspark.sql.DataFrameReader` object
4. `pyspark.sql.DataFrameReader.table()` method
5. `pyspark.sql.DataFrame` object

**Docs:**  
- [SparkSession object](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/spark_session.html)
- [read property](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.SparkSession.read.html)
- [DataFrameReader object](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.html)
- [table() method](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.table.html)
- [DataFrame object](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.html)

**Example:**

In [0]:
spark.read.table("curated_udp_db.txn_policy")

## Programmatic Query via `spark.sql`

It is possible to programmatically alter sql queries by providing _kwargs_ options to `spark.sql`. This is the officially documented method in pyspark:  
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.SparkSession.sql.html#pyspark.sql.SparkSession.sql

In [0]:
spark.sql(
    "SELECT * FROM range(10) WHERE id > {bound1} AND id < {bound2}", bound1=7, bound2=9
).show()

Alternatively, using python string `format()` method should have the same behavior:

In [0]:
spark.sql(
    "SELECT * FROM range(10) WHERE id > {bound1} AND id < {bound2}".format(bound1=7, bound2=9)
).show()

## Select & Joins

You do not need to chain pyspark commands into a single line.  
It is entirely possible to create variables of each subquery, so that the combined query is more legible.  
_Note, however, it is unclear whether chaining vs separate variable declaration has any impact on query performance._

**Example:**  

**Chained Commands:**

In [0]:
df = spark.table("ai_dwh.ai_dwh_txn_claims").join(
    spark.table("ai_dwh.ai_dwh_txn_policy"),
    on="party_ref",
    how="left"
).join(
    spark.table("ai_dwh.ai_dwh_txn_claims_diagnosis"),
    on="party_ref",
    how="left"
)

**Separate Variable Declarations:**

In [0]:
claims = spark.table("ai_dwh.ai_dwh_txn_claims")
policy = spark.table("ai_dwh.ai_dwh_txn_policy")
diagno = spark.table("ai_dwh.ai_dwh_txn_claims_diagnosis")

df = claims.join(
    policy,
    on="party_ref",
    how="left"
).join(
    diagno,
    on="party_ref",
    how="left"
)

## Multiple Condition Joins

Refer to: https://stackoverflow.com/a/34463562/5675094

To join with multiple conditions, use a list of join conditions:

In [0]:
df1 = ...
df2 = ...

# Unclear what behavior this is (AND / OR?) -- assuming defaults to AND
list_conditions = [
    (df1.col_a == df2.col_a),
    (df1.col_b == df2.col_b)
]

and_conditions = [
    (df1.col_a == df2.col_a) &
    (df1.col_b == df2.col_b)
]

or_conditions = [
    (df1.col_a == df2.col_a) |
    (df1.col_b == df2.col_b)
]

df1.join(
    df2,
    on=list_conditions
)

## Filtering by Dates

If you have a pyspark.sql.DataFrame `df` with a datetime column `dates`, this is how you can filter for records after an arbitrary date:

**You can compare datetime columns to string, but not int**

In [0]:
df = spark.table("ai_dwh.ai_dwh_txn_claims")

df.filter(claims.claim_dt >= "2022") # This will run
df.filter(claims.claim_dt >= 2022) # This will raise an error


**Pyspark considers "YEAR" to be the same as the date "YEAR-01-01"**
- _This can ruin inclusive / exclusive range setups, so consider them carefully_
- It is recommended to explicitly write out the full ISO date "YYYY-MM-DD" to prevent confusion

Consider the following examples:
- `date > "2022"` means `date > "2022-01-01"`
  - All dates **AFTER** `January 1st 2022` are True (Exclusive of Jan 1st 2022)
- `date >= "2022"` means `date >= "2022-01-01"`
  - All dates **FROM and Greater Than** `January 1st 2022` are True (Inclusive of Jan 1st 2022)
- `"2022" > date` means `"2022-01-01" > date`
  - All dates **BEFORE** `January 1st 2022` are True (Exclusive of Jan 1st 2022)
- `"2022" >= date` means `"2022-01-01" >= date`
  - All dates **FROM and Less Than** `January 1st 2022` are True (Inclusive of Jan 1st 2022)

In [0]:
import pandas as pd
import pyspark.sql.functions as F

df = spark.createDataFrame(
  data = pd.DataFrame(
      data = pd.date_range(start="1999-12-30", end="2000-01-02", freq="D"),
      columns = ["date_col"]
    )
)\
.withColumn("date_col", F.col("date_col").cast("date") )\
.withColumn("date > 2000" , F.when((F.col("date_col") > "2000") , 1).otherwise(None) )\
.withColumn("date >= 2000", F.when((F.col("date_col") >= "2000"), 1).otherwise(None) )\
.withColumn("date < 2000" , F.when((F.col("date_col") < "2000") , 1).otherwise(None) )\
.withColumn("date <= 2000", F.when((F.col("date_col") <= "2000"), 1).otherwise(None) )\
\
.withColumn("2000 > date" , F.when(("2000" > F.col("date_col")) , 1).otherwise(None) )\
.withColumn("2000 >= date", F.when(("2000" >= F.col("date_col")), 1).otherwise(None) )\
.withColumn("2000 < date" , F.when(("2000" < F.col("date_col")) , 1).otherwise(None) )\
.withColumn("2000 <= date", F.when(("2000" <= F.col("date_col")), 1).otherwise(None) )

df.show()


Output of the above code block:

```
+----------+-----------+------------+-----------+------------+-----------+------------+-----------+------------+
|  date_col|date > 2000|date >= 2000|date < 2000|date <= 2000|2000 > date|2000 >= date|2000 < date|2000 <= date|
+----------+-----------+------------+-----------+------------+-----------+------------+-----------+------------+
|1999-12-30|       null|        null|          1|           1|          1|           1|       null|        null|
|1999-12-31|       null|        null|          1|           1|          1|           1|       null|        null|
|2000-01-01|       null|           1|       null|           1|       null|           1|       null|           1|
|2000-01-02|          1|           1|       null|        null|       null|        null|          1|           1|
+----------+-----------+------------+-----------+------------+-----------+------------+-----------+------------+
```

# Aggregation

## Group By

### `pyspark.sql.DataFrame.groupBy(*column_names) -> pyspark.sql.GroupedData`

Apply the `groupBy()` method on any `sql.DataFrame` with the specified column or list of columns.  
Will return a `GroupedData` DataFrame object.  
Need to apply `agg()` method afterwards to collect useable data.

**Examples:**

In [0]:
grouped_df = df.groupBy("column")
grouped_df = df.groupBy(["column_a", "column_b"])

## Pivot Function

### `pyspark.sql.GroupedData.pivot(column_name, [*values]) -> pyspark.sql.GroupedData`

Apply the `pivot()` method on a `GroupData` object (DataFrame after group by action) to create a pivot table.  
Pivots by values in the provided column `column_name`.

Optionally provide a list of values to pivot by in `[*values]`.
This has large benefits to performance, and is recommended to do.
Any other existing values in `column_name` but not provided to `[*values]` will be ignored.

**Pure Example:**

In [0]:
grouped_df = df.groupBy("column_a").pivot("column_b")
grouped_df = df.groupBy("column_a").pivot("column_b", [val_1, val_2, val_3])

**Example for Counting number of claims by incident dates, separating by benefit types**

In [0]:
import pyspark.sql.functions as F
from pyspark.sql.window import Window as W

benefit_codes = ['DENTAL','IPD','OPD']
window = W.partitionBy("IPD").orderBy("INCIDENT_DATE")

spark.table("ai_dwh.eb_claim_records")\
  .filter(F.col("INCIDENT_DATE").between("2023-01-01", "2023-01-31"))\
  .groupBy("INCIDENT_DATE")\
  .pivot("BENEFIT_CODE", benefit_codes)\
  .agg(F.count("POLICY_NO"))\
  .orderBy("INCIDENT_DATE")

## Aggregation Function

### `pyspark.sql.GroupedData.agg(*exprs) -> pyspark.sql.DataFrame`

Apply the `agg()` method and add desired [aggregation functions](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#aggregate-functions) to perform aggregation.  
You can provide as many aggregation functions as you like, separated by comma.

**Example:**

In [0]:
import pyspark.sql.functions as F

df_agg = df_raw.groupBy("column_key")\
        .agg(
            F.count_distinct("column_a").alias("column_a_counts"),
            F.sum("column_b").alias("column_b_sum"),
            F.first("column_c").alias("column_c_first"),
            F.last("column_d").alias("column_d_last"),
            ...
        )

### Median / Percentile

See: https://stackoverflow.com/a/51933027/5675094

Use `pyspark.sql.functions.percentile_approx` function to calculate median or any other percentile value of a column.

**Docs:**  
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.percentile_approx.html

**Example:**

In [0]:
from pyspark.sql import functions as psf

df = spark.table("ai_dwh.ai_dwh_txn_claims")

# Group by hospital that made the claim
grouped = df.groupby("hospital_name")

# Calculate Median
median = grouped.agg(
    psf.percentile_approx("claim_amt", 0.5).alias("median")
)

# Quantiles (Q1, Q2, Q3)
boxplot = grouped.agg(
    psf.percentile_approx("claim_amt", [0.25, 0.5, 0.75]).alias("quantile")
)

## Window Function

### `pyspark.sql.window.Window`

[The Window utility function](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/window.html) is used to first define a "window specification" before you can apply [Window aggregation functions](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#window-functions); where you have to provide the defined "window specification" variable to the window aggregation functions as an argument.

The core syntax is:

In [0]:
# Define window specification with one of the listed Window methods
window_spec = pyspark.sql.window.Window.[window_method()]

# Apply window aggregation function and provide defined window specification
df.select( pyspark.sql.functions.[window_agg_functions()] )
df.withColumn("name", pyspark.sql.functions.[window_agg_functions()] )
grouped_df.agg( pyspark.sql.functions.[window_agg_functions()] )

**Example: You want to find the percentile ranking of each value in a DataFrame's column.**

In [0]:
import pyspark.sql.functions as F
import pyspark.sql.window as W

# DataFrame has columns: [user_id: string, number_of_claims: int]

# Define window specification to sort DataFrame by the specified column
window_spec = W.Window.orderBy("number_of_claims")

# Apply Window aggregation function
df_rank = df.withColumn("percentile_rank", F.percent_rank().over(window_spec))

## Collecting Ordered List

Great answer copied from: https://stackoverflow.com/a/50668635/5675094

As you know, using `collect_list` together with `groupBy` will result in an **unordered** list of values. This is because depending on how your data is partitioned, Spark will append values to your list as soon as it finds a row in the group. The order then depends on how Spark plans your aggregation over the executors.

A `Window` function allows you to control that situation, grouping rows by a certain value so you can perform an operation `over` each of the resultant groups:

```
w = Window.partitionBy('id').orderBy('date')
```

    `partitionBy` - you want groups/partitions of rows with the same id
    `orderBy` - you want each row in the group to be sorted by date

Once you have defined the scope of your `Window` - "rows with the same `id`, sorted by `date`" -, you can use it to perform an operation over it, in this case, a `collect_list`:

```
F.collect_list('value').over(w)
```

At this point you created a new column `sorted_list` with an ordered list of values, sorted by date, but you still have duplicated rows per `id`. To trim out the duplicated rows you want to `groupBy` `id` and keep the `max` value in for each group:

```
.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))
```

**Example:**

In [0]:
from pyspark.sql import functions as F
from pyspark.sql import Window as W

window_spec = W.partitionBy('id').orderBy('date')

sorted_list_df = input_df.withColumn(
            'sorted_list', F.collect_list('value').over(window_spec)
        )\
        .groupBy('id')\
        .agg(F.max('sorted_list').alias('sorted_list'))


# Pyspark UDF

## `pyspark.sql.functions.udf`

`pyspark.sql.functions.udf` is the preferred pattern:  
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.udf.html

Despite what the documentation says, `pyspark.sql.functions.pandas_udf` is not an alias of `udf` and has different behaviors:  
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.pandas_udf.html

In [0]:
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.udf.html
from pyspark.sql import functions as F
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/data_types.html
from pyspark.sql.types import StringType, BooleanType

# Define function in python
def py_my_func(x):
  # ...
  return

# Wrap with pyspark udf and specify return type
my_func = F.udf(py_my_func, StringType())
my_func = F.udf(py_my_func, BooleanType())

# Apply the defined function
df = group_df.agg(my_func("column"))
df = df.withColumn("new_column", my_func("column"))