### Description
_Mechanism Y starts at the same time as X and ingests the S3 stream as 
soon as transaction chunk files become available, detects the below patterns asap and puts 
these detections to S3 , 50 at a time to a unique file._

_**Patterns:**_

_**PatId1** - A customer in the top 1 percentile for a given merchant for the total number of 
transactions with the bottom 1% percentile weight, merchant wants to UPGRADE(actionType) 
them. Upgradation only begins once total transactions for the merchant exceed 50K._

_**PatId2** - A customer whose average transaction value for a given merchant < Rs 23 and made 
at least 80 transactions with that merchant, merchant wants to mark them as CHILD(actionType) asap._

_**PatId3**- Merchants where number of Female customers < number of Male customers overall 
and number of female customers > 100, are marked DEI-NEEDED(actionType)._

In [0]:
import os
import time
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window

In [0]:
# Creating Spark Session
spark = SparkSession.builder.\
    appName("Batch_MechanismY")\
    .getOrCreate()

# Paths
customer_s3_path = "s3://customer-transactions-detection/input/customers/"
transaction_s3_path = "s3://customer-transactions-detection/input/transactions/"
stage_path = "/Volumes/workspace/devdolphins/staging/merchant/"
detection_path = "s3://customer-transactions-detection/output/"

Defining schemas for customerImportance and Transactions DataFrames.

In [0]:
# Schema definition
customer_importance_schema = StructType([
    StructField("customer", StringType(), True),
    StructField("merchant", StringType(), True),
    StructField("weight", DoubleType(), True),
    StructField("typeTrans", StringType(), True),
    StructField("fraud", IntegerType(), True),
])

transaction_schema = StructType([
    StructField("step", IntegerType()),
    StructField("customer", StringType()),
    StructField("age", StringType()),
    StructField("gender", StringType()),
    StructField("zipcodeOri", StringType()),
    StructField("merchant", StringType()),
    StructField("zipMerchant", StringType()),
    StructField("category", StringType()),
    StructField("amount", DoubleType()),
    StructField("fraud", IntegerType()),
])

Load the CustomerImportance file from S3 bucket and performing aggregations to get the total weight and frauds per consumer,merchant and transaction type.

In [0]:
# Read customer importance data
customer_df = spark.read.csv(customer_s3_path, header=True, schema=customer_importance_schema)
customer_df = broadcast(customer_df)

agg_customer_df = customer_df.groupBy("customer", "merchant") \
    .agg(sum("weight").alias("total_weight"), sum("fraud").alias("fraud_rate"))

**process_chunk** function performs detections and write the output to S3 bucket for each chunk received in the transactions S3 bucket.

