In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

def process_batch(df, batch_id):
    print(f"Batch ID: {batch_id}")
    df.show(truncate=False)

spark = SparkSession.builder.appName("CustomerSegmentation").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

schema = StructType([
    StructField("user_id", StringType()),
    StructField("event_type", StringType()),
    StructField("timestamp", TimestampType()),
    StructField("product_id", StringType()),
    StructField("category", StringType()),
    StructField("price", DoubleType())
])
stream = (spark.readStream
          .schema(schema)
          .json("data/stream"))

aggregated = (stream.withWatermark("timestamp", "1 minute")
              .groupBy(window("timestamp", "5 minutes"), col("user_id"))
              .agg(collect_set("event_type").alias("event_types"))
)

segmented = aggregated.withColumn(
    "segment",
    expr("""
        CASE 
            WHEN array_contains(event_types, 'purchase') THEN 'Buyer'
            WHEN array_contains(event_types, 'cart') THEN 'Cart abandoner'
            WHEN array_contains(event_types, 'view') THEN 'Lurker'
            ELSE 'Unknown'
        END
    """)
)

query = (segmented.writeStream
         .outputMode("complete")
         .format("console")
         .foreachBatch(process_batch)
         .start())

