In [None]:
import pandas as pd
import pyspark.sql.functions as f
from datetime import datetime
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    IntegerType,
    ArrayType,
    MapType,
    FloatType,
    DateType,
)

# Complex Data Types

In this notebook, we will learn about complex data types in Python.

Until now, we have worked with simple data types like integers, floats, and strings. But in real-world applications, we often need to work with more complex data types.

These complex data types are usually composed by combining multiple simple data types together.

## Array Type

An array is an ordered collection of elements of the same type. It is useful for representing a collection of values that are related in some way.

Elements in an array can be accessed using an index.

Let's create a PySpark DataFrame with a column of values belonging to [ArrayType](https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.types.ArrayType.html) to understand how it works.

In [None]:
# Create some sample data
data = [("Sara", ["Portugal", "Spain", "France"], 1, 2),
        ("John", ["UK", "Belgium", "Netherlands", "Denmark"], 2, 1),
        ("Steffano", ["Italy", "Croatia", "Switzerland", "Portugal", "UK"], 2, 5),
        ("Marina", ["Spain", "Portugal", "France", "UK"], 4, 1)]

# Define a schema for the DataFrame
schema = StructType([
    StructField("name", StringType(), True),
    StructField("visited_countries", ArrayType(StringType()), True),
    StructField("best_idx", IntegerType()),
    StructField("worst_idx", IntegerType())
])

# Create the DataFrame
df = spark.createDataFrame(data, schema=schema)

df.display()

And now let's check the schema:

In [None]:
df.printSchema()

You can see that the column `visited_countries` is an Array column with elements of type String. This column contains the countries visited by each person, ordered by the visitation date.

The columns `best_idx` and `worst_idx` indicate the index of the best and worst countries visited by each person, respectively, on the `visited_countries` array.

As an introduction, let's explore some simple functions we can apply to `ArrayType` columns.

We are going to create two new columns with the best and worst country names. We need to access the values in `visited_countries` using the `best_idx` and `worst_idx`.

For that we'll use the PySpark SQL [element_at()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.element_at.html) function, which receives an Array column and an index (which can also be a column of indexes). 

***Note:*** In Spark, the first element of an array has index 1, not 0 like in Python.

In [None]:
df_best_worst = (
    df
    .withColumn(
        'best_country',
        f.element_at(f.col('visited_countries'), f.col('best_idx'))
    )
    .withColumn(
        'worst_country',
        f.element_at(f.col('visited_countries'), f.col('worst_idx'))
    )
)

df_best_worst.display()

Now let's find the second and third countries visited by each person.

For that, we need to slice the `visited_countries` array from index 2 to index 3, using the [slice()](https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.functions.slice.html) function.

In [None]:
df_2nd_3rd = (
    df
    .withColumn(
        '2nd_3rd_countries',
        f.slice(f.col('visited_countries'), start=2, length=2)
    )
)

df_2nd_3rd.display()

Finally, let's collapse the best and worst countries into a single Array column

For that we need to create an array using the [array()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.array.html) function, and assign it to a new column.

In [None]:
(
    df_best_worst
    .withColumn(
        'best_worst_countries',
        f.array(f.col('best_country'), f.col('worst_country'))
    )
).display()

There are several more functions that can be applied to arrays. We'll see them in more detail later.

## Struct Type

