In [0]:
# Databricks notebook source
#  %md
#  # Transaction Processing (S3/DBFS/Workspace/GDrive Sources)
#  
#  This notebook implements mechanism X and Y:
#  - X: every-second chunking of transactions (10,000 rows per chunk) to S3
#  - Y: polling S3 for newly arrived chunks, detecting PatId1, and writing detections in batches of 50 to S3
#  
#  Optional: downloads inputs from Google Drive into the cluster if configured.

# COMMAND ----------

#  %md
#  ## Setup

# COMMAND ----------

# If not already present on the cluster
#  %pip install boto3 psycopg2-binary pytz

# COMMAND ----------

import os
import io
import json
import time
from datetime import datetime
from typing import List, Optional, Set, Tuple
import threading

import pytz
import boto3
import pandas as pd
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import (
    col,
    lit,
    avg,
    sum as spark_sum,
    count as spark_count,
    desc,
    ntile,
    expr,
    current_timestamp,
    monotonically_increasing_id,
)
from pyspark.sql.window import Window

try:
    import psycopg2
    from psycopg2.extras import execute_values
except Exception:
    psycopg2 = None

# ---- Config (edit these or set env vars) ----
AWS_ACCESS_KEY_ID = "test"
AWS_SECRET_ACCESS_KEY = "t/R93uho8pGwJB6rnt0kL9"
AWS_REGION = "eu-north-1"
AWS_REGION = os.getenv("AWS_REGION", "eu-north-1")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "aws18082003")
S3_TRANSACTIONS_FOLDER = os.getenv("S3_TRANSACTIONS_FOLDER", "test/transactions/")
S3_DETECTIONS_FOLDER = os.getenv("S3_DETECTIONS_FOLDER", "test/detections/")

# Source selection: 's3', 'volume', 'dbfs', 'workspace', 'catalog', or 'gdrive'
INPUT_SOURCE = os.getenv("INPUT_SOURCE", "catalog")  # Changed default to catalog since that's working

# S3 URIs for input CSVs (if INPUT_SOURCE == 's3')
INPUT_CUSTOMER_IMPORTANCE_S3 = os.getenv("INPUT_CUSTOMER_IMPORTANCE_S3", "s3a://aws18082003/CustomerImportance.csv")
INPUT_TRANSACTIONS_S3 = os.getenv("INPUT_TRANSACTIONS_S3", "s3a://aws18082003/transactions.csv")

# Unity Catalog Volume paths (upload files to a UC Volume if using 'volume')
TRANSACTIONS_CSV_VOLUME = os.getenv("TRANSACTIONS_CSV_VOLUME", "")
CUSTOMER_IMPORTANCE_CSV_VOLUME = os.getenv("CUSTOMER_IMPORTANCE_CSV_VOLUME", "")

# DBFS file inputs (if INPUT_SOURCE == 'dbfs')
TRANSACTIONS_CSV_DBFS = os.getenv("TRANSACTIONS_CSV_DBFS", "dbfs:/FileStore/transactions.csv")
CUSTOMER_IMPORTANCE_CSV_DBFS = os.getenv("CUSTOMER_IMPORTANCE_CSV_DBFS", "dbfs:/FileStore/CustomerImportance.csv")

# Databricks Workspace absolute paths (if INPUT_SOURCE == 'workspace')
WORKSPACE_TRANSACTIONS_PATH = os.getenv("WORKSPACE_TRANSACTIONS_PATH", "/Workspace/Users/you@example.com/transactions.csv")
WORKSPACE_IMPORTANCE_PATH = os.getenv("WORKSPACE_IMPORTANCE_PATH", "/Workspace/Users/you@example.com/CustomerImportance.csv")

# Google Drive inputs (if INPUT_SOURCE == 'gdrive') - using mounted path
GDRIVE_MOUNT_PATH = os.getenv("GDRIVE_MOUNT_PATH", "/mnt/google_drive")
GDRIVE_TRANSACTIONS_FILENAME = os.getenv("GDRIVE_TRANSACTIONS_FILENAME", "transactions.csv")
GDRIVE_IMPORTANCE_FILENAME = os.getenv("GDRIVE_IMPORTANCE_FILENAME", "CustomerImportance.csv")

