In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lead,col,broadcast,collect_set,array_contains,size,datediff,avg,row_number
from pyspark.sql.types import IntegerType
from pyspark.sql import functions as F

class Transformer:

    def __init__(self):
        pass

    def transform(self,inputDf):
        pass


class AirpodsAfterIphone(Transformer):
    def transform(self,inputDF):
        #Customers who bought Airpods after Iphone
        transactionInputDF = inputDF.get("transactionInputDF")

        transactionInputDF.show()

        windowspec = Window.partitionBy("customer_id").orderBy("transaction_date")

        transform_df = transactionInputDF.withColumn("next_product_name",
                                                     lead("product_name").over(windowspec))
        print("Airpods after buying iphone")
        transform_df.orderBy("customer_id","transaction_date","product_name").show()

        filtered_df = transform_df.filter((col("product_name")=="iPhone") & (col("next_product_name")=="AirPods"))

        filtered_df.orderBy("customer_id","transaction_date","product_name").show()

        customerInputDF = inputDF.get("customerInputDF")

        customerInputDF.show()

        join_df = customerInputDF.join(broadcast(filtered_df),"customer_id")

        print("Joined DF")

        join_df.select("customer_id","customer_name","location").show()

        return join_df.select("customer_id","customer_name","location")



 
class onlyAirPodsAndIphone(Transformer):      
    def transform(self,inputDF):
        """
         Customer who bought only IPhone and Airpods only
        """ 
        transactionInputDF = inputDF.get("transactionInputDF")

        transactionInputDF.show()

        groupedDF = transactionInputDF.groupBy("customer_id").agg(collect_set("product_name").alias("products"))

        groupedDF.show(truncate=False)

        filtered_df = groupedDF.filter((array_contains(col("products"),"iPhone")) 
                                       & (array_contains(col("products"),"AirPods")) 
                                       & (size(col("products"))==2))
        filtered_df.show()
        customerInputDF = inputDF.get("customerInputDF")

        customerInputDF.show()

        join_df = customerInputDF.join(broadcast(filtered_df),"customer_id")

        print("Joined DF in second workflow")

        join_df.select("customer_id","customer_name","location","products").show()

        return join_df.select("customer_id","customer_name","location","products")

In [0]:
class AverageTimeDelay(Transformer):
    def transform(self,inputDF):
        transactionInputDF = inputDF.get("transactionInputDF")

        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        delayDF = transactionInputDF.withColumn(
            "next_purchased_product", lead("product_name").over(windowSpec)
            ).withColumn(
                "time_delay", datediff(lead("transaction_date").over(windowSpec),col("transaction_date"))
            ).filter(
                ((col("product_name") == "iPhone") & (col("next_purchased_product")== "AirPods"))
            )
        
        averageDelayDF = delayDF.groupBy("customer_id").agg(
            avg("time_delay").alias("average_time_delay")
        )

        averageDelayDF.show()
        return averageDelayDF.select("customer_id","average_time_delay")


In [0]:
class TopSellingProducts(Transformer):
    def transform(self,inputDF):

        transactionInputDF = inputDF.get("transactionInputDF")
        productsInputDF = inputDF.get("productsInputDF")

        productsInputDF = productsInputDF.withColumn("product_name_only",F.split(F.col("product_name")," ").getItem(0))

        productsInputDF = productsInputDF.withColumn("price_casted",F.col("price").cast("int"))

        productsInputDF = productsInputDF.na.fill(value=0,subset=["price_casted"])
        productsInputDF.printSchema()
        productsInputDF.show()
        joinedDF = transactionInputDF.join(productsInputDF,productsInputDF.product_name_only==transactionInputDF.product_name)

        print("Joined transaction and Product DF")
        joinedDF.show()


        topProductsDF = joinedDF.groupBy("category","product_name_only").agg(
            F.sum("price_casted").alias("total_revenue")
        ).orderBy("category",col("total_revenue"))

        print("Aggregated Data")
        topProductsDF.show()
        windowSpec = Window.partitionBy("category").orderBy(col("total_revenue"))

        top3DF = topProductsDF.withColumn(
            "rank",row_number().over(windowSpec)
        ).filter(col("rank")<=3)

        print("Top 3 products in each category")
        top3DF.show()

        return top3DF.select("category","total_revenue","product_name_only")
