In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lead
from pyspark.sql.functions import broadcast, collect_set, size, array_contains, col

class Transformer:
    def __init__(self):
        pass
    def transform(self, inputDFs):
        pass

class AirpodsAfterIphone(Transformer):
    def transform(self, inputDFs):
        """
        Customer who has bought Airpods after buying Iphone
        """
        transactionInputDF = inputDFs.get("transactionInputDF")
        print("transactionInputDF in transform")
        transactionInputDF.show()   

        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")
        transformedDF = transactionInputDF.withColumn(
            "next_product_name", lead("product_name").over(windowSpec)
        )
        print("customer bought Airpods after iphone")
        transformedDF.orderBy("customer_id","transaction_date","product_name").show()

        filteredDF = transformedDF.filter(
            (transformedDF.product_name == "iPhone") & (transformedDF.next_product_name== "AirPods")
        )
        filteredDF.orderBy("customer_id", "transaction_date", "product_name").show()

        customerInputDF = inputDFs.get("customerInputDF")
        customerInputDF.show()
        
        joinDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id"
        )
        
        print("JOINED DF")
        joinDF.show()
        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )

class OnlyAirpodsAndIphone(Transformer):
    def transform(self, inputDFs):
        """
        Customers who has bought only Airpods and Iphones
        """
        transactionInputDF = inputDFs.get("transactionInputDF")
        print("transactionInputDF in transform")
        groupedDF = transactionInputDF.groupBy("customer_id").agg(collect_set("product_name")
                                                                  .alias("products"))
        print("Grouped DF")
        groupedDF.show()

        filteredDF = groupedDF.filter(
            (array_contains(col("products"), "iPhone")) & 
            (array_contains(col("products"), "AirPods")) &
            (size(col("products")) == 2)
        )
        filteredDF.show()

        customerInputDF = inputDFs.get("customerInputDF")
        customerInputDF.show()
        
        joinDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id"
        )
        
        print("JOINED DF")
        joinDF.show()
        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )
    