In [0]:
from pyspark.sql.window import Window 
from pyspark.sql.functions import lead,col, broadcast, collect_set, size, array_contains

class Transformer:
    
    def __init__ (self):
        pass

    def transform(self,inputDFs):
        pass

class AirpodsAfterIphoneTransformer(Transformer):

    def transform(self, inputDFs):

        #Customers who bought airpords immediately after buying iphone

        transactionInputDF = inputDFs.get("transactionInputDF")

        print("transactionInputDF in Transform")

        transactionInputDF.show()

        WindowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        transformedDF = transactionInputDF.withColumn(
            "next_product_name",lead("product_name").over(WindowSpec)
        )

        print("Customers who bought Airpords after buying Iphone")
        transformedDF.orderBy("customer_id","transaction_date","product_name").show()

        filteredDF = transformedDF.filter(
            (col("product_name") == 'iPhone') & (col("next_product_name") == 'AirPods')
        )

        filteredDF.orderBy("customer_id","transaction_date","product_name").show()   

        customerInputDF = inputDFs.get("customerInputDF")
        customerInputDF.show()

        joinDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id"
        )

        print("Customers who bought airpords immediately after buying iphone")
        joinDF.show()

        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )

class OnlyIphoneAndAirpodsTransformer(Transformer): 

    def transform(self, inputDFs):

        #Customers who only bought airpords and iphone

        transactionInputDF = inputDFs.get("transactionInputDF")

        print("transactionInputDF in Transform")

        groupedDF = transactionInputDF.groupBy("customer_id").agg(
            collect_set("product_name").alias("products")
        )

        filteredDF = groupedDF.filter(
            (array_contains(col("products"), "iPhone")) &
            (array_contains(col("products"), "AirPods")) & 
            (size(col("products")) == 2)
        )
        
        print("Only Airpods and iPhone")
        filteredDF.show()

        customerInputDF = inputDFs.get("customerInputDF")
        customerInputDF.show()

        joinDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id"
        )

        print("Customers who only bought airpords and iphone")
        joinDF.show()

        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )

