In [None]:
import pyspark.sql.functions as f
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    IntegerType,
    ArrayType,
    MapType,
)

# Complex Data Types - Part I

In this notebook, we will learn about complex data types in Python.

Until now, we have worked with simple data types like integers, floats, and strings. But in real-world applications, we often need to work with more complex data types.

These complex data types are usually composed by combining multiple simple data types together.

## Array Type

An array is an ordered collection of elements of the same type. It is useful for representing a collection of values that are related in some way.

Elements in an array can be accessed using an index.

Let's create a PySpark DataFrame with a column of values belonging to [ArrayType](https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.types.ArrayType.html) to understand how it works.

In [None]:
# Create some sample data
data = [("Sara", ["Portugal", "Spain", "France"], 1, 2),
        ("John", ["UK", "Belgium", "Netherlands", "Denmark"], 2, 1),
        ("Steffano", ["Italy", "Croatia", "Switzerland", "Portugal", "UK"], 2, 5),
        ("Marina", ["Spain", "Portugal", "France", "UK"], 4, 1)]

# Define a schema for the DataFrame
schema = StructType([
    StructField("name", StringType(), True),
    StructField("visited_countries", ArrayType(StringType()), True),
    StructField("best_idx", IntegerType()),
    StructField("worst_idx", IntegerType())
])

# Create the DataFrame
df = spark.createDataFrame(data, schema=schema)

df.display()

And now let's check the schema:

In [None]:
df.printSchema()

You can see that the column `visited_countries` is an Array column with elements of type String. This column contains the countries visited by each person, ordered by the visitation date.

The columns `best_idx` and `worst_idx` indicate the index of the best and worst countries visited by each person, respectively, on the `visited_countries` array.

As an introduction, let's explore some simple functions we can apply to `ArrayType` columns.

We are going to create two new columns with the best and worst country names. We need to access the values in `visited_countries` using the `best_idx` and `worst_idx`.

For that we'll use the PySpark SQL [element_at()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.element_at.html) function, which receives an Array column and an index (which can also be a column of indexes). 

***Note:*** In Spark, the first element of an array has index 1, not 0 like in Python.

In [None]:
df_best_worst = (
    df
    .withColumn(
        'best_country',
        f.element_at(f.col('visited_countries'), f.col('best_idx'))
    )
    .withColumn(
        'worst_country',
        f.element_at(f.col('visited_countries'), f.col('worst_idx'))
    )
)

df_best_worst.display()

Now let's find the second and third countries visited by each person.

For that, we need to slice the `visited_countries` array from index 2 to index 3, using the [slice()](https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.functions.slice.html) function.

In [None]:
df_2nd_3rd = (
    df
    .withColumn(
        '2nd_3rd_countries',
        f.slice(f.col('visited_countries'), start=2, length=2)
    )
)

df_2nd_3rd.display()

Let's check how many countries each person visited using the [size()](https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.functions.size.html) function.

In [None]:
df_nr_countries = (
    df
    .withColumn(
        'nr_countries',
        f.size(f.col('visited_countries'))
    )
)

df_nr_countries.display()

Finally, let's collapse the best and worst countries into a single Array column

For that we need to create an array using the [array()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.array.html) function, and assign it to a new column.

In [None]:
(
    df_best_worst
    .withColumn(
        'best_worst_countries',
        f.array(f.col('best_country'), f.col('worst_country'))
    )
).display()

There are several more functions that can be applied to arrays. We'll see them in more detail on the next lesson.

## Struct Type

A column of type `StructType` represents a structured record with a fixed set of fields.

