In [None]:
import pandas as pd
import pyspark.sql.functions as f
from datetime import datetime
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    IntegerType,
    ArrayType,
    MapType,
    FloatType,
    DateType,
)

# Complex data types - Part II

In the previous notebook, we explored PySpark's complex data types and some basic functions. In this notebook, we will continue exploring some more functions, this time more complex ones.

First, let's start by creating a sample dataset with information about orders from a Fruit Shop.

This dataset may remind you of the one used for the exercises of Part I, but this time we will be able to perform much more complex operations on it.

In [None]:
# Sample data
data = [
    {"order_id": 5642, "order_date": datetime.strptime("2024-05-18", "%Y-%m-%d").date(),
    "items": [
        {"name": "Apple", "amount": 1.0, "unit_price": 2.99},
        {"name": "Banana", "amount": 1.7, "unit_price": 1.99}],
    'items_discount': ['Apple']},
    {"order_id": 9762, "order_date": datetime.strptime("2024-05-02", "%Y-%m-%d").date(),
    "items": [
        {"name": "Strawberry", "amount": 0.5, "unit_price": 6.99},
        {"name": "Apple", "amount": 3.0, "unit_price": 2.99},
        {"name": "Peach", "amount": 2.5, "unit_price": 3.39}],
    'items_discount': ['Apple', 'Peach']},
    {"order_id": 3652, "order_date": datetime.strptime("2024-05-23", "%Y-%m-%d").date(),
    "items": [
        {"name": "Banana", "amount": 1.5, "unit_price": 1.99}],
    'items_discount': []},
    {"order_id": 1276, "order_date": datetime.strptime("2024-05-10", "%Y-%m-%d").date(),
    "items": [
        {"name": "Apple", "amount": 2.0, "unit_price": 2.99},
        {"name": "Banana", "amount": 0.5, "unit_price": 1.99},
        {"name": "Strawberry", "amount": 1.0, "unit_price": 6.99},
        {"name": "Strawberry", "amount": 1.0, "unit_price": 6.99},
        {"name": "Peach", "amount": 1.0, "unit_price": 3.39}],
    'items_discount': ['Peach', 'Banana']},
    {"order_id": 8763, "order_date": datetime.strptime("2024-05-14", "%Y-%m-%d").date(),
    "items": [
        {"name": "Strawberry", "amount": 1.0, "unit_price": 6.99},
        {"name": "Peach", "amount": 1.0, "unit_price": 3.39},
        {"name": "Mango", "amount": 1.5, "unit_price": 5.99}],
    'items_discount': ['Mango']},
    {"order_id": 7652, "order_date": datetime.strptime("2024-05-22", "%Y-%m-%d").date(),
    "items": [
        {"name": "Banana", "amount": 1.0, "unit_price": 1.99},
        {"name": "Mango", "amount": 1.5, "unit_price": 5.99}],
    'items_discount': ['Mango', 'Banana']},
    {"order_id": 7631, "order_date": datetime.strptime("2024-05-22", "%Y-%m-%d").date(),
    "items": [
        {"name": "Banana", "amount": 1.0, "unit_price": 1.99},
        {"name": "Banana", "amount": 2.5, "unit_price": 1.99},],
    'items_discount': []}
]

# Define the schema
schema = StructType([
    StructField('order_id', IntegerType(), False),
    StructField('order_date', DateType(), False),
    StructField(
        'items',
        ArrayType(
            StructType([
                StructField('name', StringType(), False),
                StructField('amount', FloatType(), False),
                StructField('unit_price', FloatType(), False)
            ]),
            False
        ),
        False
    ),
    StructField("items_discount", ArrayType(StringType()), True)
])


# Create DataFrame
df_fruitshop = spark.createDataFrame(data, schema=schema)

df_fruitshop.display()

In [None]:
df_fruitshop.printSchema()

As you can see, the `items` column has a complex data type, the `ArrayType`.

The arrays that form this column hold elements of other complex type, the `StructType`, which are composed by a list of `StructField` elements.

These elements have names `name`, `unit_price` and `amount`, with types `StringType`, `DoubleType`and `DoubleType`, and none of them can have NULL values.

The `items_disount` column is simply an array of strings.

The next sections will show how to use some functions to manipulate these complex data types.

### Transform

The [transform()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.transform.html) function applies a function to each element of an array column.

Let's say we want to get the name of each product in the `items` column.

In [None]:
# Define the transform function. This receives a column and a function to apply to the column.
fn = f.transform(f.col('items'), lambda x: x['name'])

