In [None]:
import pyspark.sql.functions as f
from pyspark.sql import Window
from datetime import datetime
from pyspark.sql.types import *

Let's answer some questions about our FruitShop orders data.

Create the DataFrame and recall the schema

In [None]:
# Sample data
data = [
    {"order_id": 5642, "order_date": datetime.strptime("2024-05-18", "%Y-%m-%d").date(),
    "items": [
        {"name": "Apple", "amount": 1.0, "unit_price": 2.99},
        {"name": "Banana", "amount": 1.7, "unit_price": 1.99}],
    'items_discount': ['Apple']},
    {"order_id": 9762, "order_date": datetime.strptime("2024-05-02", "%Y-%m-%d").date(),
    "items": [
        {"name": "Strawberry", "amount": 0.5, "unit_price": 6.99},
        {"name": "Apple", "amount": 3.0, "unit_price": 2.99},
        {"name": "Peach", "amount": 2.5, "unit_price": 3.39}],
    'items_discount': ['Apple', 'Peach']},
    {"order_id": 3652, "order_date": datetime.strptime("2024-05-23", "%Y-%m-%d").date(),
    "items": [
        {"name": "Banana", "amount": 1.5, "unit_price": 1.99}],
    'items_discount': []},
    {"order_id": 1276, "order_date": datetime.strptime("2024-05-10", "%Y-%m-%d").date(),
    "items": [
        {"name": "Apple", "amount": 2.0, "unit_price": 2.99},
        {"name": "Banana", "amount": 0.5, "unit_price": 1.99},
        {"name": "Strawberry", "amount": 1.0, "unit_price": 6.99},
        {"name": "Strawberry", "amount": 1.0, "unit_price": 6.99},
        {"name": "Peach", "amount": 1.0, "unit_price": 3.39}],
    'items_discount': ['Peach', 'Banana']},
    {"order_id": 8763, "order_date": datetime.strptime("2024-05-14", "%Y-%m-%d").date(),
    "items": [
        {"name": "Strawberry", "amount": 1.0, "unit_price": 6.99},
        {"name": "Peach", "amount": 1.0, "unit_price": 3.39},
        {"name": "Mango", "amount": 1.5, "unit_price": 5.99}],
    'items_discount': ['Mango']},
    {"order_id": 7652, "order_date": datetime.strptime("2024-05-22", "%Y-%m-%d").date(),
    "items": [
        {"name": "Banana", "amount": 1.0, "unit_price": 1.99},
        {"name": "Mango", "amount": 1.5, "unit_price": 5.99}],
    'items_discount': ['Mango', 'Banana']},
    {"order_id": 7631, "order_date": datetime.strptime("2024-05-22", "%Y-%m-%d").date(),
    "items": [
        {"name": "Banana", "amount": 1.0, "unit_price": 1.99},
        {"name": "Banana", "amount": 2.5, "unit_price": 1.99},],
    'items_discount': []}
]

# Define the schema
schema = StructType([
    StructField('order_id', IntegerType(), False),
    StructField('order_date', DateType(), False),
    StructField(
        'items',
        ArrayType(
            StructType([
                StructField('name', StringType(), False),
                StructField('amount', FloatType(), False),
                StructField('unit_price', FloatType(), False)
            ]),
            False
        ),
        False
    ),
    StructField("items_discount", ArrayType(StringType()), True)
])


# Create DataFrame
df_fruitshop = spark.createDataFrame(data, schema=schema)

df_fruitshop.printSchema()

Answer the following questions. Good luck!

1. What is the average number of different items sold per order?

In [None]:
(
    df_fruitshop
    .withColumn(
        'unique_item_names',
        f.array_distinct(f.transform(f.col('items'), lambda x: x['name']))
    )
    .withColumn("num_items", f.size("unique_item_names"))
    .select(
        f.avg("num_items").alias("average_items_per_order")
    )
).display()

2. For each fruit, what is the total amount sold among all orders?

***Hint:*** Start by transforming the data into a more usable format.

In [None]:
(
    df_fruitshop
    .select(f.inline('items'))
    .groupBy('name')
    .agg(f.sum('amount').alias('total_amount'))
).display()

3. What was the total amount of `Peach` that was sold among all orders where `Peach` was at discount?

In [None]:
(
    df_fruitshop
    # Explode items to get one row per item
    .select(
        'order_id',
        'items_discount',
        f.explode(f.col('items')).alias('item')
    )
    # Filter for Peach items with discount
    .filter(
        (f.col('item.name') == 'Peach')
        & (f.array_contains(f.col('items_discount'), 'Peach'))
    )
    .select(
        f.sum('item.amount').alias('total_amount')
    )
).display()

4. What is the price of the most expensive item in each order?

In [None]:
(
    df_fruitshop
    # Get item prices in a seperate array
    .withColumn(
        'items_price',
        f.transform(f.col('items'), lambda x: x['unit_price'])
    )
    # Get the maximum value in the prices array
    .withColumn(
        'max_price',
        f.array_max(f.col('items_price'))
    )
).display()

5. What is the name of the most expensive item in each order?

***Hint:*** Use a window function or a groupBy to solve this question.

In [None]:
# Solution with window function
window = Window.partitionBy('order_id')

(
    df_fruitshop
    .select('order_id', f.inline('items'))
    .withColumn(
        'max_price',
        f.max(f.col('unit_price')).over(window)
    )
    .filter(f.col('unit_price') == f.col('max_price'))
    .dropDuplicates()
    .select(
        'order_id',
        'name'
    )
).display()

In [None]:
(
    df_fruitshop
    .select('order_id', f.inline('items'))
    .orderBy('order_id', f.desc('unit_price'))
    .groupBy('order_id')
    .agg(f.first('name').alias('name'))
).display()

Solution without `inline` or `explode`

**NOTE:** This code is only supported in spark versions >= 3.4.0. Make sure you initialize a cluster with runtime 13.3 LTS or above. You can do this in 'Create new resource' and then select the desired runtime

In [None]:
(
    df_fruitshop
    .withColumn(
        'sorted_items',
        # The array_sorte receives a columns and a function that compares
        # two elements in an array at a time. This comparison function should
        # return -1 if the first is smaller than the second and 1 the other way around
        f.array_sort(
            'items',
            lambda x, y: f.when(x.getField('unit_price') < y.getField('unit_price'), 1).otherwise(-1)
        )
    )
    .select(
        'order_id',
        # The new column 'sorted_items' is a Struct. We can extract elements
        # from a struct using the getField method
        f.element_at('sorted_items', 1).getField('name').alias('name'),
        f.element_at('sorted_items', 1).getField('unit_price').alias('max_price')
    )
).display()