# Unity Catalog paths (if INPUT_SOURCE == 'catalog')

CATALOG_TRANSACTIONS_PATH = 'datadump.test.transactions'
CATALOG_IMPORTANCE_PATH = 'datadump.test.customer_importance'

# Processing
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "10000"))

DETECTION_BATCH_SIZE = int(os.getenv("DETECTION_BATCH_SIZE", "50"))
MIN_TRANSACTIONS_FOR_UPGRADE = int(os.getenv("MIN_TRANSACTIONS_FOR_UPGRADE", "1000"))

# X: sleep between chunk uploads (seconds)
X_SLEEP_SECONDS = float(os.getenv("X_SLEEP_SECONDS", "1"))

# Y: polling behavior
Y_POLL_INTERVAL_SECS = float(os.getenv("Y_POLL_INTERVAL_SECS", "1"))
Y_MAX_POLL_ROUNDS = int(os.getenv("Y_MAX_POLL_ROUNDS", "10"))  # safety to avoid infinite loop in demos

# Postgres (optional)
PG_HOST = os.getenv("PG_HOST", "")
PG_PORT = int(os.getenv("PG_PORT", "5432"))
PG_DB = os.getenv("PG_DB", "")
PG_USER = os.getenv("PG_USER", "")
PG_PASSWORD = os.getenv("PG_PASSWORD", "")
PG_TABLE_KEYS = os.getenv("PG_TABLE_KEYS", "processed_s3_keys")

IST = pytz.timezone("Asia/Kolkata")

# Run selection
RUN_MODE = os.getenv("RUN_MODE", "both")  # "x" | "y" | "both"

# COMMAND ----------

#  %md
#  ## Utilities

# COMMAND ----------

def _to_s3_uri(s3a_uri: str) -> str:
    return s3a_uri.replace("s3a://", "s3://")


def _read_csv_from_s3_via_boto(s3_uri: str) -> pd.DataFrame:
    uri = _to_s3_uri(s3_uri)
    from urllib.parse import urlparse

    p = urlparse(uri)
    bucket, key = p.netloc, p.path.lstrip("/")
    client = boto3.client(
        "s3",
        aws_access_key_id=AWS_ACCESS_KEY_ID,
        aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
        region_name=AWS_REGION,
    )
    obj = client.get_object(Bucket=bucket, Key=key)
    return pd.read_csv(io.StringIO(obj["Body"].read().decode("utf-8")))


# COMMAND ----------

#  %md
#  ## S3 Client

# COMMAND ----------

class S3Client:
    def __init__(self):
        self.s3_client = boto3.client(
            "s3",
            aws_access_key_id=AWS_ACCESS_KEY_ID,
            aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
            region_name=AWS_REGION,
        )
        self.bucket = S3_BUCKET_NAME
        self.tx_prefix = S3_TRANSACTIONS_FOLDER
        self.det_prefix = S3_DETECTIONS_FOLDER

    def create_bucket_if_not_exists(self) -> bool:
        try:
            self.s3_client.head_bucket(Bucket=self.bucket)
            return True
        except Exception:
            params = {"Bucket": self.bucket}
            if AWS_REGION != "us-east-1":
                params["CreateBucketConfiguration"] = {"LocationConstraint": AWS_REGION}
            self.s3_client.create_bucket(**params)
            return True

    def upload_transaction_chunk(self, df: pd.DataFrame, chunk_id: str) -> str:
        csv_data = df.to_csv(index=False)
        key = f"{self.tx_prefix}{datetime.now().strftime('%Y%m%d_%H%M%S')}_{chunk_id}.csv"
        self.s3_client.put_object(
            Bucket=self.bucket,
            Key=key,
            Body=csv_data.encode("utf-8"),
            ContentType="text/csv",
        )
        print(f"Uploaded chunk to s3://{self.bucket}/{key}")
        return key

    def upload_detections_batch(self, detections: list, batch_id: str) -> str:
        body = json.dumps(detections, indent=2, default=str).encode("utf-8")
        key = f"{self.det_prefix}{datetime.now().strftime('%Y%m%d_%H%M%S')}_detections_batch_{batch_id}.json"
        self.s3_client.put_object(
            Bucket=self.bucket,
            Key=key,
            Body=body,
            ContentType="application/json",
        )
        print(f"Uploaded detections to s3://{self.bucket}/{key}")
        return key

    def list_keys(self, prefix: str) -> List[str]:
        resp = self.s3_client.list_objects_v2(Bucket=self.bucket, Prefix=prefix)
        return [o["Key"] for o in (resp.get("Contents", []) or [])]


