In [None]:
import pyspark.sql.functions as f
import matplotlib.pyplot as plt
import seaborn as sns

In the PySpark DataFrames - Part 3 notebook we've seen more methods and functions that you can use to manipulate DataFrames in PySpark to perform some data analysis.

In this notebook there are some more of those questions that you can answer using PySpark DataFrames methods and SQL functions.

Some of them will require you to look into the documentation to find the right function to use, which we think is the best way to learn how to use PySpark.

So, let's get started!

First, let's get the preprocessed orders and products dataframes.

In [None]:
df_orders = (
    spark.read.format("csv")
    .option("header", "true")
    .option("inferSchema", "true")
    .option("sep", ",")
    .load("/FileStore/lp-big-data/orders-data/orders-preprocessed.csv")
)

df_products = (
    spark.read.format("csv")
    .option("header", "true")
    .option("inferSchema", "true")
    .option("sep", ",")
    .load("/FileStore/lp-big-data/orders-data/products-preprocessed.csv")
)

Join the tables on the product_id column and store the result in a new dataframe called _df_orders_products_.

When joining, make sure to keep all orders, even the ones without a matching product in the products dataframe.

In [None]:
df_orders_products = (
    df_orders.join(
        df_products,
        on=['product_id'],
        how='left'
    )
)

df_orders_products.display()

1. What is the ID of the most profitable customer ever? The one with the highest total profit among all his orders.

In [None]:
(
    df_orders_products
    .groupBy('customer_id')
    .agg(
        f.sum('profit').alias('total_profit')
    )
    .orderBy('total_profit', ascending=False)
    .limit(1)
).display()

2. Who were the top 3 customers with the highest number of orders in 2017?

In [None]:
(
    df_orders_products
    .filter(f.col('order_year') == 2017)
    .groupBy('customer_id')
    .agg(
        f.count('order_id').alias('total_orders')
    )
    .orderBy('total_orders', ascending=False)
    .limit(3)
).display()

3. What is the name of the supplier that has the highest amount of orders delivered with speed 'Fast'?

*Note:* The delivery speed can be found on column `delivery_speed`.

In [None]:
(
    df_orders_products
    .filter(f.col('delivery_speed') == 'Fast')
    .groupBy(['supplier_id', 'supplier_name'])
    .agg(
        f.count('order_id').alias('total_orders')
    )
    .orderBy('total_orders', ascending=False)
    .limit(1)
).display()

4. For each product line, what is the ratio between average profit and average revenue?

In [None]:
(
    df_orders_products
    .groupBy('product_line')
    .agg(
        f.avg('profit').alias('avg_profit'),
        f.avg('revenue').alias('avg_revenue'),
    )
    .withColumn('ratio', f.col('profit') / f.col('revenue'))
).display()

5. Consider only suppliers who have delivered products to at least 150 different customers. Which supplier has the greatest variety of products?

In [None]:
(
    df_orders_products
    .groupBy(['supplier_id', 'supplier_name'])
    .agg(
        f.countDistinct(f.col('product_id')).alias('unique_products'),
        f.countDistinct('customer_id').alias('nr_unique_customers')
    )
    .filter(f.col('nr_unique_customers') > 150)
    .orderBy(f.desc('unique_products'))
    .limit(1)
).display()

6. Is there a relationship between delivery speed and product line?

**Bonus:** Crate a visualization to better explore this relationship.

In [None]:
df_relationship = (
    df_orders_products
    .groupBy('delivery_speed')
    .pivot('product_line')
    .agg(
        f.count('order_id')
    )
)

In [None]:
# Convert the Spark DataFrame to a Pandas DataFrame
df_pandas = df_relationship.toPandas()

# Set the 'delivery_speed' column as the index
df_pandas.set_index('delivery_speed', inplace=True)

# Plot the heatmap
plt.figure(figsize=(10, 6))  # Adjust the figure size as needed
sns.heatmap(df_pandas, annot=True, fmt="d", cmap="YlGnBu", cbar=True)

# Customize the plot
plt.title('Delivery Speed VS Product Line')
plt.ylabel('Delivery Speed')
plt.xlabel('Product Line')

# Show the heatmap
plt.show()