In [0]:
def process_chunk(chunk_path, batch_id):
    print(f"Processing {chunk_path}")
    process_time = datetime.now().isoformat()
    
    # Read transaction chunk
    df = spark.read.csv(chunk_path, header=True, schema=transaction_schema) \
            .withColumn("YStartTime", lit(process_time))
    
    df.write.format("delta") \
        .mode("append") \
        .option("header", True) \
        .save(stage_path)

    stage_df = spark.read.format("delta").load(stage_path)
    # Precompute merchant and customer information
    merchant_txn = stage_df.groupBy("merchant").agg(
        sum("amount").alias("transactions"),
        countDistinct(when(col("gender") == "'F'", col("customer"))).alias("no_of_female_customers"),
        countDistinct(when(col("gender") == "'M'", col("customer"))).alias("no_of_male_customers")
        )
    #merchant_txn.cache()

    customer_txn = stage_df.groupBy("customer", "merchant").agg(
    count("*").alias("txn_count"), avg("amount").alias("avg_txn")
        )
    #customer_txn.cache()
    # Pattern 1
    merchant_txn_pat1 = merchant_txn.filter(col("transactions") > 50000)
    pattern1 = df.join(merchant_txn_pat1, "merchant")\
        .join(customer_txn, ["customer", "merchant"]) \
        .join(agg_customer_df, ["customer", "merchant"])\
        .withColumn("rank", percent_rank().over(Window.partitionBy("customer", "merchant").orderBy("txn_count"))) \
        .withColumn("low_weight_rank", percent_rank().over(Window.partitionBy("customer", "merchant").orderBy("total_weight"))) \
        .filter((col("rank") >= 0.99) & (col("low_weight_rank") <= 0.01)) \
        .withColumn("ActionType", lit("UPGRADE")) \
        .withColumn("patternId", lit("PatId1"))

    pat1_formatted = pattern1.withColumn("customerName", col("customer")) \
        .withColumn("MerchantId", col("merchant")) \
        .withColumn("ActionType", lit("UPGRADE")) \
        .withColumn("patternId", lit("PatId1")) \
        .withColumn("detectionTime", lit(datetime.now().isoformat())) \
        .select("YStartTime", "detectionTime", "patternId", "ActionType", "customerName", "MerchantId")

    # Pattern 2
    pattern2 = df.join(customer_txn, ["customer", "merchant"]) \
        .filter((col("avg_txn") < 23) & (col("txn_count") >= 80))

    pat2_formatted = pattern2.withColumn("customerName", col("customer")) \
        .withColumn("MerchantId", col("merchant")) \
        .withColumn("ActionType", lit("CHILD")) \
        .withColumn("patternId", lit("PatId2")) \
        .withColumn("detectionTime", lit(datetime.now().isoformat())) \
        .select("YStartTime", "detectionTime", "patternId", "ActionType", "customerName", "MerchantId")
    
    # Pattern 3
    pattern3 = df.join(merchant_txn, "merchant") \
        .filter((col("no_of_female_customers") > 100) & (col("no_of_female_customers") < col("no_of_male_customers")))

    pat3_formatted = pattern3.withColumn("customerName", lit("")) \
        .withColumn("MerchantId", col("merchant")) \
        .withColumn("ActionType", lit("DEI-NEEDED")) \
        .withColumn("patternId", lit("PatId3")) \
        .withColumn("detectionTime", lit(datetime.now().isoformat())) \
        .select("YStartTime", "detectionTime", "patternId", "ActionType", "customerName", "MerchantId")

    all_detections = pat1_formatted.unionByName(pat2_formatted).unionByName(pat3_formatted)

    # Chunk detections in 50-record batches
    row_window = Window.orderBy("patternId", "YStartTime")
    all_detections = all_detections.withColumn("row_num", row_number().over(row_window))
    all_detections = all_detections.withColumn("group_id", floor((col("row_num") - 1) / 50))

    # Get distinct group_id using DataFrame only
    group_ids = [row['group_id'] for row in all_detections.select("group_id").distinct().collect()]

    for group in group_ids:
        chunk_df = all_detections.filter(col("group_id") == group).drop("row_num", "group_id")
        filename = f"detection{batch_id}_group{group}.csv"
        output_path = os.path.join(detection_path, filename)
        chunk_df.write.mode("overwrite").option("header", True).csv(output_path)
        print(f"Wrote detection chunk to {output_path}")
    
    #merchant_txn.unpersist()
    #customer_txn.unpersist()

Stream the transaction chunk from S3 bucket based on the arrival of the chunk and perform the pattern detections. The below code cell calls **process_chunk** function to perform detections and write the output to S3 bucket.

In [0]:
batch_id = 0
no_new_chunk_seconds = 0
TIMEOUT_LIMIT = 60
# Track processed chunks
processed_chunks = set()

while True:
    chunk_dirs = sorted([d.name for d in dbutils.fs.ls(transaction_s3_path) if d.name.startswith("chunk_")])
    new_chunks = [d for d in chunk_dirs if d not in processed_chunks]

    if new_chunks:
        for chunk in new_chunks:
            chunk_path = os.path.join(transaction_s3_path, chunk)
            print(chunk_path)
            process_chunk(chunk_path, batch_id)
            processed_chunks.add(chunk)
            batch_id += 1
        no_new_chunk_seconds = 0
    else:
        no_new_chunk_seconds += 1
        print(f"No new chunks... waiting ({no_new_chunk_seconds}/60)")
        if no_new_chunk_seconds >= TIMEOUT_LIMIT:
            print("No new chunks for 60 seconds. Terminating gracefully.")
            break

    time.sleep(1)