In [0]:
%sql
CREATE VOLUME IF NOT EXISTS workspace.default.checkpoints;

DROP VOLUME workspace.default.checkpoints;

In [0]:
import os

# Configure AWS credentials for both Spark and boto3
aws_access_key_id = "AKIAWVRJWKZZBTQ4BHUI"
aws_secret_access_key = "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"

# Set environment variables (works in both serverless and standard compute)
os.environ['AWS_ACCESS_KEY_ID'] = aws_access_key_id
os.environ['AWS_SECRET_ACCESS_KEY'] = aws_secret_access_key
os.environ['AWS_DEFAULT_REGION'] = 'ap-south-1'

# Additional S3-specific environment variables for Spark
os.environ['AWS_S3_ACCESS_KEY'] = aws_access_key_id
os.environ['AWS_S3_SECRET_KEY'] = aws_secret_access_key




In [0]:
import pyspark.sql.functions as F
from pyspark.sql.types import *

# --- Configuration Widgets ---
dbutils.widgets.text("s3_raw_bucket", "pk-transactions-raw", "S3 Raw Bucket")
dbutils.widgets.text("s3_detections_bucket", "pk-transactions-detections", "S3 Detections Bucket")
dbutils.widgets.text("s3_temp_bucket", "pk-transactions-temp", "S3 Temp Bucket")

dbutils.widgets.text("customer_importance_table", "workspace.default.customer_importance", "Customer Importance Table Name")
dbutils.widgets.text("postgres_host", "databricks-postgres-db.cd424ookkd9v.ap-south-1.rds.amazonaws.com", "Postgres Host")
dbutils.widgets.text("postgres_user", "postgres", "Postgres User")
dbutils.widgets.text("postgres_db", "postgres", "Postgres DB")
dbutils.widgets.text("postgres_password", "Abcabc123", "Postgres Password")

# --- Get Widget Values ---
s3_raw_path = f"s3a://{dbutils.widgets.get('s3_raw_bucket')}"
s3_detections_path = f"s3a://{dbutils.widgets.get('s3_detections_bucket')}"
customer_importance_table_name = dbutils.widgets.get('customer_importance_table')
postgres_host = dbutils.widgets.get('postgres_host')
postgres_user = dbutils.widgets.get('postgres_user')
postgres_db = dbutils.widgets.get('postgres_db')
postgres_password = dbutils.widgets.get('postgres_password')

# --- Checkpoint Location on a Unity Catalog Volume ---
volume_checkpoint_path = "/Volumes/workspace/default/checkpoints/mechanism_y"
s3_checkpoint_location = volume_checkpoint_path

# Ensure the directory exists. This is idempotent.
dbutils.fs.mkdirs(s3_checkpoint_location)

# --- JDBC Connection Properties ---
jdbc_url = f"jdbc:postgresql://{postgres_host}:5432/{postgres_db}"
connection_properties = { "user": postgres_user, "password": postgres_password, "driver": "org.postgresql.Driver" }

# --- Define Transaction Schema from Source Table ---
try:
    source_table_for_schema = spark.table("workspace.default.transactions")
    transactions_schema = source_table_for_schema.schema
    print("Successfully derived transactions schema from source table.")
except Exception as e:
    raise Exception("Could not find source table 'workspace.default.transactions' to derive schema.") from e

print("\nConfiguration complete.")

In [0]:
from datetime import datetime
import pytz
import psycopg2
from pyspark.sql.window import Window
import pyspark.sql.functions as F
from pyspark.sql.functions import regexp_replace, trim, col, when, upper
from pyspark.sql.types import StructType, StructField, StringType

# Global variable to store Y start time (set once when mechanism Y starts)
Y_START_TIME = datetime.now(pytz.timezone('Asia/Kolkata')).strftime('%Y-%m-%d %H:%M:%S')

def clean_quoted_strings(df, string_columns):
    """Remove single quotes from string columns"""
    for column in string_columns:
        if column in df.columns:
            df = df.withColumn(column, regexp_replace(col(column), "^'|'$", ""))
    return df

def normalize_gender(df, gender_column="gender"):
    """Normalize gender values to standard format"""
    return df.withColumn(
        gender_column,
        when(upper(col(gender_column)).isin(['M', 'MALE', '1']), 'M')
        .when(upper(col(gender_column)).isin(['F', 'FEMALE', '2']), 'F')
        .otherwise('Unknown')  # Unknown/Other
    )