# COMMAND ----------

#  %md
#  ## Optional Postgres tracking

# COMMAND ----------

def pg_available() -> bool:
    return bool(PG_HOST and PG_DB and PG_USER and PG_PASSWORD and psycopg2 is not None)


def pg_connect():
    if not pg_available():
        return None
    return psycopg2.connect(
        host=PG_HOST,
        port=PG_PORT,
        database=PG_DB,
        user=PG_USER,
        password=PG_PASSWORD,
    )


def pg_ensure_tables():
    conn = pg_connect()
    if conn is None:
        return
    cur = conn.cursor()
    cur.execute(
        f"""
        CREATE TABLE IF NOT EXISTS {PG_TABLE_KEYS} (
            s3_key TEXT PRIMARY KEY,
            processed_at TIMESTAMPTZ DEFAULT NOW()
        )
        """
    )
    conn.commit()
    cur.close()
    conn.close()


def pg_load_processed_keys() -> Set[str]:
    conn = pg_connect()
    if conn is None:
        return set()
    cur = conn.cursor()
    cur.execute(f"SELECT s3_key FROM {PG_TABLE_KEYS}")
    rows = cur.fetchall()
    cur.close()
    conn.close()
    return {r[0] for r in rows}


def pg_mark_keys_processed(keys: List[str]):
    if not keys:
        return
    conn = pg_connect()
    if conn is None:
        return
    cur = conn.cursor()
    execute_values(cur, f"INSERT INTO {PG_TABLE_KEYS} (s3_key) VALUES %s ON CONFLICT DO NOTHING", [(k,) for k in keys])
    conn.commit()
    cur.close()
    conn.close()


# COMMAND ----------

#  %md
#  ## Normalization helpers

# COMMAND ----------

def normalize_txns(df: DataFrame) -> DataFrame:
    cols = [c.lower() for c in df.columns]
    df = df.toDF(*cols)
    renames = {
        "merchant": "merchant_id",
        "merchantid": "merchant_id",
        "customer": "customer_id",
        "customerid": "customer_id",
        "user": "customer_id",
        "userid": "customer_id",
        "trans_id": "transaction_id",
        "transactionid": "transaction_id",
        "txnid": "transaction_id",
        "txntype": "transaction_type",
        "type": "transaction_type",
    }
    for src, dst in renames.items():
        if src in df.columns:
            df = df.withColumnRenamed(src, dst)
    if "merchant_id" not in df.columns:
        raise ValueError(f"merchant_id not found. Available: {df.columns}")
    if "customer_id" not in df.columns:
        df = df.withColumn("customer_id", lit(""))
    if "amount" not in df.columns:
        df = df.withColumn("amount", lit(0.0).cast("double"))
    else:
        df = df.withColumn("amount", col("amount").cast("double"))
    if "transaction_type" not in df.columns:
        df = df.withColumn("transaction_type", lit("PURCHASE"))
    if "transaction_id" not in df.columns:
        df = df.withColumn("transaction_id", monotonically_increasing_id().cast("string"))
    if "timestamp" not in df.columns:
        df = df.withColumn("timestamp", current_timestamp().cast("string"))
    return df