df_fruitshop_transformed = (
    df_fruitshop
    # Apply the transformation to the dataframe and save the results to a new column
    .withColumn('item_names', fn)
)

df_fruitshop_transformed.display()

Check the schema of the new DataFrame to see the type of column `item_names`

In [None]:
df_fruitshop_transformed.printSchema()

The `item_names` column is an array of strings, where each string is the name of a product.

### Array distinct

The [array_distinct()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.array_distinct.html) function returns an array of distinct elements in the input array.

Let's get the list of unique items in each order, as well as the number of unique items.

In [None]:
df_fruitshop_unique = (
    df_fruitshop_size
    .withColumn(
        'unique_item_names',
        f.array_distinct(f.col('item_names'))
    )
    .withColumn(
        'nr_unique_items',
        f.size(f.col('unique_item_names'))
    )
)

df_fruitshop_unique.display()

### Array contains

The [array_contains()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.array_contains.html) function checks if an array contains a specific value.

Create a new column of boolean values indicating if the order contains the `Banana` item.

In [None]:
df_fruitshop_contains = (
    df_fruitshop_unique
    .withColumn(
        'contains_banana',
        f.array_contains(f.col('unique_item_names'), 'Banana')
    )
)

df_fruitshop_contains.display()

Since this function returns a column of boolean values, we could also use it to filter the DataFrame and keep only the orders in which `Banana` was purchased.

In [None]:
df_fruitshop_filtered = (
    df_fruitshop_unique
    .filter(
        f.array_contains(f.col('unique_item_names'), 'Banana')
    )
)

df_fruitshop_filtered.display()

### Explode

The [explode()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.explode.html) function creates a new row for each element in an array or map.

Let's try it out. Transform the original DataFrame to get a new DataFrame with one row for each unique item purchased in each order.

In [None]:
df_fruitshop_exploded = (
    df_fruitshop_unique
    .select(
        'order_id',
        'order_date',
        f.explode(f.col('unique_item_names')).alias('item')
    )
)

df_fruitshop_exploded.display()

### Collect list

The [collect_list()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.collect_list.html) function  is an aggregate function that can be applied over a [Window](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Window.html) or using the [groupBy()](https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.groupBy.html) method.

It allows you to collect the results of an aggregation into an array.

Let's revert the `explode()` operation and collect the list of unique items purchased in each order.

In [None]:
df_fruitshop_collected = (
    df_fruitshop_exploded
    .groupBy('order_id', 'order_date')
    .agg(
        f.collect_list(f.col('item')).alias('unique_item_names')
    )
).display()

### Inline

The [inline()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.inline.html) function explodes an array of structs into a DataFrame.

Let's say we want to create a new table with the name, unit_price and amount of each item purchased in each order.

In [None]:
(
    df_fruitshop_unique
    .select(
        'order_id',
        f.inline(f.col('items'))
    )
).display()

Now let's use the `explode` function to see the difference between the two.

In [None]:
(
    df_fruitshop_unique
    .select(
        'order_id',
        f.explode(f.col('items'))
    )
).display()

The difference is that `explode` creates a new row for each element in the array, while `inline` not only creates a new row for each element in the array but also creates a new column for each field in the struct.

### Other functions

There are a lot more collection functions you can use with complex types, like:

- `concat()`: Concatenate multiple arrays.
- `arrays_zip()`: Combine multiple arrays into a single array of structs.
- `array_except()`: Return an array containing elements from the first array that are not present in the second array.
- `array_intersect()`: Return an array containing elements that are present in both input arrays.
- `array_union()`: Return an array containing elements from both input arrays, without duplicates.
- `array_sort()`:Sort the elements of an array in ascending order.
- `array_max()`: Find the maximum value in an array.
- `array_position()`: Find the position (index) of a specific value in an array.
- `flatten()`: Flatten a nested array structure.
- `map_filter(col, f)`: Returns a map whose key-value pairs satisfy a predicate.

etc.

A complete list of functions can be found in the [PySpark SQL Functions documentation](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#collection-functions).

---

In this notebook, we learned about complex data types in PySpark, such as ArrayType, StructType, and MapType.

Go to the `exercises` notebook to practice what you've learned!

But before that, save the `df_fruitshop` DataFrame to the a csv file in the dbfs to use it in the exercises.

In [None]:
df_fruitshop.write.parquet('/FileStore/lp-big-data/fruitshop.parquet')