<a href="https://colab.research.google.com/github/nhsbsa-data-analytics/coffee-and-coding/blob/master/2025-01-30%20%20What%20is%20Spark%3F/2025_01_30_coffee_and_coding_intro_to_spark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Coffee & Coding: What is Spark?

- **Date:** 30/01/2025
- **Presented by:** Alistair Jones

## Overview
This session provides an introduction to Spark by example.
We will explore basic operations to read, transform and write data using Spark SQL and PySpark.
We will then work through a short practical example of a simple analytical pipeline.

The session will cover the fundamentals of how Spark works, why it is important/useful and offer some considerations for using it in your workflows.

## Background and context

### Why should you care about Spark?
Apache Spark is a unified analytics engine designed for large-scale data processing.
Some of the strengths of Spark include:

- **Scalability** is the primary benefit of using Spark. It distributes work across a cluster of computers, which means it can handle much bigger workloads than a single computer could by itself.
- **Support for multiple languages** auch as Python via PySpark and SQL.
- **Managed compute** in platforms such as Microsoft Fabric  
- Spark is completely **open-source** and **well-documented**!
- Spark has an **active community** of developers with lots of helpful threads and tutorials available online.

### What are the limitations of Spark?
- **Additional complexity** due to the distributed nature of Spark, which means you need a cluster of computers to leverage it
- **Row-wise computation** means we need to be careful about things like row-ordering (we will show this below)
- Can be **hard to debug** due to the distributed nature and because it is implemented in Java under the hood (which means lengthy errors and stacktraces!)
- Can be **hard to setup** due to the need for a cluster and the vast configuration options!


## Getting started


### Environment Setup
The first step is to setup our environment with the packages and configuration required to do the work.

#### Installing PySpark Locally
Note that Apache Spark (PySpark and Spark SQL) are meant to run on a distributed system.
By distributed, we just mean a cluster of computers (or nodes) all working together to perform a task.

You can install PySpark locally to learn how to write and run code via the PySpark API.
Since this is on a single computer rather than a distributed system, while it will look the same, the way it works under the hood will be different in the different environments.

This has implications for performance and certain behaviours.
It won't be a problem for learning how to interact with the PySpark API, but it is a consideration for transfering this over to an actual Spark system (such as Microsoft Fabric).

In [1]:
!pip install pyspark



### Import Packages

In [2]:
from pyspark import SparkFiles
from pyspark.sql import SparkSession, functions as F


### What is a Spark Session?
A Spark Session is the entry point to programming Spark. It provides a way to interact with Spark's functionality and create DataFrames.

In managed data platforms such as Microsoft Fabric, the Spark Session is usually created when you start the notebook and available via a global `spark` variable without any further setup.

In this example, since we are in our own environment, we need to create a Spark Session to be able to work with Spark.

In [3]:
# Initialize Spark Session since we are not working in an environment
# where one is created for us
spark = (
    SparkSession.builder
    .appName("Coffee_And_Coding")
    .getOrCreate()
)

### Loading Inputs from the NHSBSA Open Data Portal
Usually when we work in a managed environment like a data platform, we have access to data stored in tables that already exist and we just want to do something with that data.
For this tutorial there is an additional setup step so that we can interact with
data like we would in a data platform: we will download a csv file from the NHSBSA Open Data Portal and load it into a table, so that we have something to query!

First we download the csv file and load it into a PySpark DataFrame.

In [4]:
# Adding the file to the Spark Context creates a reference to the data
# so that we can query it,
data_url = "https://opendata.nhsbsa.net/dataset/a436f2d3-e6c8-46bd-b730-4f4d7af7ca56/resource/9e2d01a5-7644-4843-809a-b09ddd6f447a/download/monthly-hospital-data-oct24.csv"
spark.sparkContext.addFile(data_url)

# Get the local file spark created
filename = data_url.split("/")[-1]  # Part of the URL after the last '/'
local_filepath = f"file://{SparkFiles.get(filename)}"

# Load the data from a csv
df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "true")
    .csv(local_filepath)
)