def normalize_importance(df: DataFrame) -> DataFrame:
    cols = [c.lower() for c in df.columns]
    df = df.toDF(*cols)
    renames = {
        "source": "customer_id",
        "target": "merchant_id",
        "typetrans": "transaction_type",
        "type": "transaction_type",
        "weight": "weightage",
    }
    for src, dst in renames.items():
        if src in df.columns:
            df = df.withColumnRenamed(src, dst)
    if "weightage" in df.columns:
        df = df.withColumn("weightage", col("weightage").cast("double"))
    for req, default in [
        ("customer_id", lit("")),
        ("merchant_id", lit("")),
        ("transaction_type", lit("PURCHASE")),
        ("weightage", lit(0.1)),
    ]:
        if req not in df.columns:
            df = df.withColumn(req, default)
    return df


# COMMAND ----------

#  %md
#  ## Input loading (S3/DBFS/Workspace/Volume/GDrive)

# COMMAND ----------

def load_from_gdrive(spark: SparkSession) -> tuple[DataFrame, DataFrame]:
    # Use the mounted Google Drive path in Databricks
    tx_path = f"{GDRIVE_MOUNT_PATH}/{GDRIVE_TRANSACTIONS_FILENAME}"
    im_path = f"{GDRIVE_MOUNT_PATH}/{GDRIVE_IMPORTANCE_FILENAME}"
    
    # List files in the mounted directory to verify
    print(f"Listing files in {GDRIVE_MOUNT_PATH}:")
    try:
        files = dbutils.fs.ls(GDRIVE_MOUNT_PATH)
        for file in files:
            print(f"  - {file.name} (size: {file.size} bytes)")
    except Exception as e:
        print(f"Warning: Could not list files in {GDRIVE_MOUNT_PATH}: {e}")
    
    print(f"Loading transactions from: {tx_path}")
    print(f"Loading importance from: {im_path}")
    
    try:
        # Method 1: Try to read entire files using dbutils.fs.read()
        print("Attempting to read files using dbutils.fs.read()...")
        
        # Get file sizes first
        tx_size = dbutils.fs.head(tx_path, 1)  # Just check if accessible
        im_size = dbutils.fs.head(im_path, 1)
        
        # Read entire files
        tx_content = dbutils.fs.read(tx_path)
        im_content = dbutils.fs.read(im_path)
        
        # Convert to pandas DataFrames
        import io
        txns_pd = pd.read_csv(io.StringIO(tx_content))
        importance_pd = pd.read_csv(io.StringIO(im_content))
        
        # Convert to Spark DataFrames
        txns = spark.createDataFrame(txns_pd)
        importance = spark.createDataFrame(importance_pd)
        
        print(f"Successfully loaded {txns.count()} transactions and {importance.count()} importance records")
        return txns, importance
        
    except Exception as e:
        print(f"Method 1 failed: {e}")
        print("Trying Method 2: Copy to DBFS first...")
        
        try:
            # Method 2: Copy files to DBFS and then read
            dbfs_tx_path = f"dbfs:/FileStore/{GDRIVE_TRANSACTIONS_FILENAME}"
            dbfs_im_path = f"dbfs:/FileStore/{GDRIVE_IMPORTANCE_FILENAME}"
            
            print(f"Copying {tx_path} to {dbfs_tx_path}")
            dbutils.fs.cp(tx_path, dbfs_tx_path)
            print(f"Copying {im_path} to {dbfs_im_path}")
            dbutils.fs.cp(im_path, dbfs_im_path)
            
            print("Files copied successfully, now reading from DBFS...")
            txns = spark.read.csv(dbfs_tx_path, header=True, inferSchema=True)
            importance = spark.read.csv(dbfs_im_path, header=True, inferSchema=True)
            
            print(f"Successfully loaded {txns.count()} transactions and {importance.count()} importance records from DBFS")
            return txns, importance
            
        except Exception as e2:
            print(f"Method 2 also failed: {e2}")
            print("Trying Method 3: Use Unity Catalog path...")
            
            try:
                # Method 3: Try to access via Unity Catalog path
                catalog_tx_path = f"/Volumes/your_catalog/your_schema/your_volume/{GDRIVE_TRANSACTIONS_FILENAME}"
                catalog_im_path = f"/Volumes/your_catalog/your_schema/your_volume/{GDRIVE_IMPORTANCE_FILENAME}"
                
                print(f"Attempting to read from Unity Catalog: {catalog_tx_path}")
                txns = spark.read.csv(catalog_tx_path, header=True, inferSchema=True)
                importance = spark.read.csv(catalog_im_path, header=True, inferSchema=True)
                
                print(f"Successfully loaded from Unity Catalog: {txns.count()} transactions and {importance.count()} importance records")
                return txns, importance
                
            except Exception as e3:
                print(f"Method 3 also failed: {e3}")
                print("Trying Method 4: Use workspace path...")
                
                try:
                    # Method 4: Try to access via workspace path
                    workspace_tx_path = f"/Workspace/Users/your_email@domain.com/{GDRIVE_TRANSACTIONS_FILENAME}"
                    workspace_im_path = f"/Workspace/Users/your_email@domain.com/{GDRIVE_IMPORTANCE_FILENAME}"
                    
                    print(f"Attempting to read from workspace: {workspace_tx_path}")
                    txns_pd = pd.read_csv(workspace_tx_path)
                    importance_pd = pd.read_csv(workspace_im_path)
                    
                    txns = spark.createDataFrame(txns_pd)
                    importance = spark.createDataFrame(importance_pd)
                    
                    print(f"Successfully loaded from workspace: {txns.count()} transactions and {importance.count()} importance records")
                    return txns, importance
                    
                except Exception as e4:
                    print(f"All methods failed. Final error: {e4}")
                    raise Exception(f"Could not load files from any source. Errors: Method1={e}, Method2={e2}, Method3={e3}, Method4={e4}")