Each field inside a Struct has type [StructField](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.types.StructField.html#pyspark.sql.types.StructField).

Structs are useful for representing structured data where each field has a name and a type. Fields in a struct can be accessed by their name.

In PySpark, structs are commonly used to represent nested or hierarchical data structures within a DataFrame. As you may recall, we use this data type to define the schema of a DataFrame when loading data from a file.

You can think of Structs as Python dictionaries, where the dictionary keys and value types are predefined and are the same for all records.

Let's create a PySpark DataFrame with a column of values belonging to [StructType](https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.types.StructType.html) to understand how it works.

In [None]:
# Create some sample
data = [(1, ("Sara", 30, "Female")),
        (2, ("John", 35, "Male")),
        (3, ("Steffano", 40, "Male")),
        (4, ("Marina", 20, "Female"))]

# Define the schema for the DataFrame
schema = StructType([
    # Define the struct fields
    StructField("person_id", IntegerType(), False),
    StructField(
        "person_info",
        # A struct field can also have type StructType
        StructType([
            # Which in turn contains struct fields
            StructField("name", StringType(), True),
            StructField("age", IntegerType(), True),
            StructField("gender", StringType(), True)
        ]),
        False
    )
])


# Create the DataFrame
df = spark.createDataFrame(data, schema=schema)

df.display()

df.printSchema()

As you can see in the schema, `person_info` is of type `StructType`. This struct has one field for the name, other for the age, and other for the gender of the person.

Let's get each person's name, age and gender in separate columns. We can do this by getting the field values from the struct using the [getField()](https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.Column.getField.html) column function.

In [None]:
df_new_columns = (
    df
    .select(
        f.col('person_id'),
        f.col('person_info').getField('name').alias('name'),
        f.col('person_info').getField('age').alias('age'),
        f.col('person_info').getField('gender').alias('gender')
    )
)

df_new_columns.display()

You can also get the field values like this:

In [None]:
df_new_columns = (
    df
    .select(
        'person_id',
        'person_info.name',
        'person_info.age',
        'person_info.gender',
    )
)

df_new_columns.display()

We can reverse the process and create the struct column again using the [struct()](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.struct.html) function.

In [None]:
df_struct = (
    df_new_columns
    .select(
        f.col('person_id'),
        f.struct('name','age','gender').alias('person_info')
    )
)

df_struct.display()

## Map Type

A Map is an unordered collection of key-value pairs. It is useful for representing a collection of values that are indexed by a key.

The Map data type also resembles a Python dictionary, where each key is associated with a value.

Unlike structs, maps are not limited to a fixed set of fields and their keys and values may vary across rows.

All keys in a map are of the same type, and all values are of the same type as well.

Let's create a PySpark DataFrame with a column of values belonging to [MapType](https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.types.MapType.html) to understand how it works.

In [None]:
# Create some sample data
data = [("Sara", {"Maths": 90, "Science": 75, "History": 95}),
        ("John", {"Biology": 85, "Maths": 75, "Chemistry": 88, "Literature": 75}),
        ("Stefanno", {"Maths": 90, "Science": 75, "Economy": 85}),
        ("Marina", {"Science": 68, "History": 100})]

# Define the schema for the DataFrame
schema = StructType([
    StructField("name", StringType(), True),
    StructField("grades", MapType(StringType(), IntegerType()), True)
])

# Create the DataFrame
df = spark.createDataFrame(data, schema=schema)

df.display()

df.printSchema()


The DataFrame has one column `name` and another `grades` of type `MapType`, which contains the grades for different subjects. Since different students may have taken different subjects, it makes sense to use a `MapType` to save this data instead of a `StructType`.

The grades are saved as values between 0 and 100. Let's convert them to values between 0 and 20.

To do this, we can transform the values in a Map column using the [transform_values()](https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.functions.transform_values.html) function.

In [None]:
(
    df.withColumn('transformed_grades', f.transform_values('grades', lambda k, v: v*20/100))
).display()

Now, let's find the subjects where the grade is bigger than 80 and save their grades in a new column using the [map_filter()](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.map_filter.html) function.

In [None]:
(
    df.withColumn('grades_above_80', f.map_filter('grades', lambda k, v: v > 80))
).display()

---

In this notebook, we learned about complex data types in PySpark, such as ArrayType, StructType, and MapType.

Go to the `exercises-part1` notebook to practice what you've learned!