Then we can write the data to a table.
This step creates [parquet](https://parquet.apache.org/) files in our local drive which are the storage format Spark uses under the hood.
Spark works by distributing a dataset around the cluster by sending a subset of the parquet files to each node.

This is how Spark systems are able to scale so well, because each node in the cluster only has to do a small portion of the work.
If we add more nodes, or make the nodes bigger, then we can do more work in a way that wouldn't be possible with a single computer!

In [5]:
# Store the data in a new table
table_name = filename.split(".")[0]  # Part of the filename before .csv
table_name = table_name.replace("-", "_")  # Sanitize the table name

# There are many options to configure the table to be written
# We are only using two which provide the name and indicate
# that we should overwrite the table if it already exists
df.write.mode("overwrite").saveAsTable(table_name)

## Fundamental Operations
In this section we'll step through a few of the fundamental operations for interacting with PySpark and Spark SQL.
Many of these will be familiar to anyone who has used SQL before!

In each part we'll show both the PySpark and the Spark SQL command: feel free to follow along to whichever you prefer.
If you don't have any prior Python experience, the Spark SQL may be more accessible.
Equally, PySpark can be a good entrypoint to learn some Python if you are already familiar with SQL!

### Create a Dataframe From Input Data
Spark can read data from many different sources: tables, csv files, web urls, etc.
Since we are usually working in a managed environment where the tables are curated for us, we'll explore how to load data from a table (the one we created above!).



#### What is a DataFrame?
Whenever we read data using the PySpark API, we create a 'DataFrame', which represents the data we are querying.
A DataFrame isn't actually the data itself: it is essentially a set of instructions to read, transform and write data.
These instructions are evaluated 'lazily' which means that nothing happens straight away - you queue up all of your instructions and Spark works out when to actually do the work (usually this is when you want to print or write some results).

Spark uses lazy evaluation to optimise query performance: by only executing the code at the point we need it to, Spark can cleverly work out the fastest and most efficient way to perform all the operations we are asking it to!

#### PySpark
We can create a DataFrame in PySpark using the `spark.table` function.

In [6]:
df = spark.table(table_name)
df.show(5)

+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+-----------------+---------+
|PERIOD|            BNF NAME|       BNF CODE|HOSPITAL TRUST CODE|      HOSPITAL TRUST|QUANTITY|TOTAL QUANTITY|TOTAL ITEMS|TOTAL ACTUAL COST|TOTAL NIC|
+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+-----------------+---------+
|202410|Scheriproct suppo...|0107020P0BCABAD|              R0A00|MANCHESTER UNIVER...|     7.0|           7.0|          1|           1.0024|     0.89|
|202410|Circadin 2mg modi...|0401010ADBBAAAA|              R0A00|MANCHESTER UNIVER...|    84.0|         252.0|          3| 123.148680889096|   129.27|
|202410|Equasym XL 20mg c...|0404000M0BCAEAQ|              R0A00|MANCHESTER UNIVER...|    14.0|         168.0|         12|         169.3488|    168.0|
|202410|Elvanse 20mg caps...|0404000U0BCADAA|              R0A00|MANCHESTER UNIVER...|    21.0

#### Spark SQL
To query data with SQL, we use the `spark.sql` function, providing a query that look be familiar to current users of SQL.

In [7]:
query = f"""
    SELECT * FROM {table_name} LIMIT 5
"""
print(f"Showing result of: {query}")
spark.sql(query).show()

Showing result of: 
    SELECT * FROM monthly_hospital_data_oct24 LIMIT 5

+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+-----------------+---------+
|PERIOD|            BNF NAME|       BNF CODE|HOSPITAL TRUST CODE|      HOSPITAL TRUST|QUANTITY|TOTAL QUANTITY|TOTAL ITEMS|TOTAL ACTUAL COST|TOTAL NIC|
+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+-----------------+---------+
|202410|Scheriproct suppo...|0107020P0BCABAD|              R0A00|MANCHESTER UNIVER...|     7.0|           7.0|          1|           1.0024|     0.89|
|202410|Circadin 2mg modi...|0401010ADBBAAAA|              R0A00|MANCHESTER UNIVER...|    84.0|         252.0|          3| 123.148680889096|   129.27|
|202410|Equasym XL 20mg c...|0404000M0BCAEAQ|              R0A00|MANCHESTER UNIVER...|    14.0|         168.0|         12|         169.3488|    168.0|
|202410|Elvanse 20m

### Selecting and Renaming Columns
It is good practice to select only the columns we need to use in our analysis / pipeline / model.

Column selection helps:
- **Minimise data** to only the fields we need to use, thus decreasing the likelihood of any privacy risks!
- **Improve query performance** by reducing the amount of data we pull through, reducing memory usage and processing time!
- **Focus on relevant data** as some of our datasets can be very wide and it might not be helpful to see everything at once!

#### PySpark
To select columns in PySpark, we use the `select` method on the DataFrame we created above.

In [8]:
# Columns to select
# We can put these directly in the arguments of 'select' but
# it is often easier to read / maintain if we pull them out
# into their own list
columns = [
    "HOSPITAL TRUST CODE",
    "BNF CODE",
    "QUANTITY",
    "PERIOD",
]
selected_df = df.select(columns)
selected_df.show(5)

+-------------------+---------------+--------+------+
|HOSPITAL TRUST CODE|       BNF CODE|QUANTITY|PERIOD|
+-------------------+---------------+--------+------+
|              R0A00|0107020P0BCABAD|     7.0|202410|
|              R0A00|0401010ADBBAAAA|    84.0|202410|
|              R0A00|0404000M0BCAEAQ|    14.0|202410|
|              R0A00|0404000U0BCADAA|    21.0|202410|
|              R0A00|0407010F0AAAFAF|    60.0|202410|
+-------------------+---------------+--------+------+
only showing top 5 rows



Notice that our column names contain spaces: we may want to rename these to standardise and avoid potential problems down the line.

In PySpark we can reference columns by name, or via a 'Column' object
created with the `col` function from `pyspark.sql.functions` (aliased here as `F`). Using a Column object is often helpful because we can perform operations on it, which enables us specify complex transformations in a modular way.
Think about those times you've seen many lines to define a single column in the select part of a SQL query!

In [24]:
renamed_col = F.col("HOSPITAL TRUST CODE").alias("HOSPITAL_TRUST_CODE")
renamed_df = df.select(renamed_col)

print("Renamed columns:")
print(renamed_df.columns)

Renamed columns:
['HOSPITAL_TRUST_CODE']


We can rename all the columns with a for-loop.

In [9]:
renamed_columns = []

for old_name in df.columns:
    old_col = F.col(old_name)  # Define the Column by referencing via its name

    # Get the new name and use it to create a new Column by aliasing the old
    new_name = old_name.replace(" ", "_")  # Replace spaces with underscore
    new_col = old_col.alias(new_name)

    # Add the new column to the list of columns we will select
    renamed_columns.append(new_col)

# Select the newly renamed columns
renamed_df = df.select(renamed_columns)
print("Renamed columns:")
print(renamed_df.columns)

Renamed columns:
['PERIOD', 'BNF_NAME', 'BNF_CODE', 'HOSPITAL_TRUST_CODE', 'HOSPITAL_TRUST', 'QUANTITY', 'TOTAL_QUANTITY', 'TOTAL_ITEMS', 'TOTAL_ACTUAL_COST', 'TOTAL_NIC']


#### Spark SQL
In Spark SQL selecting columns is straightforward: we just modify the query we pass to `spark.sql`.
The only issue to be aware of is that is can be hard to write modular code in this way, so while it might feel easier, sometimes it can be harder to read/write/maintain Spark SQL compared to PySpark as the size/complexity of a query grows.

Similar to above, we can rename our columns by aliasing, this time with 'AS'

In [10]:
query = f"""
    SELECT
        PERIOD AS PERIOD,
        `BNF NAME` AS BNF_NAME,
        `BNF CODE` AS BNF_CODE,
        `HOSPITAL TRUST CODE` AS HOSPITAL_TRUST_CODE,
        `HOSPITAL TRUST` AS HOSPITAL_TRUST,
        QUANTITY AS QUANTITY,
        `TOTAL QUANTITY` AS TOTAL_QUANTITY,
        `TOTAL ITEMS` AS TOTAL_ITEMS,
        `TOTAL ACTUAL COST` AS TOTAL_ACTUAL_COST,
        `TOTAL NIC` AS TOTAL_NIC
    FROM {table_name}
    LIMIT 5
"""
spark.sql(query).show()

+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+-----------------+---------+
|PERIOD|            BNF_NAME|       BNF_CODE|HOSPITAL_TRUST_CODE|      HOSPITAL_TRUST|QUANTITY|TOTAL_QUANTITY|TOTAL_ITEMS|TOTAL_ACTUAL_COST|TOTAL_NIC|
+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+-----------------+---------+
|202410|Scheriproct suppo...|0107020P0BCABAD|              R0A00|MANCHESTER UNIVER...|     7.0|           7.0|          1|           1.0024|     0.89|
|202410|Circadin 2mg modi...|0401010ADBBAAAA|              R0A00|MANCHESTER UNIVER...|    84.0|         252.0|          3| 123.148680889096|   129.27|
|202410|Equasym XL 20mg c...|0404000M0BCAEAQ|              R0A00|MANCHESTER UNIVER...|    14.0|         168.0|         12|         169.3488|    168.0|
|202410|Elvanse 20mg caps...|0404000U0BCADAA|              R0A00|MANCHESTER UNIVER...|    21.0

### Filtering Rows
Filtering rows is a key operation that we use frequently in preparing and transforming data for analysis, helping to:
- **Minimise data** to only the records we need to use, thus decreasing the likelihood of any privacy risks!
- **Improve query performance** by reducing the amount of data we pull through, reducing memory usage and processing time!
- **Focus on relevant data** as some of our datasets can be very long and it might not be helpful to see everything at once!

#### PySpark
In PySpark we can filter rows using the `filter` or `where` methods of a DataFrame (you can use either - they are exactly the same under the hood!).

Filter conditions can be created either using a SQL expression withing a string (e.g. `"my_column > 5"`) or using the Column object we discussed above (e.g. `F.col("my_column") > 5`. You can use either, but Column objects are usually preferable for more complex conditions, since they can be created and combined in a modular way (unlike string SQL expressions, Column objects support most operations like addition, subtraction, greater or less than, and so on).

In [12]:
hospital_condition_col = F.col("HOSPITAL_TRUST_CODE") == "RKB00"  # Coventry
bnf_condition_col = F.substring("BNF_CODE", 0, 9) == "0407010H0"  # Paracetamol
condition_col = hospital_condition_col & bnf_condition_col
filtered_df = renamed_df.filter(condition_col)
filtered_df.show(5)

+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+------------------+---------+
|PERIOD|            BNF_NAME|       BNF_CODE|HOSPITAL_TRUST_CODE|      HOSPITAL_TRUST|QUANTITY|TOTAL_QUANTITY|TOTAL_ITEMS| TOTAL_ACTUAL_COST|TOTAL_NIC|
+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+------------------+---------+
|202410|Paracetamol 500mg...|0407010H0AAAMAM|              RKB00|UNIVERSITY HOSPIT...|   100.0|         200.0|          2|3.1299551649067054|     3.88|
|202410|Paracetamol 500mg...|0407010H0AAAMAM|              RKB00|UNIVERSITY HOSPIT...|    32.0|          32.0|          1|0.5085845882067416|     0.62|
+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+------------------+---------+



#### Spark SQL
In Spark SQL, similarly to the sections above, we simply provide a SQL query with a 'WHERE' clause, as we would with other flavours of SQL.

Note how the query size starts to grow: when we create a DataFrame with PySpark the query is dynamically created at runtime (i.e. after we've added all the transformations) but when we write SQL directly we have to put everything in one long query string.

One way to get around this is to create temporary tables or views to hold intermediate outputs from SQL queries.
We'll not cover that here though, instead simply recommending to give PySpark a try!

In [13]:
query = f"""
    SELECT
        PERIOD AS PERIOD,
        `BNF NAME` AS BNF_NAME,
        `BNF CODE` AS BNF_CODE,
        `HOSPITAL TRUST CODE` AS HOSPITAL_TRUST_CODE,
        `HOSPITAL TRUST` AS HOSPITAL_TRUST,
        QUANTITY AS QUANTITY,
        `TOTAL QUANTITY` AS TOTAL_QUANTITY,
        `TOTAL ITEMS` AS TOTAL_ITEMS,
        `TOTAL ACTUAL COST` AS TOTAL_ACTUAL_COST,
        `TOTAL NIC` AS TOTAL_NIC
    FROM {table_name}
    WHERE `HOSPITAL TRUST CODE` = 'RKB00'
        AND substring(`BNF CODE`, 0, 9) = '0407010H0'
    LIMIT 5
"""
spark.sql(query).show()

+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+------------------+---------+
|PERIOD|            BNF_NAME|       BNF_CODE|HOSPITAL_TRUST_CODE|      HOSPITAL_TRUST|QUANTITY|TOTAL_QUANTITY|TOTAL_ITEMS| TOTAL_ACTUAL_COST|TOTAL_NIC|
+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+------------------+---------+
|202410|Paracetamol 500mg...|0407010H0AAAMAM|              RKB00|UNIVERSITY HOSPIT...|   100.0|         200.0|          2|3.1299551649067054|     3.88|
|202410|Paracetamol 500mg...|0407010H0AAAMAM|              RKB00|UNIVERSITY HOSPIT...|    32.0|          32.0|          1|0.5085845882067416|     0.62|
+------+--------------------+---------------+-------------------+--------------------+--------+--------------+-----------+------------------+---------+



### Adding New Columns

Above we looked at how to select existing columns, but we can also add new columns to our data.
This allows us to perform calculations involving existing columns or apply business rules and derivations.

#### PySpark
In Pyspark, we _can_ add new columnd with `select`, but it is often clearer if we stick to using `select` for picking existing columns and `withColumn` to add a new column.

If the new column involves a conditional 'case-when' statement, we can do this using `when` from `pyspark.sql.functions`.

In [14]:
# Extract date from period
derived_df = renamed_df.withColumn("DATE", F.to_date("PERIOD", "yyyyMM"))

# Add a new column to classify the quantity
# Case when statement using F.when(cond, value_if_true).otherwise(default)
quantity_class_col = (
    F.when(F.col("QUANTITY") > 50, "high")
    .otherwise("low")
)
derived_df = derived_df.withColumn("QUANTIY_CLASS", quantity_class_col)

# Show the new columns and the columns used to derive them
show_cols = [
    "PERIOD",
    "DATE",
    "QUANTITY",
    "QUANTIY_CLASS",
]
derived_df.select(show_cols).show(5)

+------+----------+--------+-------------+
|PERIOD|      DATE|QUANTITY|QUANTIY_CLASS|
+------+----------+--------+-------------+
|202410|2024-10-01|     7.0|          low|
|202410|2024-10-01|    84.0|         high|
|202410|2024-10-01|    14.0|          low|
|202410|2024-10-01|    21.0|          low|
|202410|2024-10-01|    60.0|         high|
+------+----------+--------+-------------+
only showing top 5 rows



Note: if we want to add a constant-value column, we need to use `lit` which means we are defining a literal value column.
This is because PySpark expects columns to either be Column objects (which represent a SQL expression under the hood) or a string which references the column by name.
The `lit` function simply tells PySpark to create a Column with a single constant value.

In [15]:
# Note that the line below doesn't reassign the DataFrame derived_df with the
# new column (i.e. there is no '='), so this transformation will not be
# preserved downstream
(
    derived_df
    .withColumn("CONSTANT", F.lit(1))  # Add the column
    .select("CONSTANT")  # Select the new column
    .show(5)
)

+--------+
|CONSTANT|
+--------+
|       1|
|       1|
|       1|
|       1|
|       1|
+--------+
only showing top 5 rows



#### Spark SQL
As with other SQL sections above, we can use familiar syntax to other flavours of SQL to add a new column via a `SELECT` statement and use `CASE... WHEN... ELSE...` to implement conditional column derivations.

In [16]:
# As above, we derive 2 columns
# 'DATE' from formatting the 'PERIOD'
# 'QUANTIY_CLASS' from conditional cases on 'QUANTIY'
query = f"""
    SELECT
        PERIOD,
        format_string(PERIOD, 'yyyyMM') AS DATE,
        QUANTITY,
        CASE
            WHEN QUANTITY > 50 THEN 'high'
            ELSE 'low'
        END AS QUANTITY_CLASS
    FROM {table_name}
"""
spark.sql(query).show(5)

+------+------+--------+--------------+
|PERIOD|  DATE|QUANTITY|QUANTITY_CLASS|
+------+------+--------+--------------+
|202410|202410|     7.0|           low|
|202410|202410|    84.0|          high|
|202410|202410|    14.0|           low|
|202410|202410|    21.0|           low|
|202410|202410|    60.0|          high|
+------+------+--------+--------------+
only showing top 5 rows



### Aggregations
Preparing, transforming and analysing data often involves some form of aggregations.
We can aggregate data through 'group by' operations, which may be familiar to folks who have used languages such as SQL, Python or R to interact with data.

There are two parts to a 'group by':
1. The fields or columns in the data we want to group records by
2. The fields or columns to be aggregated over the group and how they will be aggregated (e.g. summed, counted, etc)

#### PySpark
In PySpark, we use the `groupBy` method of a DataFrame to perform a 'group by' operation and `agg` to do the aggregation.
We can provide Column objects or strings to specify the fields to use in each step of this transformation and as above we often want to use Column objects.

Additionally, it is often easier to read, understand and maintain code where the columns have been pulled out of the transformation itself, as shown below.

In [17]:
# Lift the columns for grouping and aggregation out of the transformation
# This can make code easier to read and maintain, especially as the complexity
# increases
group_cols = [
    F.col("HOSPITAL_TRUST_CODE"),
    F.col("BNF_CODE"),
]
agg_cols = [
    # `alias` used to name the resulting column
    F.sum("TOTAL_QUANTITY").alias("TOTAL_QUANTITY")
]
totals_by_hospital_and_bnf_df = (
    renamed_df
    .groupBy(*group_cols)
    .agg(*agg_cols)
)

totals_by_hospital_and_bnf_df.show(5)

+-------------------+---------------+--------------+
|HOSPITAL_TRUST_CODE|       BNF_CODE|TOTAL_QUANTITY|
+-------------------+---------------+--------------+
|              RBK00|0604020C0AAAAAA|         182.0|
|              R0A00|    21220000205|        3100.0|
|              R1F00|0407020A0AAAFAF|           5.0|
|              RAT00|0407042F0AAAAAA|        2332.0|
|              R1K00|0406000B0AAABAB|         410.0|
+-------------------+---------------+--------------+
only showing top 5 rows



#### Spark SQL
Aggregation in Spark SQL uses the familiar 'GROUP BY' expression from other flavours of SQL, along with aggregation expressions such as 'sum' in the columns to be selected.

In [18]:
query = f"""
    SELECT
        `HOSPITAL TRUST CODE` AS HOSPITAL_TRUST_CODE,
        `BNF CODE` AS BNF_CODE,
        sum(`TOTAL QUANTITY`) AS TOTAL_QUANTITY
    FROM {table_name}
    GROUP BY `HOSPITAL TRUST CODE`, `BNF CODE`
"""
spark.sql(query).show(5)

+-------------------+---------------+--------------+
|HOSPITAL_TRUST_CODE|       BNF_CODE|TOTAL_QUANTITY|
+-------------------+---------------+--------------+
|              RBK00|0604020C0AAAAAA|         182.0|
|              R0A00|    21220000205|        3100.0|
|              R1F00|0407020A0AAAFAF|           5.0|
|              RAT00|0407042F0AAAAAA|        2332.0|
|              R1K00|0406000B0AAABAB|         410.0|
+-------------------+---------------+--------------+
only showing top 5 rows



### Joins
We frequently need to join datasets together to enrich data or add dimensions to provide additional context for analysis.

#### PySpark
In PySpark we can use the `join` method of a DataFrame, which takes arguments specifying:
1. the other DataFrame to join
2. the columns to join on (these can be a list of strings or Column objects)
3. how to do the join (e.g. inner, left, etc)

In [19]:
# Get the mapping between codes and names
# We'll join this to the aggregated data above to provide context
bnf_hospital_cols = [
    "HOSPITAL_TRUST",
    "HOSPITAL_TRUST_CODE",
    "BNF_NAME",
    "BNF_CODE",
]
bnf_hospital_map_df = (
    renamed_df.select(bnf_hospital_cols).distinct()
)

# It can be helpful in terms of readability/maintainability to pull the
# join columns out of the join transformation
join_on = [
    "HOSPITAL_TRUST_CODE",
    "BNF_CODE",
]
joined_df = (
    totals_by_hospital_and_bnf_df
    .join(
        bnf_hospital_map_df,
        on=join_on,
        how="inner"
    )
)

joined_df.show(5)

+-------------------+---------------+--------------+--------------------+--------------------+
|HOSPITAL_TRUST_CODE|       BNF_CODE|TOTAL_QUANTITY|      HOSPITAL_TRUST|            BNF_NAME|
+-------------------+---------------+--------------+--------------------+--------------------+
|              R0A00|040702040AAACAC|          28.0|MANCHESTER UNIVER...|Tramadol 100mg mo...|
|              RBN00|1304000H0BBAAAA|         190.0|MERSEY & WEST LAN...|Eumovate 0.05% cream|
|              RAT00|0403030E0AAAPAP|         177.0|NORTH EAST LONDON...|Fluoxetine 10mg t...|
|              R1L00|0408010F0AAABAB|        1539.0|ESSEX PARTNERSHIP...|Clonazepam 500mic...|
|              RAE00|0304010I0AAAAAA|          28.0|BRADFORD TEACHING...|Cetirizine 10mg t...|
+-------------------+---------------+--------------+--------------------+--------------------+
only showing top 5 rows



#### Spark SQL
We can join in Spark SQL again using familiar syntax from other SQL flavours, e.g. 'INNER JOIN', 'LEFT JOIN' etc.

In [20]:
query = f"""
    SELECT
        agg_tbl.HOSPITAL_TRUST_CODE,
        agg_tbl.BNF_CODE,
        agg_tbl.TOTAL_QUANTITY,
        map_tbl.HOSPITAL_TRUST,
        map_tbl.BNF_NAME
    FROM (
        SELECT
            `HOSPITAL TRUST CODE` AS HOSPITAL_TRUST_CODE,
            `BNF CODE` AS BNF_CODE,
            sum(`TOTAL QUANTITY`) AS TOTAL_QUANTITY
        FROM {table_name}
        GROUP BY `HOSPITAL TRUST CODE`, `BNF CODE`
    ) AS agg_tbl
    INNER JOIN (
        SELECT DISTINCT
            `HOSPITAL TRUST CODE` AS HOSPITAL_TRUST_CODE,
            `BNF CODE` AS BNF_CODE,
            `HOSPITAL TRUST` AS HOSPITAL_TRUST,
            `BNF NAME` AS BNF_NAME
        FROM {table_name}
    ) AS map_tbl
    ON map_tbl.HOSPITAL_TRUST_CODE = agg_tbl.HOSPITAL_TRUST_CODE
        AND map_tbl.BNF_CODE = agg_tbl.BNF_CODE
"""
spark.sql(query).show(5)

+-------------------+---------------+--------------+--------------------+--------------------+
|HOSPITAL_TRUST_CODE|       BNF_CODE|TOTAL_QUANTITY|      HOSPITAL_TRUST|            BNF_NAME|
+-------------------+---------------+--------------+--------------------+--------------------+
|              RAJ00|0103050L0AAAHAH|         141.0|MID AND SOUTH ESS...|Lansoprazole 30mg...|
|              RBD00|190201000AABLBL|         862.0|DORSET COUNTY HOS...|Exception Handler...|
|              R0B00|0303020G0AAADAD|          28.0|SOUTH TYNESIDE AN...|Montelukast 4mg g...|
|              RAT00|0404000M0BGADAU|       13187.0|NORTH EAST LONDON...|Medikinet XL 10mg...|
|              RBK00|1001010C0AAALAL|         236.0|WALSALL HEALTHCAR...|Diclofenac sodium...|
+-------------------+---------------+--------------+--------------------+--------------------+
only showing top 5 rows



#### Note on row ordering
Notice that the order of rows is different between the PySpark and Spark SQL examples above.
This demonstrates a key aspect to consider when writing code for Spark which is that the processing does not take account of row order unless you explictly tell it to (e.g. with an 'ORDER BY' clause).

When Spark executes a task, it splits the rows up and sends them to different nodes in the cluster, each of which can take a different amount of time to process and return the results.

Although the results are the same overall (as in they have the same schema and contain the same records), the order in which the records appear in the result sets is not guaranteed to be the same from run-to-run.

### Writing Outputs
Spark is capable of writing outputs to a number of different formats, including csv, parquet and json.
But most often we want to write to a managed table in a Spark SQL database (note the data will actually be stored in parquet files under the hood, but we will simply see a table in the database as we would in e.g. Oracle SQL Developer).


#### PySpark
In PySpark we can use the `save` and `saveAsTable` methods of a DataFrame: the former writes records to files in a specified location (e.g. a network drive), while the latter creates a table in SQL to write the records.

For general use `save` is more powerful since there are more options for configuration.
But in the example below we'll keep it simple and use `saveAsTable`

Note: since `saveAsTable` will error if a table exists, it can be useful to check if the table exists before you try to write to it.
(Hint: this is probably good practice for other reasons too!). You can check if a table exists using `spark.catalog.tableExists`, passing the table name as the function arguments.

In [21]:
# There are many options to configure the table to be written
# We are only using two which provide the name and indicate
# that we should overwrite the table if it already exists
joined_df.write.saveAsTable("results_20250130", mode="overwrite")

Let's take a look at our new table.

In [22]:
spark.table("results_20250130").show(5)  # Look at our new table!

+-------------------+---------------+--------------+--------------------+--------------------+
|HOSPITAL_TRUST_CODE|       BNF_CODE|TOTAL_QUANTITY|      HOSPITAL_TRUST|            BNF_NAME|
+-------------------+---------------+--------------+--------------------+--------------------+
|              RA700|1202010U0AAAAAA|           8.0|UNIVERSITY HOSPIT...|Mometasone 50micr...|
|              R0A00|0404000V0AAACAC|         189.0|MANCHESTER UNIVER...|Guanfacine 1mg mo...|
|              RAT00|0404000M0BDAAAM|         878.0|NORTH EAST LONDON...|Concerta XL 18mg ...|
|              RAJ00|0603020T0AAACAC|        8836.0|MID AND SOUTH ESS...|Prednisolone 5mg ...|
|              RAT00|040201030AAANAN|         777.0|NORTH EAST LONDON...|Risperidone 500mi...|
+-------------------+---------------+--------------+--------------------+--------------------+
only showing top 5 rows



#### Spark SQL
To do the same in pure Spark SQL requires much more code than above, because we must first create the table and then write into it.
We will not cover this here, simply (and lazily) recommending that you use the line of PySpark shown above to write the table or follow the examples on the [Spark SQL docs](https://spark.apache.org/docs/latest/sql-ref-syntax-ddl-create-table-datasource.html)!

## Example Pipeline
Let's put our learning into practice and build a full end to end pipeline!

Our pipeline will have the following steps:
1. Read data
1. Select columns of interest
1. Derive a new column
1. Filter the rows
1. Aggregate the data
1. Write outputs to a table

In [23]:
# Configuration
input_table = "monthly_hospital_data_oct24"
output_table = "max_unit_cost_oct24_20250130"

# Input data
input_df = spark.table(input_table)

# Select columns of interest and rename
selected_cols = [
    F.col("HOSPITAL TRUST").alias("HOSPITAL_TRUST"),
    F.col("BNF NAME").alias("BNF_NAME"),
    F.col("TOTAL NIC").alias("TOTAL_NIC"),
    F.col("TOTAL QUANTITY").alias("TOTAL_QUANTITY"),
]
selected_df = input_df.select(selected_cols)

# Derive unit cost
unit_cost_col = F.round(F.col("TOTAL_NIC") / F.col("TOTAL_QUANTITY"), 4)
derived_df = selected_df.withColumn("UNIT_COST", unit_cost_col)

# Filter on BNF
bnf_condition_col = F.col("BNF_NAME").like("%Paracetamol 500mg tablets%")
filtered_df = derived_df.filter(bnf_condition_col)

# Group by HOSPITAL, BNF and max UNIT_COST
group_cols = [
    "HOSPITAL_TRUST",
    "BNF_NAME",
]
agg_cols = [
    F.max("UNIT_COST").alias("MAX_UNIT_COST")
]
aggregated_df = filtered_df.groupBy(group_cols).agg(*agg_cols)

# Write outputs
aggregated_df.write.saveAsTable(output_table, mode="overwrite")

# Show outputs (ordered)
(
    spark
    .table(output_table)
    .orderBy("HOSPITAL_TRUST", "BNF_NAME")
    .show()
)

+--------------------+--------------------+-------------+
|      HOSPITAL_TRUST|            BNF_NAME|MAX_UNIT_COST|
+--------------------+--------------------+-------------+
|ASHFORD & ST PETE...|Paracetamol 500mg...|       0.0194|
|BARKING HAVERING ...|Paracetamol 500mg...|       0.0195|
|BARTS HEALTH NHS ...|Paracetamol 500mg...|       0.0194|
|BEDFORDSHIRE HOSP...|Paracetamol 500mg...|       0.0196|
|BERKSHIRE HEALTHC...|Paracetamol 500mg...|       0.0195|
|BIRMINGHAM AND SO...|Paracetamol 500mg...|       0.0194|
|BLACK COUNTRY HEA...|Paracetamol 500mg...|       0.0194|
|BLACKPOOL TEACHIN...|Paracetamol 500mg...|       0.0194|
|BOLTON NHS FOUNDA...|Paracetamol 500mg...|       0.0195|
|BUCKINGHAMSHIRE H...|Paracetamol 500mg...|       0.0195|
|CORNWALL PARTNERS...|Paracetamol 500mg...|       0.0195|
|COUNTESS OF CHEST...|Paracetamol 500mg...|       0.0194|
|COUNTY DURHAM AND...|Paracetamol 500mg...|       0.0195|
|CUMBRIANORTHUMBER...|Paracetamol 500mg...|       0.0195|
|DARTFORD AND 

## Takeaways

### Key Takeaways
- Spark is a technology for processing and transforming data at scale
- Spark works by distributing work over a cluster of computers, or nodes.
- You can interact with Spark through a number of APIs including PySpark (Python) and Spark SQL
- We showed how to perform some of the fundamental operations in transforming data using PySpark and Spark SQL
- And we built an example pipeline using what we've learned!

## Additional Resources
- [Apache Spark Documentation](https://spark.apache.org/docs/latest/)
- [PySpark API Reference](https://spark.apache.org/docs/latest/api/python/index.html)
- [Microsoft Fabric Training](https://learn.microsoft.com/en-us/training/modules/use-apache-spark-work-files-lakehouse/)
- [Palantir Style Guide](https://github.com/palantir/pyspark-style-guide)
