In [1]:
from pyspark.sql import SparkSession, DataFrame, functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType
import logging
from pathlib import Path
from functools import reduce

In [2]:
logger = logging.getLogger(__name__)

In [3]:
def create_spark_session():
    spark = SparkSession.builder \
        .appName("UrbanFresh_Data_Pipeline") \
        .config("spark.sql.adaptative.enabled", "true") \
        .getOrCreate()
    return spark

In [4]:
def extract_sales_data(spark, file_path):
    logger.info(f"Extracting sales data from {file_path}")

    expected_schema = StructType([
        StructField("order_id", StringType(), True),
        StructField("customer_id", StringType(), True),
        StructField("product_name", StringType(), True),
        StructField("price", StringType(), True),
        StructField("quantity", StringType(), True),
        StructField("order_date", StringType(), True),
        StructField("region", StringType(), True)
    ])
    try:
        sales_df = spark.read.schema(expected_schema) \
            .csv(file_path, header=True, mode="PERMISSIVE")
        
        logger.info("Sales data extracted successfully.")

        return sales_df
    
    except Exception as e:
        logger.error(f"Error extracting sales data from {file_path}: {e}")
        raise

In [17]:
def extract_all_data(spark):
    data_dir = Path("data/raw")
    
    try:
        files = [
            file_path
            for file_path in data_dir.glob("*.csv")
            if "orders" in file_path.name.lower()
        ]
        if not files:
            logger.warning(f"No CSV files found in {data_dir}")
            
            raise FileNotFoundError(
                f"No sales CSV files found in {data_dir.resolve()}"
            )
                
        logger.info(f"Found {len(files)} CSV files in {data_dir}")
    
        # Use a generator expression to create DataFrames for each sales file and then union them together
        dataframes = (
            extract_sales_data(spark, str(file_path))
            for file_path in files
        )
        return reduce(DataFrame.unionByName, dataframes)

    except Exception as e:
        logger.error(f"Error during data extraction: {e}")
        raise

In [18]:
def clean_customer_id(df):
    return df.withColumn(
        "customer_id",
        #Select only rows the are not using the standard format for customer_id
        F.when((~F.col("customer_id").startswith("CUST_")) \
               & (F.col("customer_id").rlike("\\d+")), \
                F.concat(
                    F.lit("CUST_"), 
                    F.regexp_extract(F.col("customer_id"), "\\d+", 0))) \
         .otherwise(F.col("customer_id"))
    )

In [None]:
spark = create_spark_session()
df = extract_all_data(spark)

                                                                                

Unnamed: 0,order_id,customer_id,product_name,price,quantity,order_date,region
0,MOB_3001,CUST_8821,Whole Wheat Tortillas,3.99,2,2024-10-15,North
1,MOB_3002,1923,Hummus,$5.50,1,2024-10-16,South
2,MOB_3003,CUST_4512,Salsa,$3.25,2,2024-10-17,East
3,MOB_3004,7634,Guacamole,4.99,1,2024-10-18,West
4,MOB_3005,CUST_9123,Tortilla Chips,$2.99,3,2024-10-19,North


In [21]:
df = clean_customer_id(df)
df.toPandas().head()

                                                                                

Unnamed: 0,order_id,customer_id,product_name,price,quantity,order_date,region
0,MOB_3001,CUST_8821,Whole Wheat Tortillas,3.99,2,2024-10-15,North
1,MOB_3002,CUST_1923,Hummus,$5.50,1,2024-10-16,South
2,MOB_3003,CUST_4512,Salsa,$3.25,2,2024-10-17,East
3,MOB_3004,CUST_7634,Guacamole,4.99,1,2024-10-18,West
4,MOB_3005,CUST_9123,Tortilla Chips,$2.99,3,2024-10-19,North