def create_empty_detections_df():
    """Create empty DataFrame with proper schema for detections"""
    schema = StructType([
        StructField("YStartTime_IST", StringType(), True),
        StructField("detectionTime_IST", StringType(), True),
        StructField("patternId", StringType(), True),
        StructField("ActionType", StringType(), True),
        StructField("customerName", StringType(), True),
        StructField("MerchantId", StringType(), True)
    ])
    return spark.createDataFrame([], schema)

def process_batch(batch_df, batch_id):
    print(f"--- Processing Batch ID: {batch_id} ---")
    if batch_df.isEmpty():
        print("Empty batch - skipping processing")
        return

    try:
        print("=== Data Cleaning and Validation ===")
        
        # **Critical Fix 1: Clean quoted strings from transaction data**
        string_columns = ["customer", "merchant", "category", "gender", "zipcodeOri", "zipMerchant"]
        batch_df = clean_quoted_strings(batch_df, string_columns)
        
        # **Critical Fix 2: Normalize gender values and handle data types**
        batch_df = normalize_gender(batch_df)
        batch_df = batch_df.withColumn("age", col("age").cast("integer"))
        
        # Data quality validation
        print(f"Batch records: {batch_df.count()}")
        gender_values = [row['gender'] for row in batch_df.select('gender').distinct().collect()]
        print(f"Gender values after cleaning: {gender_values}")
        
        # **Critical Fix 3: Clean importance data with same logic**
        importance_df = (
            spark.table(customer_importance_table_name)
              .withColumnRenamed("Source", "customer")
              .withColumnRenamed("Target", "merchant")
              .withColumnRenamed("Weight", "weight")
              .withColumnRenamed("typeTrans", "category")
              .select("customer", "merchant", "weight", "category")
        )
        
        # Clean quoted strings from importance data
        importance_string_columns = ["customer", "merchant", "category"]
        importance_df = clean_quoted_strings(importance_df, importance_string_columns)
        
        print("=== Database Operations ===")
        
        # **Critical Fix 4: Improved connection management with error handling**
        conn = None
        cursor = None
        
        try:
            conn = psycopg2.connect(
                host=postgres_host, 
                dbname=postgres_db, 
                user=postgres_user, 
                password=postgres_password,
                connect_timeout=10
            )
            cursor = conn.cursor()
            
            # --- Part 1: Update Merchant Counts in Postgres ---
            merchant_counts_in_batch = batch_df.groupBy("merchant").count()
            merchant_updates = merchant_counts_in_batch.collect()
            
            print(f"Updating merchant counts for {len(merchant_updates)} merchants...")
            for row in merchant_updates:
                cursor.execute(
                    """INSERT INTO merchant_transaction_counts (merchant_id, transaction_count) 
                       VALUES (%s, %s) 
                       ON CONFLICT (merchant_id) 
                       DO UPDATE SET transaction_count = merchant_transaction_counts.transaction_count + EXCLUDED.transaction_count;""",
                    (row['merchant'], row['count'])
                )
            
            # --- Part 2: Update Customer-Merchant Statistics ---
            batch_with_importance = batch_df.join(importance_df, ["customer", "merchant", "category"], "inner")
            
            if not batch_with_importance.isEmpty():
                customer_merchant_batch_stats = batch_with_importance.groupBy("customer", "merchant").agg(
                    F.count("*").alias("batch_txn_count"),
                    F.avg("weight").alias("batch_avg_weight"),
                    F.sum("amount").alias("batch_amount_sum"),
                    F.count("amount").alias("batch_amount_count")
                )
                
                customer_stats_updates = customer_merchant_batch_stats.collect()
                print(f"Updating customer-merchant stats for {len(customer_stats_updates)} combinations...")
                
                for row in customer_stats_updates:
                    # **Critical Fix 5: Handle potential null values**
                    weight_sum = float(row['batch_avg_weight'] or 0) * row['batch_txn_count']
                    cursor.execute(
                        """INSERT INTO customer_merchant_stats (customer_id, merchant_id, transaction_count, weight_sum, weight_count, amount_sum, amount_count)
                           VALUES (%s, %s, %s, %s, %s, %s, %s)
                           ON CONFLICT (customer_id, merchant_id)
                           DO UPDATE SET 
                               transaction_count = customer_merchant_stats.transaction_count + EXCLUDED.transaction_count,
                               weight_sum = customer_merchant_stats.weight_sum + EXCLUDED.weight_sum,
                               weight_count = customer_merchant_stats.weight_count + EXCLUDED.weight_count,
                               amount_sum = customer_merchant_stats.amount_sum + EXCLUDED.amount_sum,
                               amount_count = customer_merchant_stats.amount_count + EXCLUDED.amount_count;""",
                        (row['customer'], row['merchant'], row['batch_txn_count'], 
                         weight_sum, row['batch_txn_count'],
                         float(row['batch_amount_sum'] or 0), row['batch_amount_count'])
                    )
            else:
                print("No matching records found between batch and importance data")
            
            # --- Part 3: Update Customer Gender Stats per Merchant ---
            gender_batch_stats = batch_df.groupBy("merchant", "gender").agg(
                F.countDistinct("customer").alias("distinct_customers")
            )
            
            gender_updates = gender_batch_stats.collect()
            print(f"Updating gender stats for {len(gender_updates)} merchant-gender combinations...")
            
            for row in gender_updates:
                cursor.execute(
                    """INSERT INTO merchant_gender_stats (merchant_id, gender, customer_count)
                       VALUES (%s, %s, %s)
                       ON CONFLICT (merchant_id, gender)
                       DO UPDATE SET customer_count = merchant_gender_stats.customer_count + EXCLUDED.customer_count;""",
                    (row['merchant'], row['gender'], row['distinct_customers'])
                )
            
            conn.commit()
            print("Database updates committed successfully")
            
        except psycopg2.Error as db_error:
            print(f"Database error: {db_error}")
            if conn:
                conn.rollback()
            raise
            
        finally:
            if cursor:
                cursor.close()
            if conn:
                conn.close()
        
        print("=== Pattern Detection ===")
        
        # --- Pattern Detection Logic ---
        y_start_time = F.lit(Y_START_TIME)
        detection_time = F.lit(datetime.now(pytz.timezone('Asia/Kolkata')).strftime('%Y-%m-%d %H:%M:%S'))
        
        # Initialize empty detection DataFrames
        patid1_detections = create_empty_detections_df()
        patid2_detections = create_empty_detections_df()
        patid3_detections = create_empty_detections_df()
        
        # **PatId1: Top 10% transaction count + Bottom 10% weight + Merchant >50K transactions**
        try:
            print("Running PatId1 detection...")
            eligible_merchants_df = spark.read.jdbc(
                url=jdbc_url, 
                table="merchant_transaction_counts", 
                properties=connection_properties
            ).filter(F.col("transaction_count") > 50000)
            
            if not eligible_merchants_df.isEmpty():
                print(f"Found {eligible_merchants_df.count()} eligible merchants for PatId1")
                
                customer_stats_df = spark.read.jdbc(
                    url=jdbc_url,
                    table="(SELECT customer_id, merchant_id, transaction_count, CASE WHEN weight_count > 0 THEN weight_sum/weight_count ELSE 0 END as avg_weight FROM customer_merchant_stats WHERE weight_count > 0) as stats",
                    properties=connection_properties
                )
                
                if not customer_stats_df.isEmpty():
                    # Calculate percentiles per merchant
                    merchant_window = Window.partitionBy("merchant_id")
                    percentiles_df = customer_stats_df.withColumn(
                        "txn_percentile", F.percent_rank().over(merchant_window.orderBy("transaction_count"))
                    ).withColumn(
                        "weight_percentile", F.percent_rank().over(merchant_window.orderBy("avg_weight"))
                    )
                    
                    patid1_detections = percentiles_df.join(
                        eligible_merchants_df, 
                        percentiles_df.merchant_id == eligible_merchants_df.merchant_id
                    ).filter(
                        (F.col("txn_percentile") >= 0.90) & (F.col("weight_percentile") <= 0.10)
                    ).select(
                        y_start_time.alias("YStartTime_IST"),
                        detection_time.alias("detectionTime_IST"),
                        F.lit("PatId1").alias("patternId"),
                        F.lit("UPGRADE").alias("ActionType"),
                        F.col("customer_id").alias("customerName"),
                        F.col("merchant_id").alias("MerchantId")
                    )
                    
                    print(f"PatId1 detections: {patid1_detections.count()}")
                else:
                    print("No customer stats available for PatId1")
            else:
                print("No eligible merchants for PatId1 (>50K transactions)")
                
        except Exception as e:
            print(f"PatId1 detection error: {e}")
            patid1_detections = create_empty_detections_df()
        
        # **PatId2: Avg transaction <23 AND >=80 transactions**
        try:
            print("Running PatId2 detection...")
            customer_stats_df = spark.read.jdbc(
                url=jdbc_url,
                table="(SELECT customer_id, merchant_id, transaction_count, CASE WHEN amount_count > 0 THEN amount_sum/amount_count ELSE 0 END as avg_txn_value FROM customer_merchant_stats WHERE amount_count > 0) as stats",
                properties=connection_properties
            )
            
            if not customer_stats_df.isEmpty():
                patid2_detections = customer_stats_df.filter(
                    (F.col("avg_txn_value") < 23) & (F.col("transaction_count") >= 80)
                ).select(
                    y_start_time.alias("YStartTime_IST"),
                    detection_time.alias("detectionTime_IST"),
                    F.lit("PatId2").alias("patternId"),
                    F.lit("CHILD").alias("ActionType"),
                    F.col("customer_id").alias("customerName"),
                    F.col("merchant_id").alias("MerchantId")
                )
                
                print(f"PatId2 detections: {patid2_detections.count()}")
            else:
                print("No customer stats available for PatId2")
                
        except Exception as e:
            print(f"PatId2 detection error: {e}")
            patid2_detections = create_empty_detections_df()
        
        # **PatId3: Female < Male customers AND Female >100**
        try:
            print("Running PatId3 detection...")
            gender_pivot_query = """
            SELECT merchant_id,
                   COALESCE(SUM(CASE WHEN gender = 'F' THEN customer_count END), 0) as female_customers,
                   COALESCE(SUM(CASE WHEN gender = 'M' THEN customer_count END), 0) as male_customers
            FROM merchant_gender_stats 
            GROUP BY merchant_id
            HAVING COALESCE(SUM(CASE WHEN gender = 'F' THEN customer_count END), 0) > 0
            """
            
            gender_stats_df = spark.read.jdbc(
                url=jdbc_url,
                table=f"({gender_pivot_query}) as gender_stats",
                properties=connection_properties
            )
            
            if not gender_stats_df.isEmpty():
                patid3_detections = gender_stats_df.filter(
                    (F.col("female_customers") < F.col("male_customers")) & 
                    (F.col("female_customers") > 100)
                ).select(
                    y_start_time.alias("YStartTime_IST"),
                    detection_time.alias("detectionTime_IST"),
                    F.lit("PatId3").alias("patternId"),
                    F.lit("DEI-NEEDED").alias("ActionType"),
                    F.lit("").alias("customerName"),
                    F.col("merchant_id").alias("MerchantId")
                )
                
                print(f"PatId3 detections: {patid3_detections.count()}")
            else:
                print("No gender stats available for PatId3")
                
        except Exception as e:
            print(f"PatId3 detection error: {e}")
            patid3_detections = create_empty_detections_df()
        
        # **Critical Fix 6: Proper detection output handling with schema alignment**
        all_detections = patid1_detections.unionByName(patid2_detections).unionByName(patid3_detections)
        
        if not all_detections.isEmpty():
            detection_count = all_detections.count()
            print(f"Found {detection_count} total detections in this batch.")
            
            # **Fix column names to match assignment requirements**
            final_detections = all_detections.withColumnRenamed("YStartTime_IST", "YStartTime(IST)") \
                                           .withColumnRenamed("detectionTime_IST", "detectionTime(IST)")
            
            if detection_count > 0:
                # Calculate number of partitions needed (each partition = max 50 records)
                num_partitions = max(1, (detection_count + 49) // 50)  # Ceiling division
                
                final_detections.repartition(num_partitions).write.mode("append").format("csv").option("header", "true").save(s3_detections_path)
                print(f"Written {detection_count} detections in {num_partitions} files")
        else:
            print("No detections found in this batch")
            
    except Exception as e:
        print(f"Critical error in process_batch: {str(e)}")
        import traceback
        traceback.print_exc()
        raise e
    
    print(f"--- Completed Batch ID: {batch_id} ---\n")


In [0]:
df_stream = (
    spark.readStream
    .format("cloudFiles")  
    .option("cloudFiles.format", "csv")
    .option("cloudFiles.schemaLocation", s3_checkpoint_location)
    .schema(transactions_schema)
    .option("header", "true")
    .load(s3_raw_path)
)

streaming_query = (
    df_stream.writeStream
    .foreachBatch(process_batch)
    .option("checkpointLocation", s3_checkpoint_location)
    .trigger(availableNow=True)
    .start()
)

print("Streaming query started successfully.")
streaming_query.awaitTermination()
print("Stream processing complete.")