def load_from_catalog(spark: SparkSession) -> tuple[DataFrame, DataFrame]:
    """Load data directly from Unity Catalog tables"""
    print(f"Loading from Unity Catalog tables:")
    print(f"  Transactions: {CATALOG_TRANSACTIONS_PATH}")
    print(f"  Importance: {CATALOG_IMPORTANCE_PATH}")
    
    try:
        # Read directly from Unity Catalog tables
        txns = spark.read.table(CATALOG_TRANSACTIONS_PATH)
        importance = spark.read.table(CATALOG_IMPORTANCE_PATH)
        
        print(f"Successfully loaded {txns.count()} transactions and {importance.count()} importance records from Unity Catalog")
        
        # Debug: Show schemas
        print("DEBUG: Transactions schema:")
        txns.printSchema()
        print("DEBUG: Importance schema:")
        importance.printSchema()
        
        # Debug: Show sample data
        print("DEBUG: Sample transactions data:")
        txns.show(5, truncate=False)
        print("DEBUG: Sample importance data:")
        importance.show(5, truncate=False)
        
        return txns, importance
        
    except Exception as e:
        print(f"Error reading from Unity Catalog: {e}")
        print("Trying alternative catalog path format...")
        
        try:
            # Alternative format: try with /Volumes prefix
            alt_tx_path = f"/Volumes/{CATALOG_TRANSACTIONS_PATH}/transactions.csv"
            alt_im_path = f"/Volumes/{CATALOG_IMPORTANCE_PATH}/CustomerImportance.csv"
            
            print(f"Trying alternative path: {alt_tx_path}")
            txns = spark.read.csv(alt_tx_path, header=True, inferSchema=True)
            importance = spark.read.csv(alt_im_path, header=True, inferSchema=True)
            
            print(f"Successfully loaded using alternative path: {txns.count()} transactions and {importance.count()} importance records")
            return txns, importance
            
        except Exception as e2:
            print(f"Alternative path also failed: {e2}")
            raise Exception(f"Could not load from Unity Catalog. Errors: {e}, Alternative: {e2}")


