In [None]:
from pyspark.sql.functions import *
import pytest

# Load the raw data from the data lake 

In [None]:
def load_data():
    df_orders = spark.read.csv('abfss://<container>@<account-name>.dfs.core.windows.net/<path>/orders.csv', header='true', inferSchema='true')
    return df_orders

# Clean the data

In [None]:
def remove_duplicate_orders(df):
    #return df.distinct()
    return df.dropDuplicates(["OrderId"])

# Calculate reporting metrics

In [None]:
def calculate_sales_by_region(df):
    return df \
        .select("Region", "TotalPrice") \
        .groupBy("Region") \
        .sum("TotalPrice") \
        .withColumnRenamed("sum(TotalPrice)", "TotalSales")

# Save the output

In [None]:
def save_output(df):
    df.repartition(1) \
        .write.mode("overwrite") \
        .option("header", "true") \
        .csv('abfss://<container>@<account-name>.dfs.core.windows.net/<path>/output/')

    df.show()

# Tests

In [None]:
orders_schema = ["OrderId","OrderDate", "Region", "City", "Category","Product","Quantity","UnitPrice","TotalPrice"]

In [None]:
def test_orders_with_duplicated_order_id_are_removed():
    
    # Arrange
    df = spark.createDataFrame(
            [
                (10,"01/01/2020","North","Chicago","Bars","Carrot",33,1.77,58.41),
                (10,"11/03/2020","North","Chicago","Bars","Carrot",33,1.77,58.41),
            ],
            orders_schema 
        )

    #Act
    df_result = remove_duplicate_orders(df)

    #Assert
    assert df_result, "No data frame returned from remove_duplicate_orders()"

    expected_orders = 1
    number_of_orders = df_result.count()
    assert number_of_orders == 1, f'Expected {expected_orders} order after remove_duplicate_orders() but {number_of_orders} returned.'

In [None]:
def test_similar_orders_with_different_order_id_are_not_removed():
    
    # Arrange
    df = spark.createDataFrame(
            [
                (10,"01/01/2020","North","Chicago","Bars","Carrot",33,1.77,58.41),
                (11,"01/01/2020","North","Chicago","Bars","Carrot",33,1.77,58.41),
                (12,"01/01/2020","North","Chicago","Bars","Carrot",33,1.77,58.41),
            ],
            orders_schema 
        )

    #Act
    df_result = remove_duplicate_orders(df)

    #Assert
    expected_orders = 3
    number_of_orders = df_result.count()
    assert number_of_orders == 3, f'Expected {expected_orders} order after remove_duplicate_orders() but {number_of_orders} returned.'

In [None]:
def test_regional_sales_are_calculated_correctly():

    # Arrange
    df = spark.createDataFrame(
            [
                (7,"19/01/2020","East","Boston","Crackers","Whole Wheat",149,3.49,520.01),
                (8,"22/01/2020","West","Los Angeles","Bars","Carrot",51,1.77,90.27),
                (9,"25/01/2020","East","New York","Bars","Carrot",100,1.77,177.00),
                (10,"28/01/2020","East","New York","Snacks","Potato Chips",28,1.35,37.8),
            ],
            orders_schema 
        )
        
    #Act
    df_result = calculate_sales_by_region(df)

    #Assert
    expected_sales_east = 734.81
    sales_east = df_result.where(df_result["Region"] == "East").collect()[0]["TotalSales"]

    assert expected_sales_east == sales_east, f'Expected regional sales to be {expected_sales_east} for East region but {sales_east} returned.'

# Run the workflow

In [None]:
# Tests
test_orders_with_duplicated_order_id_are_removed()
test_similar_orders_with_different_order_id_are_not_removed()
test_regional_sales_are_calculated_correctly()

#ETL
df_orders = load_data()
df_unique_orders = remove_duplicate_orders(df_orders)
df_sales_by_region = calculate_sales_by_region(df_unique_orders)
save_output(df_sales_by_region)