In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, lead, col, broadcast, expr, collect_set, size, array_contains, min , collect_list, datediff, avg, datediff, when, count, desc

In [0]:
class Transformer:
    def __init__(self):
        pass

    def transform(self, InputDFs):
        """
        Abstract method for transforming the input DataFrame.
        """
        raise NotImplementedError("Subclasses should implement this method")

class FirstTransform(Transformer):
    def transform(self, InputDFs):
        """
        Example transformation: Display the DataFrame.
        """
        # Question 1: Customer who bought Airpods after buying iPhone using lead and lag function
        transactionDF = InputDFs.get('transactionInput')
        print("my input dataframe is below")
        transactionDF.show()
        window_spec = Window.partitionBy(col("customer_id")).orderBy("transaction_date") 
        transformDF = transactionDF.withColumn("next_prod", lead("product_name").over(window_spec)) 

        print("Airpods after buying iPhone")
        transformDF.orderBy("customer_id","transaction_date").show()

        # getting the final value of customer_id
        print("getting the final customer id")
        filteredDF = transformDF.select(col("customer_id")).where((col("product_name") == 'iPhone') & (col("next_prod")=='AirPods'))

        # joing the above output with the customer data to get all the details
        customerDF = InputDFs.get('customerInput')
        # Applying the join operation on both using "id"
        # airpodsAfterIphoneDF = customerDF.join(filteredDF, on="customer_id", how="inner")
        airpodsAfterIphoneDF = customerDF.join(broadcast(filteredDF), on="customer_id", how="inner")
        # airpodsAfterIphoneDF.select(col("customer_id"),col("customer_name"),col("location")).show()
        transformDF = airpodsAfterIphoneDF.select(col("customer_id"),col("customer_name"),col("location"))
        return transformDF

        # Now we will learn to apply the broadcast join to reduce the shuffling process.

class SecondTransform(Transformer):

    def transform(self, InputDFs):
        """
            Extract customers who only bought Iphone and AirPods
        """
        transaction_df = InputDFs.get("transactionInput")
        customer_df = InputDFs.get("customerInput")
        # collect all the products bought by each customer

        iphoneandAirpods = transaction_df.groupBy("customer_id").agg( \
        collect_set(col("product_name")).alias('collect_product'))

        # collecting only elements which contains only iPhone and AirPods
        filterDF = iphoneandAirpods.filter(
        (size(col('collect_product')) == 2) &
        array_contains(col('collect_product'), 'AirPods') &
        array_contains(col('collect_product'), 'iPhone')
        )
        filterDF.show()

        print("Following are the details: Customer who bought only Iphone and AirPods")

        filterDF = filterDF.join(customer_df, on="customer_id", how="inner").select("customer_id","customer_name","location","collect_product")
        filterDF.show()
        return filterDF
    
class ThirdTransform(Transformer):

    def transform(self, InputDFs):
        """
            Extract products bought by customers after initial purchase
        """
        transaction_df = InputDFs.get("transactionInput")
        # get the first transaction date for all customers and then join that table to transaction table.

        min_transaction_df = transaction_df.groupBy("customer_id").agg(
        min("transaction_date").alias("min_date")
        )
        join_products_df = transaction_df.join(min_transaction_df, on="customer_id", how="left")
        join_products_df.show()

        print("List of products bought by each customer after their initial purchase")
        join_products_df = join_products_df.filter(expr("transaction_date > min_date")).groupBy("customer_id").agg(
        collect_list("product_name").alias('collect_product')
        ).orderBy("customer_id")
        join_products_df.show()

        return join_products_df
    
class FourthTransform(Transformer):

    def transform(self, InputDFs):
        """
            Average time delay between Iphone and Airpods by each customer.
        """
        transaction_df = InputDFs.get("transactionInput")
        customer_df = InputDFs.get("customerInput")
        # creating two dataframes of Iphone and AirPods
        iPhone_df = transaction_df.filter(col("product_name")=='iPhone')
        airpods_df = transaction_df.filter(col("product_name")=='AirPods')

        joined_df = iPhone_df.alias('a').join(airpods_df.alias('b'), (col("a.customer_id")== col("b.customer_id")))

        joined_df = joined_df.withColumn(
        "time_delay", 
        datediff(col("b.transaction_date"), col("a.transaction_date"))
        ).groupBy(
        "a.customer_id"
        ).agg(
        avg("time_delay").alias("average_time_delay"))

        print("Getting the final result")

        # Join with customer_df
        joined_df = joined_df.join(
        customer_df.alias('c'),  # Use an alias to avoid ambiguity
        on=col("a.customer_id") == col("c.customer_id"), 
        how="inner"
        ).select(
        col("a.customer_id"),    # Use the correct alias to disambiguate
        col("c.customer_name"), 
        col("c.location"), 
        col("average_time_delay")
        )
        joined_df.show()

        return joined_df
    
class FifthTransform(Transformer):

    def transform(self, InputDFs):
        """
            ETL pipeline to extract top 3 products by total revenue.
        """
        transaction_df = InputDFs.get("transactionInput")
        products_df = InputDFs.get("productInput")

        # Updating the product_name in products_df. So that we can join both the dataframes.
        products_df = products_df \
        .withColumn("product_name", when(col("product_name")=='iPhone SE','iPhone').otherwise(col("product_name"))) \
        .withColumn("product_name", when(col("product_name")=='AirPods Pro','AirPods').otherwise(col("product_name"))) \
        .withColumn("product_name", when(col("product_name")=='MacBook Air','MacBook').otherwise(col("product_name"))) \
        .withColumn("product_name", when(col("product_name")=='iPad Mini','iPad').otherwise(col("product_name")))

        print("Getting the top 3 products by total revenue")

        top_products = transaction_df.groupBy("product_name").count().alias("product_sold")
        final_df = top_products.join(products_df, on="product_name", how="inner")
        final_df = final_df.withColumn("total_revenue",expr("price*count")).orderBy(col("total_revenue"), ascending=False).limit(3)

        final_df.show()

        return final_df