A Struct represents a structured record with a fixed set of fields. These fields are of type [StructField](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.types.StructField.html#pyspark.sql.types.StructField), which is a complex data type itself.

Structs are useful for representing structured data where each field has a name and a type. Fields in a struct can be accessed by their name.

In PySpark, structs are commonly used to represent nested or hierarchical data structures within a DataFrame. As you may recall, we use this data type to define the schema of a DataFrame when loading data from a file.

Let's create a PySpark DataFrame with a column of values belonging to [StructType](https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.types.StructType.html) to understand how it works.

In [None]:
# Create some sample
data = [(1, ("Sara", 30, "Female")),
        (2, ("John", 35, "Male")),
        (3, ("Steffano", 40, "Male")),
        (4, ("Marina", 20, "Female"))]

# Define the schema for the DataFrame
schema = StructType([
    # Define the struct fields
    StructField("person_id", IntegerType(), False),
    StructField(
        "person_info",
        # A struct field can also have type StructType
        StructType([
            # Which in turn contains struct fields
            StructField("name", StringType(), True),
            StructField("age", IntegerType(), True),
            StructField("gender", StringType(), True)
        ]),
        False
    )
])


# Create the DataFrame
df = spark.createDataFrame(data, schema=schema)

df.display()

df.printSchema()

As you can see in the schema, `person_info` is of type `StructType`. This struct has one field for the name, other for the age, and other for the gender of the person.

Let's get each person's name, age and gender in separate columns. We can do this by getting the field values from the struct using the [getField()](https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.Column.getField.html) column function.

In [None]:
df_new_columns = (
    df
    .select(
        f.col('person_id'),
        f.col('person_info').getField('name').alias('name'),
        f.col('person_info').getField('age').alias('age'),
        f.col('person_info').getField('gender').alias('gender')
    )
)

df_new_columns.display()

You can also get the field values like this:

In [None]:
df_new_columns = (
    df
    .select(
        'person_id',
        'person_info.name',
        'person_info.age',
        'person_info.gender',
    )
)

df_new_columns.display()

We can reverse the process and create the struct column again using the [struct()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.struct.html) function.

In [None]:
df_struct = (
    df_new_columns
    .select(
        f.col('person_id'),
        f.struct(
            f.col('name'),
            f.col('age'),
            f.col('gender')
        ).alias('person_info')
    )
)

df_struct.display()

## Map Type

A Map is an unordered collection of key-value pairs. It is useful for representing a collection of values that are indexed by a key.

The Map data type resembles a Python dictionary, where each key is associated with a value.

Unlike structs, maps are not limited to a fixed set of fields and their keys and values may vary across rows.

All keys in a map are of the same type, and all values are of the same type as well.

Let's create a PySpark DataFrame with a column of values belonging to [MapType](https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.types.MapType.html) to understand how it works.

In [None]:
# Create some sample data
data = [("Sara", {"Maths": 90, "Science": 75, "History": 95}),
        ("John", {"Biology": 85, "Maths": 75, "Chemistry": 88, "Literature": 75}),
        ("Stefanno", {"Maths": 90, "Science": 75, "Economy": 85}),
        ("Marina", {"Science": 68, "History": 100})]

# Define the schema for the DataFrame
schema = StructType([
    StructField("name", StringType(), True),
    StructField("grades", MapType(StringType(), IntegerType()), True)
])

# Create the DataFrame
df = spark.createDataFrame(data, schema=schema)

df.display()

df.printSchema()


The DataFrame has one column `name` and another `grades` of type `MapType`, which contains the grades for different subjects. Since different students may have taken different subjects, it makes sense to use a `MapType` to save this data instead of a `StructType`.

The grades are saved as values between 0 and 100. Let's convert them to values between 0 and 20.

To do this, we can transform the values in a Map column using the [transform_values()](https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.functions.transform_values.html) function.

In [None]:
(
    df
    .select(
        f.col('name'),
        f.transform_values('grades', lambda k, v: v*20/100).alias('transformed_grades')
        )
).display()

Now, let's find the subjects where the grade is bigger than 80 and save their grades in a new column using the [map_filter()](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.map_filter.html) function.

In [None]:
(
    df
    .select(
        f.col('name'),
        f.col('grades'),
        f.map_filter('grades', lambda k, v: v > 80).alias('grades_above_80')
        )
).display()

## Complex data types functions

Now that we've explored PySpark's complex data types and some basic functions that can be used with each, let's explore some more functions with a concrete example.

First, let's start by creating a sample dataset with information about orders from a Fruit Shop.

In [None]:
# Sample data
data = [
    {"order_id": 5642, "order_date": datetime.strptime("2024-05-18", "%Y-%m-%d").date(),
    "items": [
        {"name": "Apple", "quantity": 1.0, "price": 2.99},
        {"name": "Banana", "quantity": 1.7, "price": 1.99}],
    'items_discount': ['Apple']},
    {"order_id": 9762, "order_date": datetime.strptime("2024-05-02", "%Y-%m-%d").date(),
    "items": [
        {"name": "Strawberry", "quantity": 0.5, "price": 6.99},
        {"name": "Apple", "quantity": 3.0, "price": 2.99},
        {"name": "Peach", "quantity": 2.5, "price": 3.39}],
    'items_discount': ['Apple', 'Peach']},
    {"order_id": 3652, "order_date": datetime.strptime("2024-05-23", "%Y-%m-%d").date(),
    "items": [
        {"name": "Banana", "quantity": 1.5, "price": 1.99}],
    'items_discount': []},
    {"order_id": 1276, "order_date": datetime.strptime("2024-05-10", "%Y-%m-%d").date(),
    "items": [
        {"name": "Apple", "quantity": 2.0, "price": 2.99},
        {"name": "Banana", "quantity": 0.5, "price": 1.99},
        {"name": "Strawberry", "quantity": 1.0, "price": 6.99},
        {"name": "Strawberry", "quantity": 1.0, "price": 6.99},
        {"name": "Peach", "quantity": 1.0, "price": 3.39}],
    'items_discount': ['Peach', 'Banana']},
    {"order_id": 8763, "order_date": datetime.strptime("2024-05-14", "%Y-%m-%d").date(),
    "items": [
        {"name": "Strawberry", "quantity": 1.0, "price": 6.99},
        {"name": "Peach", "quantity": 1.0, "price": 3.39},
        {"name": "Mango", "quantity": 1.5, "price": 5.99}],
    'items_discount': ['Mango']},
    {"order_id": 7652, "order_date": datetime.strptime("2024-05-22", "%Y-%m-%d").date(),
    "items": [
        {"name": "Banana", "quantity": 1.0, "price": 1.99},
        {"name": "Mango", "quantity": 1.5, "price": 5.99}],
    'items_discount': ['Mango', 'Banana']},
    {"order_id": 7631, "order_date": datetime.strptime("2024-05-22", "%Y-%m-%d").date(),
    "items": [
        {"name": "Banana", "quantity": 1.0, "price": 1.99},
        {"name": "Banana", "quantity": 2.5, "price": 1.99},],
    'items_discount': []}
]

# Define the schema
schema = StructType([
    StructField('order_id', IntegerType(), False),
    StructField('order_date', DateType(), False),
    StructField(
        'items',
        ArrayType(
            StructType([
                StructField('name', StringType(), False),
                StructField('quantity', FloatType(), False),
                StructField('price', FloatType(), False)
            ]),
            False
        ),
        False
    ),
    StructField("items_discount", ArrayType(StringType()), True)
])


# Create DataFrame
df_fruitshop = spark.createDataFrame(data, schema=schema)

df_fruitshop.display()

In [None]:
df_fruitshop.printSchema()

As you can see, the `items` column has a complex data type, the `ArrayType`.

The arrays that form this column hold elements of other complex type, the `StructType`, which are composed by a list of `StructField` elements. These elements have names `name`, `price` and `quantity`, types `StringType`, `DoubleType`and `DoubleType`, and nullable property `True`, `True` and `True`, respectively.

The `items_disount` column is simply an array of strings.

The next sections will show how to use some functions to manipulate these complex data types.

### Transform

The [transform()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.transform.html) function applies a function to each element of an array column.

Let's say we want to get the name of each product in the `items` column.

In [None]:
# Define the transform function. This receives a column and a function to apply to the column.
fn = f.transform(f.col('items'), lambda x: x.name)

df_fruitshop_transformed = (
    df_fruitshop
    # Apply the transformation to the dataframe and save the results to a new column
    .withColumn('item_names', fn)
)

df_fruitshop_transformed.display()

Check the schema of the new DataFrame to see the type of column `item_names`

In [None]:
df_fruitshop_transformed.printSchema()

The `item_names` column is an array of strings, where each string is the name of a product.

### Size

The [size()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.size.html) function returns the size of an array or map.

Let's say we want to know the number of items in each order.

In [None]:
df_fruitshop_size = (
    df_fruitshop_transformed
    .withColumn(
        'nr_items',
        f.size(f.col('items'))
    )
)

df_fruitshop_size.display()

### Array distinct

The [array_distinct()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.array_distinct.html) function returns an array of distinct elements in the input array.

Let's get the list of unique items in each order, as well as the number of unique items.

In [None]:
df_fruitshop_unique = (
    df_fruitshop_size
    .withColumn(
        'unique_item_names',
        f.array_distinct(f.col('item_names'))
    )
    .withColumn(
        'nr_unique_items',
        f.size(f.col('unique_item_names'))
    )
)

df_fruitshop_unique.display()

### Array contains

The [array_contains()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.array_contains.html) function checks if an array contains a specific value.

Create a new column of boolean values indicating if the order contains the `Banana` item.

In [None]:
df_fruitshop_contains = (
    df_fruitshop_unique
    .withColumn(
        'contains_banana',
        f.array_contains(f.col('unique_item_names'), 'Banana')
    )
)

df_fruitshop_contains.display()

Since this function returns a column of boolean values, we could also use it to filter the DataFrame and keep only the orders in which `Banana` was purchased.

In [None]:
df_fruitshop_filtered = (
    df_fruitshop_unique
    .filter(
        f.array_contains(f.col('unique_item_names'), 'Banana')
    )
)

df_fruitshop_filtered.display()

### Explode

The [explode()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.explode.html) function creates a new row for each element in an array or map.

Let's try it out. Transform the original DataFrame to get a new DataFrame with one row for each unique item purchased in each order.

In [None]:
df_fruitshop_exploded = (
    df_fruitshop_contains
    .select(
        'order_id',
        'order_date',
        f.explode(f.col('unique_item_names')).alias('item')
    )
)

df_fruitshop_exploded.display()

### Collect list

The [collect_list()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.collect_list.html) function  is an aggregate function that can be applied over a [Window](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Window.html) or using the [groupBy()](https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.groupBy.html) method.

It allows you to collect the results of an aggregation into an array.

Let's revert the `explode()` operation and collect the list of unique items purchased in each order.

In [None]:
df_fruitshop_collected = (
    df_fruitshop_exploded
    .groupBy('order_id', 'order_date')
    .agg(
        f.collect_list(f.col('item')).alias('unique_item_names')
    )
).display()

### Inline

The [inline()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.inline.html) function explodes an array of structs into a DataFrame.

Let's say we want to create a new table with the name, price and quantity of each item purchased in each order.

In [None]:
(
    df_fruitshop_filtered
    .select(
        'order_id',
        f.inline(f.col('items'))
    )
).display()

### Other functions

There are a lot more collection functions you can use with complex types, like:

- `concat()`: Concatenate multiple arrays.
- `arrays_zip()`: Combine multiple arrays into a single array of structs.
- `array_except()`: Return an array containing elements from the first array that are not present in the second array.
- `array_intersect()`: Return an array containing elements that are present in both input arrays.
- `array_union()`: Return an array containing elements from both input arrays, without duplicates.
- `array_sort()`:Sort the elements of an array in ascending order.
- `array_max()`: Find the maximum value in an array.
- `array_position()`: Find the position (index) of a specific value in an array.
- `flatten()`: Flatten a nested array structure.
- `map_filter(col, f)`: Returns a map whose key-value pairs satisfy a predicate.

etc.

A complete list of functions can be found in the [PySpark SQL Functions documentation](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#collection-functions).

---

In this notebook, we learned about complex data types in PySpark, such as ArrayType, StructType, and MapType.

Go to the `exercises` notebook to practice what you've learned!

But before that, save the `df_fruitshop` DataFrame to the a csv file in the dbfs to use it in the exercises.

In [None]:
df_fruitshop.write.parquet('/FileStore/lp-big-data/fruitshop.parquet')