def load_inputs_normalized(spark: SparkSession) -> tuple[DataFrame, DataFrame]:
    source = INPUT_SOURCE
    if source == "s3":
        txns_pd = _read_csv_from_s3_via_boto(INPUT_TRANSACTIONS_S3)
        importance_pd = _read_csv_from_s3_via_boto(INPUT_CUSTOMER_IMPORTANCE_S3)
        txns = spark.createDataFrame(txns_pd)
        importance = spark.createDataFrame(importance_pd)
    elif source == "workspace":
        txns_pd = pd.read_csv(WORKSPACE_TRANSACTIONS_PATH)
        importance_pd = pd.read_csv(WORKSPACE_IMPORTANCE_PATH)
        txns = spark.createDataFrame(txns_pd)
        importance = spark.createDataFrame(importance_pd)
    elif source == "volume":
        txns = spark.read.csv(TRANSACTIONS_CSV_VOLUME, header=True, inferSchema=True)
        importance = spark.read.csv(CUSTOMER_IMPORTANCE_CSV_VOLUME, header=True, inferSchema=True)
    elif source == "catalog":
        txns, importance = load_from_catalog(spark)
    elif source == "gdrive":
        txns, importance = load_from_gdrive(spark)
    else:  # dbfs
        txns = spark.read.csv(TRANSACTIONS_CSV_DBFS, header=True, inferSchema=True)
        importance = spark.read.csv(CUSTOMER_IMPORTANCE_CSV_DBFS, header=True, inferSchema=True)

    txns = normalize_txns(txns)
    importance = normalize_importance(importance)
    return txns, importance


# COMMAND ----------

#  %md
#  ## Pattern Detection (PatId1)

# COMMAND ----------

