# Get functions from etl_functions.py

In [1]:
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType

def create_spark_session(app_name="casestudy"):
    return SparkSession.builder \
        .master("local[4]") \
        .appName(app_name) \
        .config("spark.sql.warehouse.dir", "/path/to/warehouse") \
        .config("hive.metastore.uris", "thrift://localhost:9083") \
        .config("spark.hadoop.hive.metastore.warehouse.dir", "/path/to/hive/warehouse") \
        .config("spark.sql.catalogImplementation", "hive") \
        .config("spark.eventLog.logBlockUpdates.enabled", True) \
        .enableHiveSupport() \
        .getOrCreate()

def read_csv_to_df(spark, file_path, schema=None, sep=",", infer_schema=False, header=True):
    hdfs_path = "hdfs://localhost:9000" + file_path
    
    if infer_schema:
        df = spark.read.csv(hdfs_path, sep=sep, header=header, inferSchema=True)
    else:
        df = spark.read.csv(hdfs_path, schema=schema, sep=sep, header=header)
    
    return df


def remove_duplicates(df):
    return df.dropDuplicates()

def check_nulls(df, column_name):
    return df.filter(col(column_name).isNull()).count() > 0

def calculate_total_paid_price_after_discount(transactions_df):
    return transactions_df.withColumn(
        "total_paid_price_after_discount",
        when(col("offer_1") == "1", col("unit_price") * 0.15)
        .when(col("offer_2") == "2", col("unit_price") * 0.25)
        .when(col("offer_3") == "3", col("unit_price") * 0.35)
        .when(col("offer_4") == "4", col("unit_price") * 0.45)
        .when(col("offer_5") == "5", col("unit_price") * 0.50)
        .otherwise(col("unit_price"))
    )

def add_offer_column(transactions_df):
    return transactions_df.withColumn(
        "offer",
        when(col("offer_1").isNotNull(), "1")
        .when(col("offer_2").isNotNull(), "2")
        .when(col("offer_3").isNotNull(), "3")
        .when(col("offer_4").isNotNull(), "4")
        .when(col("offer_5").isNotNull(), "5")
    )


def insert_into_hive_table(spark, df, table_name, table_location=None, primary_key=None):
    """
    Insert data from a DataFrame into a Hive table.

    Parameters:
    - spark: SparkSession object
    - df: Spark DataFrame
    - table_name: Name of the Hive table
    - table_location: Location of the external table (optional)
    - primary_key: Primary key column(s) to identify new records (default is None)
    """
    
    table_exists = spark._jsparkSession.catalog().tableExists(table_name)

    if table_name in ["casestudy.fact_sales", "casestudy.dim_sales"]:
        print(f"Inserting data into {table_name}.")
        df.write.mode('append').insertInto(table_name)
    else:
        if table_exists:
            print(f"Table {table_name} already exists. Inserting only new records.")
            existing_data = spark.table(table_name)

            if primary_key:
                new_data = df.join(existing_data, on=primary_key, how="left_anti")
                if new_data.count() > 0:
                    new_data.write.mode('append').saveAsTable(table_name)
                    print(f"Inserted {new_data.count()} new records into {table_name}.")
                else:
                    print(f"No new records to insert into {table_name}.")
            else:
                print(f"Primary key is required to identify new records for table {table_name}.")
        else:
            print(f"Table {table_name} does not exist. Creating and inserting data.")
            df.write.mode('overwrite').saveAsTable(table_name)

def create_fact_sales(transactions_df):
    return transactions_df.select(
        "transaction_id", "customer_id", "sales_agent_id", "branch_id", "product_id", "offer", "units", "unit_price", "total_paid_price_after_discount"
    )

def create_dim_sales(transactions_df):
    return transactions_df.select(
        "transaction_id",
        col("transaction_date").cast("timestamp").alias("transaction_date"),
        "is_online",
        "payment_method",
        "shipping_address",
        col("load_date").cast("timestamp").alias("load_date"),
        "load_source"
    )

