In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lead, col, trim, lower, to_date, broadcast, collect_set, array_contains, size

class Transformer:
    def __init__(self):
        pass
    def transform(self,inputDFs):
        pass

class AirpodsAfterIphoneTransformer(Transformer):
    def transform(self, inputDFs):
        transactionInputDF = inputDFs.get("transactionInputDF")

        # Ensure transaction_date is of DateType (assuming the format 'yyyy-MM-dd')
        transactionInputDF = transactionInputDF.withColumn(
            "transaction_date", to_date(col("transaction_date"), "yyyy-MM-dd")
        )

        # Add window specification
        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        # Trim and lower the product names, then use lead function
        transformedDF = transactionInputDF.withColumn(
            "product_name", trim(col("product_name"))
        ).withColumn(
            "next_product_name", lead("product_name").over(windowSpec)
        )

        # Debugging: show transformed data with next product name
        print("Transformed DataFrame with next product name:")
        transformedDF.orderBy("customer_id", "transaction_date").show(50, False)

        # Filter the DataFrame for customers who bought an iPhone followed by Airpods
        filteredDF = transformedDF.filter(
            (lower(col("product_name")) == "iphone") & (lower(col("next_product_name")) == "airpods")
        )

        # Debugging: show filtered data
        print("Filtered DataFrame with customers buying iPhone followed by Airpods:")
        filteredDF.orderBy("customer_id", "transaction_date").show(50, False)

        customerInputDF = inputDFs.get("customerInputDF")

        # Joining customer table to get details about the filtered customers
        joinedDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id",
            "inner"
        )

        # Debugging: show filtered data
        print("Customer details for customers who bought Airpods after buying an iPhone:")
        joinedDF.select(
            "customer_id",
            "customer_name",
            "location").show(50, False)

        return joinedDF
    
class OnlyAirpodsAndIphone(Transformer):

    def transform(self, inputDFs):
        """
        Customers who have only bought iPhone and Airpods and nothing else
        """

        transactionInputDF = inputDFs.get("transactionInputDF")

        print("transactionInputDF in transform")

        # Using collect_set to get distinct purchases grouped per customer
        groupedDF = transactionInputDF.groupBy("customer_id").agg(
            collect_set("product_name").alias("products")
        )

        # Debugging: show grouped data
        print("Grouped DataFrame with distinct purchases grouped per customer :")
        groupedDF.show(50, False)

        # Filtering grouped customers on those who have bought only an iPhone and AirPods
        filteredDF = groupedDF.filter(
            (array_contains(col("products"), "iPhone")) &  # Use array_contains
            (array_contains(col("products"), "AirPods")) &  # Use array_contains
            (size(col("products")) == 2)  # Ensure they only bought 2 distinct products
        )

        # Debugging: show grouped data filtered for customers who have bought only an iPhone and AirPods
        print("Filtered DataFrame with customers who have bought iPhone and Airpods:")
        filteredDF.show(50, False)

        customerInputDF = inputDFs.get("customerInputDF")

        # Joining customer table to get details about the filtered customers
        joinedDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id",
            "inner"
        )

        # Debugging: show filtered data
        print("Customer details for customers who only bought Airpods and an iPhone:")
        joinedDF.select(
            "customer_id",
            "customer_name",
            "location").show(50, False)

        return joinedDF 