class PatternDetector:
    def __init__(self):
        self.ist_start = datetime.now(IST)
        self.counter = 0

    def _create_detection(self, customer_id: str, merchant_id: str) -> dict:
        self.counter += 1
        return {
            "YStartTime": self.ist_start.strftime("%Y-%m-%d %H:%M:%S"),
            "detectionTime": datetime.now(IST).strftime("%Y-%m-%d %H:%M:%S"),
            "patternId": "PatId1",
            "ActionType": "UPGRADE",
            "customerName": customer_id or "",
            "MerchantId": merchant_id or "",
        }

    def detect_patid1(self, spark: SparkSession, txn_df: DataFrame, importance_df: DataFrame) -> List[dict]:
        if txn_df is None or txn_df.limit(1).count() == 0:
            print("DEBUG: Transaction DataFrame is empty or None")
            return []
        if importance_df is None or importance_df.limit(1).count() == 0:
            print("DEBUG: Importance DataFrame is empty or None")
            return []

        print(f"DEBUG: Starting detection with {txn_df.count()} transactions and {importance_df.count()} importance records")
        
        # Show sample data
        print("DEBUG: Sample transactions:")
        txn_df.show(5, truncate=False)
        print("DEBUG: Sample importance:")
        importance_df.show(5, truncate=False)

        cust_stats = (
            txn_df.groupBy("merchant_id", "customer_id")
            .agg(spark_count(lit(1)).alias("total_transactions"))
        )
        
        print("DEBUG: Customer stats:")
        cust_stats.show(5, truncate=False)

        merchant_totals = (
            cust_stats.groupBy("merchant_id")
            .agg(spark_sum("total_transactions").alias("merchant_total_txns"))
        )
        
        print("DEBUG: Merchant totals:")
        merchant_totals.show(5, truncate=False)

        w = Window.partitionBy("merchant_id").orderBy(desc("total_transactions"))
        cust_ranked = cust_stats.withColumn("tile10", ntile(10).over(w))
        top10 = cust_ranked.filter(col("tile10") == 1)
        
        print("DEBUG: Top 10% customers:")
        top10.show(5, truncate=False)

        imp_avg = (
            importance_df.groupBy("merchant_id", "customer_id")
            .agg(avg("weightage").alias("avg_weightage"))
        )
        
        print("DEBUG: Importance averages:")
        imp_avg.show(5, truncate=False)

        merchant_thresholds = (
            imp_avg.groupBy("merchant_id")
            .agg(expr("percentile_approx(avg_weightage, 0.10)").alias("w10"))
        )
        
        print("DEBUG: Merchant thresholds (10th percentile):")
        merchant_thresholds.show(5, truncate=False)

        top10_with_weight = (
            top10.join(imp_avg, ["merchant_id", "customer_id"], "left")
            .join(merchant_totals, ["merchant_id"], "left")
            .join(merchant_thresholds, ["merchant_id"], "left")
        )
        
        print("DEBUG: Top 10% with weights and thresholds:")
        top10_with_weight.show(5, truncate=False)

        eligible = top10_with_weight.where(
            (col("merchant_total_txns") >= lit(MIN_TRANSACTIONS_FOR_UPGRADE))
            & (col("avg_weightage").isNotNull())
            & (col("w10").isNotNull())
            & (col("avg_weightage") <= col("w10"))
        ).select("merchant_id", "customer_id")
        
        print(f"DEBUG: MIN_TRANSACTIONS_FOR_UPGRADE = {MIN_TRANSACTIONS_FOR_UPGRADE}")
        print("DEBUG: Eligible customers after all filters:")
        eligible.show(5, truncate=False)
        
        # Additional debugging: Show merchants that meet the threshold
        print("DEBUG: Merchants that meet the transaction threshold:")
        merchants_above_threshold = merchant_totals.filter(col("merchant_total_txns") >= lit(MIN_TRANSACTIONS_FOR_UPGRADE))
        merchants_above_threshold.show(10, truncate=False)
        
        # Show how many customers would be eligible if we ignored the transaction threshold
        print("DEBUG: Customers that would be eligible without transaction threshold:")
        eligible_without_threshold = top10_with_weight.where(
            (col("avg_weightage").isNotNull())
            & (col("w10").isNotNull())
            & (col("avg_weightage") <= col("w10"))
        ).select("merchant_id", "customer_id")
        eligible_without_threshold.show(5, truncate=False)

        rows = eligible.distinct().collect()
        print(f"DEBUG: Final detections: {len(rows)}")
        return [self._create_detection(r["customer_id"], r["merchant_id"]) for r in rows]

    def batch(self, detections: List[dict]):
        for i in range(0, len(detections), DETECTION_BATCH_SIZE):
            yield detections[i : i + DETECTION_BATCH_SIZE]


# COMMAND ----------

#  %md
#  ## Mechanism X (chunk to S3) and Y (poll S3, detect, write)

# COMMAND ----------

def run_mechanism_x() -> List[str]:
    spark = SparkSession.builder.getOrCreate()
    s3 = S3Client()
    s3.create_bucket_if_not_exists()

    txns, _ = load_inputs_normalized(spark)

    total = int(txns.count())
    uploaded_keys: List[str] = []
    txns_pd = txns.toPandas()

    chunk_index = 0
    for start in range(0, total, CHUNK_SIZE):
        end = min(start + CHUNK_SIZE, total)
        chunk_pd = txns_pd.iloc[start:end]
        chunk_id = f"chunk_{chunk_index}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        key = s3.upload_transaction_chunk(chunk_pd, chunk_id)
        uploaded_keys.append(key)
        print(f"X: uploaded {key}")
        chunk_index += 1
        time.sleep(X_SLEEP_SECONDS)
    return uploaded_keys