def create_dim_product(transactions_df):
    return transactions_df.select(
        "product_id",
        "product_name",
        "product_category",
        col("load_date").cast("timestamp").alias("load_date"),
        "load_source"
    ).distinct()

def create_dim_customer(transactions_df):
    schema = StructType([
        StructField("customer_id", IntegerType(), True),
        StructField("customer_fname", StringType(), True),
        StructField("customer_lname", StringType(), True),
        StructField("customer_email", StringType(), True),
        StructField("load_date", TimestampType(), True),
        StructField("load_source", StringType(), True)
    ])
    
    return transactions_df.select(
        "customer_id",
        "customer_fname",
        "cusomter_lname",
        "cusomter_email",
        col("load_date").cast("timestamp").alias("load_date"),
        "load_source"
    ).distinct()

def create_dim_branch(branches_df):
    return branches_df.select(
        "branch_id",
        "location",
        col("establish_date").cast(TimestampType()).alias("establish_date"),
        "class",
        col("load_date").cast(TimestampType()).alias("load_date"),
        "load_source"
    ).distinct()

def create_dim_agent(agents_df):
    # Cleanse null values and handle column references correctly
    cleaned_agents_df = agents_df.withColumn("hire_date", when(col("hire_date") == "", None).otherwise(col("hire_date"))) \
                                .withColumn("load_date", when(col("load_date") == "", None).otherwise(col("load_date")))

    return cleaned_agents_df.select(
        col("sales_person_id"),
        col("name").alias("sales_agent_name"),  # Ensure correct column reference here
        col("hire_date").cast(TimestampType()).alias("hire_date"),
        col("load_date").cast(TimestampType()).alias("load_date"),
        col("load_source")
    ).distinct()

# Run the main.py

In [2]:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType, BooleanType
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType, BooleanType

spark = create_spark_session()
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType, BooleanType

# Define schema for transactions DataFrame
transactions_schema = StructType([
    StructField("transaction_date", StringType(), True),
    StructField("transaction_id", StringType(), True),
    StructField("customer_id", StringType(), True),
    StructField("customer_fname", StringType(), True),
    StructField("cusomter_lname", StringType(), True),
    StructField("cusomter_email", StringType(), True),
    StructField("sales_agent_id", StringType(), True),
    StructField("branch_id", StringType(), True),
    StructField("product_id", StringType(), True),
    StructField("product_name", StringType(), True),
    StructField("product_category", StringType(), True),
    StructField("offer_1", StringType(), True),
    StructField("offer_2", StringType(), True),
    StructField("offer_3", StringType(), True),
    StructField("offer_4", StringType(), True),
    StructField("offer_5", StringType(), True),
    StructField("units", StringType(), True),
    StructField("unit_price", StringType(), True),
    StructField("is_online", StringType(), True),
    StructField("payment_method", StringType(), True),
    StructField("shipping_address", StringType(), True),
    StructField("load_date", StringType(), True),
    StructField("load_source", StringType(), True)
])

# Define schema for branches DataFrame
branches_schema = StructType([
    StructField("branch_id", StringType(), True),
    StructField("location", StringType(), True),
    StructField("establish_date", StringType(), True),
    StructField("class", StringType(), True),
    StructField("load_date", StringType(), True),
    StructField("load_source", StringType(), True)
])

# Define schema for agents DataFrame
agents_schema = StructType([
    StructField("sales_person_id", StringType(), True),
    StructField("name", StringType(), True),
    StructField("hire_date", StringType(), True),
    StructField("load_date", StringType(), True),
    StructField("load_source", StringType(), True)
])


# Reading CSV files into DataFrames with inferred schemas
transactions_df = read_csv_to_df(spark, "/casestudy/day183/hour14/sales_transactions.csv", schema=transactions_schema)
branches_df = read_csv_to_df(spark, "/casestudy/day183/hour14/branches.csv", schema=branches_schema)
agents_df = read_csv_to_df(spark, "/casestudy/day183/hour14/sales_agents.csv", schema=agents_schema)


In [3]:
transactions_df.show(5)

