In [0]:
from pyspark.sql.window import Window
from pyspark.sql import functions as f
from pyspark.sql.functions import lead
from pyspark.sql.types import IntegerType

In [0]:
class Transformer:
    def __init__(self, inputDFs):
        self.inputDFs = inputDFs

    def transform(self):
        pass

class FirstTransformer(Transformer):

    def transform(self):
        """
        Customers who bought Airpods after buying iPhone
        """
        transactionInputDF = self.inputDFs.get("transactionInputDF")

        print("transactionInputDF in transform")

        # Ensure customer_id is of IntegerType
        transactionInputDF = transactionInputDF.withColumn("customer_id", transactionInputDF["customer_id"].cast(IntegerType()))

        transactionInputDF.printSchema()
        transactionInputDF.show()

        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        transformedDF = transactionInputDF.withColumn(
            "next_product_name", lead("product_name").over(window=windowSpec)            
        )

        print("with next_product_name")

        transformedDF.orderBy("customer_id", "transaction_date").show()

        filteredDF = transformedDF.filter(
            (f.col("product_name") == "iPhone") & (f.col("next_product_name") == "AirPods")
        )

        print("Customers who bought Airpods after buying iPhone")

        filteredDF.orderBy("customer_id", "transaction_date").show()
        
        filteredDF = filteredDF.withColumn("customer_id", filteredDF["customer_id"].cast(IntegerType()))

        filteredDF = filteredDF.filter(f.col("customer_id").isNotNull())

        print("filteredDF Schema")
        filteredDF.printSchema()
        filteredDF.show()

        print("customerInputDF")
        customerInputDF = self.inputDFs.get("customerInputDF")

        customerInputDF = customerInputDF.filter(f.col("customer_id").isNotNull())

        customerInputDF.show()

        # Ensure customer_id is of IntegerType in customerInputDF
        customerInputDF = customerInputDF.withColumn("customer_id", customerInputDF["customer_id"].cast(IntegerType()))

        print("customerInputDF Schema")
        customerInputDF.printSchema()
        
        joinedDF = customerInputDF.join(
            f.broadcast(filteredDF),
            "customer_id"
        )

        print("Joined DF")
        joinedDF.show()

        return joinedDF.select(
            "customer_id",
            "customer_name",
            "location"
        )