def run_mechanism_y_once(keys: Optional[List[str]] = None):
    spark = SparkSession.builder.getOrCreate()
    s3 = S3Client()

    _, importance = load_inputs_normalized(spark)

    if not keys:
        keys = s3.list_keys(S3_TRANSACTIONS_FOLDER)

    detector = PatternDetector()
    processed_now: List[str] = []

    for key in keys:
        obj = s3.s3_client.get_object(Bucket=s3.bucket, Key=key)
        chunk_pd = pd.read_csv(io.StringIO(obj["Body"].read().decode("utf-8")))
        chunk_spark = spark.createDataFrame(chunk_pd)

        dets = detector.detect_patid1(spark, chunk_spark, importance)
        if dets:
            for i, batch in enumerate(detector.batch(dets)):
                s3.upload_detections_batch(batch, f"{key.replace('/', '_')}_{i}")
        print(f"Y: processed {key}, detections={len(dets)}")
        processed_now.append(key)

    if processed_now and pg_available():
        pg_mark_keys_processed(processed_now)


def run_mechanism_y_streaming():
    s3 = S3Client()
    seen: Set[str] = set()

    if pg_available():
        pg_ensure_tables()
        seen = pg_load_processed_keys()
        print(f"Loaded {len(seen)} processed keys from Postgres")

    rounds = 0
    while rounds < Y_MAX_POLL_ROUNDS:
        rounds += 1
        all_keys = s3.list_keys(S3_TRANSACTIONS_FOLDER)
        new_keys = [k for k in all_keys if k not in seen]
        if new_keys:
            print(f"Y round {rounds}: found {len(new_keys)} new keys")
            run_mechanism_y_once(new_keys)
            seen.update(new_keys)
        else:
            print(f"Y round {rounds}: no new keys")
        time.sleep(Y_POLL_INTERVAL_SECS)


# COMMAND ----------

#  %md
#  ## Execute

# COMMAND ----------

print(f"INPUT_SOURCE={INPUT_SOURCE}")

uploaded = None

if RUN_MODE == "both":
    # Start Y streaming concurrently, then run X
    y_thread = threading.Thread(target=run_mechanism_y_streaming, daemon=True)
    y_thread.start()
    uploaded = run_mechanism_x()
    # After X finishes, allow a couple more rounds for Y to finish up
    time.sleep(max(2 * Y_POLL_INTERVAL_SECS, 2))
elif RUN_MODE == "x":
    uploaded = run_mechanism_x()
elif RUN_MODE == "y":
    run_mechanism_y_streaming()

# COMMAND ----------

#  %md
#  ## List outputs in S3

# COMMAND ----------

s3 = S3Client()
resp_tx = s3.s3_client.list_objects_v2(Bucket=S3_BUCKET_NAME, Prefix=S3_TRANSACTIONS_FOLDER)
print("Transactions in S3:")
for o in (resp_tx.get("Contents", [])[:10]):
    print(" -", o["Key"])

resp_det = s3.s3_client.list_objects_v2(Bucket=S3_BUCKET_NAME, Prefix=S3_DETECTIONS_FOLDER)
print("\nDetections in S3:")
for o in (resp_det.get("Contents", [])[:10]):
    print(" -", o["Key"])


INPUT_SOURCE=catalog
Loading from Unity Catalog tables:
  Transactions: datadump.test.transactions
  Importance: datadump.test.customer_importance
Y round 1: no new keys
Successfully loaded 594643 transactions and 1189286 importance records from Unity Catalog
DEBUG: Transactions schema:
Y round 2: no new keys
root
 |-- step: long (nullable = true)
 |-- customer: string (nullable = true)
 |-- age: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- zipcodeOri: string (nullable = true)
 |-- merchant: string (nullable = true)
 |-- zipMerchant: string (nullable = true)
 |-- category: string (nullable = true)
 |-- amount: double (nullable = true)
 |-- fraud: long (nullable = true)

DEBUG: Importance schema:
root
 |-- step: long (nullable = true)
 |-- customer: string (nullable = true)
 |-- age: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- zipcodeOri: string (nullable = true)
 |-- merchant: string (nullable = true)
 |-- zipMerchant: string (nullable = t