+----------------+----------------+-----------+--------------+--------------+--------------------+--------------+---------+----------+------------+----------------+-------+-------+-------+-------+-------+-----+----------+---------+--------------+----------------+----------+-----------+
|transaction_date|  transaction_id|customer_id|customer_fname|cusomter_lname|      cusomter_email|sales_agent_id|branch_id|product_id|product_name|product_category|offer_1|offer_2|offer_3|offer_4|offer_5|units|unit_price|is_online|payment_method|shipping_address| load_date|load_source|
+----------------+----------------+-----------+--------------+--------------+--------------------+--------------+---------+----------+------------+----------------+-------+-------+-------+-------+-------+-----+----------+---------+--------------+----------------+----------+-----------+
|      2023-10-25|trx-072037549384|      85550|          Emma|        Wilson|emma.wilson@outlo...|           2.0|      3.0|         3|     

In [4]:
# Removing duplicates
transactions_df = remove_duplicates(transactions_df)
branches_df = remove_duplicates(branches_df)
agents_df = remove_duplicates(agents_df)

# Checking for null values in key columns
if check_nulls(transactions_df, "transaction_id"):
    print("Null values found in transaction_id column")
if check_nulls(branches_df, "branch_id"):
    print("Null values found in branch_id column")
if check_nulls(agents_df, "sales_person_id"):
    print("Null values found in sales_person_id column")

# Calculating total paid price after discount and adding offer column
transactions_df = calculate_total_paid_price_after_discount(transactions_df)
transactions_df = add_offer_column(transactions_df)

# Creating fact and dimension tables
fact_sales = create_fact_sales(transactions_df)
dim_sales = create_dim_sales(transactions_df)
dim_product = create_dim_product(transactions_df)
dim_customer = create_dim_customer(transactions_df)
dim_branch = create_dim_branch(branches_df)
dim_agent = create_dim_agent(agents_df)

In [5]:
fact_sales.show(5)

+----------------+-----------+--------------+---------+----------+-----+-----+----------+-------------------------------+
|  transaction_id|customer_id|sales_agent_id|branch_id|product_id|offer|units|unit_price|total_paid_price_after_discount|
+----------------+-----------+--------------+---------+----------+-----+-----+----------+-------------------------------+
|trx-829015484650|      85500|           8.0|      1.0|        30| null|    2|     24.99|                          24.99|
|trx-318134583182|      85484|           6.0|      1.0|        24|    1|    8|     49.99|                          49.99|
|trx-738773442038|      85557|          10.0|      4.0|        13|    1|    3|    149.99|                         149.99|
|trx-562759580036|      85532|          null|     null|        18|    4|   10|    149.99|                         149.99|
|trx-641002331340|      85551|          null|     null|        22| null|    1|     79.99|                          79.99|
+----------------+------

In [6]:
dim_sales.show(5)


+----------------+-------------------+---------+--------------+--------------------+-------------------+-----------+
|  transaction_id|   transaction_date|is_online|payment_method|    shipping_address|          load_date|load_source|
+----------------+-------------------+---------+--------------+--------------------+-------------------+-----------+
|trx-829015484650|2023-12-11 00:00:00|       no|   Credit Card|                null|2024-07-01 00:00:00|    source1|
|trx-318134583182|2022-07-11 00:00:00|       no|          Cash|                null|2024-07-01 00:00:00|    source1|
|trx-738773442038|2023-11-10 00:00:00|       no|          Cash|                null|2024-07-01 00:00:00|    source1|
|trx-562759580036|2023-04-16 00:00:00|      yes|        PayPal|22 Eden Street/Mi...|2024-07-01 00:00:00|    source1|
|trx-641002331340|2023-11-16 00:00:00|      yes|        Stripe|5124 Williston Ro...|2024-07-01 00:00:00|    source1|
+----------------+-------------------+---------+--------------+-

In [7]:
# Inserting data into Hive tables
    insert_into_hive_table(spark, fact_sales, "casestudy.fact_sales")  # Fact sales - directly insert
    insert_into_hive_table(spark, dim_sales, "casestudy.dim_sales")    # Dim sales - directly insert

Inserting data into casestudy.fact_sales.
Inserting data into casestudy.dim_sales.
