In [1]:
from pyspark.sql import SparkSession, DataFrame, functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType
import logging
from pathlib import Path
from functools import reduce

In [2]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

In [3]:
def create_spark_session():
    spark = SparkSession.builder \
        .appName("UrbanFresh_Data_Pipeline") \
        .config("spark.sql.adaptative.enabled", "true") \
        .getOrCreate()
    return spark

In [4]:
def extract_sales_data(spark, file_path):
    logger.info(f"Extracting sales data from {file_path}")

    expected_schema = StructType([
        StructField("order_id", StringType(), True),
        StructField("customer_id", StringType(), True),
        StructField("product_name", StringType(), True),
        StructField("price", StringType(), True),
        StructField("quantity", StringType(), True),
        StructField("order_date", StringType(), True),
        StructField("region", StringType(), True)
    ])
    try:
        sales_df = spark.read.schema(expected_schema) \
            .csv(file_path, header=True, mode="PERMISSIVE")
        
        logger.info("Sales data extracted successfully.")

        return sales_df
    
    except Exception as e:
        logger.error(f"Error extracting sales data from {file_path}: {e}")
        raise

In [5]:
def extract_all_data(spark):
    data_dir = Path("data/raw")
    
    try:
        files = [
            file_path
            for file_path in data_dir.glob("*.csv")
            if "orders" in file_path.name.lower()
        ]
        if not files:
            logger.warning(f"No CSV files found in {data_dir}")
            
            raise FileNotFoundError(
                f"No sales CSV files found in {data_dir.resolve()}"
            )
                
        logger.info(f"Found {len(files)} CSV files in {data_dir}")
    
        # Use a generator expression to create DataFrames for each sales file and then union them together
        dataframes = (
            extract_sales_data(spark, str(file_path))
            for file_path in files
        )
        return reduce(DataFrame.unionByName, dataframes)

    except Exception as e:
        logger.error(f"Error during data extraction: {e}")
        raise

In [6]:
def clean_customer_id(df):
    df_clean = df.withColumn(
        "customer_id",
        #Select only rows the are not using the standard format for customer_id
        F.when((~F.col("customer_id").startswith("CUST_")) \
               & (F.col("customer_id").rlike("\\d+")), \
                F.concat(
                    F.lit("CUST_"), 
                    F.regexp_extract(F.col("customer_id"), "\\d+", 0))) \
         .otherwise(F.col("customer_id"))
    )
    logger.info(f"{df.count() - df_clean.count()} records had their customer_id cleaned.")
    return df_clean

In [7]:
def clean_price_column(df):
    df = df.withColumns({
        "unit_price": 
        F.when(
            F.col("price").isNull(),
            F.lit(0.0).cast("double")
        ).otherwise(
            F.regexp_replace(F.col("price"), "[^0-9.]", "")
             .cast("double")
        )
    })

    df = df.withColumns({
        "price_quality_flag":
            F.when(F.col("unit_price") < 0, "CHECK_NEGATIVE_PRICE")
             .when(F.col("unit_price") == 0, "CHECK_ZERO_PRICE")
             .when(F.col("unit_price") > 1000, "CHECK_HIGH_PRICE")
             .otherwise("OK")
    })

    logger.info(f"Price column cleaned. {df.filter(F.col('price_quality_flag') != 'OK').count()} records flagged for review.")

    return df.drop("price")

In [8]:
def standardize_date_column(df):
    dt1 = F.to_date(F.col("order_date"), "yyyy/MM/dd")
    dt2 = F.to_date(F.col("order_date"), "MM-dd-yyyy")
    dt3 = F.to_date(F.col("order_date"), "dd-MM-yyyy")
    dt4 = F.to_date(F.col("order_date"), "yyyy-MM-dd")
    dt5 = F.to_date(F.col("order_date"), "MM/dd/yyyy")

    df = df.withColumn(
        "order_date",
        F.coalesce(dt1, dt2, dt3, dt4, dt5)
    )

    logger.warning(f"{df.filter(F.col('order_date').isNull()).count()} records with unparseable dates.")
    
    return df

In [9]:
def remove_test_data(df):
    df_cleaned = df.filter(
            ~(
                F.lower(F.col("customer_id")).contains("test_") |
                F.lower(F.col("product_name")).contains("test_") |
                F.col("order_id").isNull() |
                F.col("customer_id").isNull()
            )
    )
    logger.info(f"{df.count()-df_cleaned.count()} records removed after filtering.")
    return df_cleaned

In [10]:
def handle_duplicates(df):
    df_deduped = df.dropDuplicates(["order_id"])
    logger.info(f"{df.count() - df_deduped.count()} duplicate records removed based on order_id.")
    return df_deduped

In [14]:
def transform_data(df):
    return (
        df
        .transform(clean_customer_id)
        .transform(clean_price_column)
        .transform(standardize_date_column)
        .transform(remove_test_data)
        .withColumns({
            "quantity": F.col("quantity").cast(IntegerType()),
            "total_amount": (F.col("unit_price") * F.col("quantity")).cast(DoubleType()),
            "processing_date": F.current_date(),
            "year": F.year(F.col("order_date")),
            "month": F.month(F.col("order_date"))
        })
    )

In [15]:
df = extract_all_data(create_spark_session())
df_transformed = transform_data(df)
df_transformed.show(20, truncate=False)

0 records with unparseable dates.


+--------+-----------+---------------------+--------+----------+------+----------+------------------+------------+---------------+----+-----+
|order_id|customer_id|product_name         |quantity|order_date|region|unit_price|price_quality_flag|total_amount|processing_date|year|month|
+--------+-----------+---------------------+--------+----------+------+----------+------------------+------------+---------------+----+-----+
|MOB_3001|CUST_8821  |Whole Wheat Tortillas|2       |2024-10-15|North |3.99      |OK                |7.98        |2026-02-11     |2024|10   |
|MOB_3002|CUST_1923  |Hummus               |1       |2024-10-16|South |5.5       |OK                |5.5         |2026-02-11     |2024|10   |
|MOB_3003|CUST_4512  |Salsa                |2       |2024-10-17|East  |3.25      |OK                |6.5         |2026-02-11     |2024|10   |
|MOB_3004|CUST_7634  |Guacamole            |1       |2024-10-18|West  |4.99      |OK                |4.99        |2026-02-11     |2024|10   |
|MOB_3