In [0]:
from pyspark.sql.window import Window 
from pyspark.sql.functions import lead, col, broadcast, collect_set,size,array_contains

class Transformer:
    def __init__(self):
        pass

    def transform(self, inputDFs):
        """
        Abstract method to be implemented by subclasses.
        """
        raise NotImplementedError("The method 'transform' is not implemented in the base class.")

class AirpodsAfterIphoneTransformer(Transformer):
    def transform(self, inputDFs):
        """
        Filters customers who bought Airpods after buying an iPhone.
        """
        transactionInputDF = inputDFs.get("transactionInputDF")

        if transactionInputDF is None:
            raise ValueError("transactionInputDF is not found in inputDFs.")

        print("transactionInputDF in transform:")
        transactionInputDF.show()

        WindowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        transformedDF = transactionInputDF.withColumn(
            "next_product_name", lead("product_name").over(WindowSpec)
        )

        print("Airpods after buying iPhone:")
        transformedDF.orderBy("customer_id", "transaction_date", "product_name").show()

        filteredDF = transformedDF.filter(
            (col("product_name") == "iPhone") & (col("next_product_name") == "AirPods") 
        )
        filteredDF.show()

        customerInputDF = inputDFs.get("customerInputDF")

        if customerInputDF is None:
            raise ValueError("customerInputDF is not found in inputDFs.")

        customerInputDF.show()

        joinDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id"
        )

        print("JOINED DF")
        joinDF.show()

        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )

class OnlyAirpodsAndIphone(Transformer):
    """
    Filters customers who only bought iPhone and AirPods, nothing else.
    """
    def transform(self, inputDFs):
        transactionInputDF = inputDFs.get("transactionInputDF")

        if transactionInputDF is None:
            raise ValueError("transactionInputDF is not found in inputDFs.")

        print("transactionInputDF in transform:")
        transactionInputDF.show()

        groupedDF = transactionInputDF.groupBy("customer_id").agg(
            collect_set("product_name").alias("products")
        )
        print("Grouped DF")
        groupedDF.show()

        filteredDF = groupedDF.filter(
            array_contains(col("products"),"iPhone") & 
            array_contains(col("products"),"AirPods") &
            (size(col("products"))==2) 
            
        )
        print("Only AirPods And Iphone")
        filteredDF.show()

        customerInputDF = inputDFs.get("customerInputDF")

        if customerInputDF is None:
            raise ValueError("customerInputDF is not found in inputDFs.")

        customerInputDF.show()

        joinDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id"
        )

        print("JOINED DF")
        joinDF.show()

        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )
