In [31]:
from pyspark.sql.types import (
    StructType, StructField,
    StringType, LongType, IntegerType, DoubleType
)

from pyspark.sql import SparkSession
import pandas as pd

pd.set_option('display.max_columns', None)

In [2]:
spark = (
    SparkSession.builder
        .appName("AML-Investigator-Copilot")
        # ---- core tuning (safe defaults; adjust per cluster) ----
        .config("spark.sql.shuffle.partitions", "200")
        .config("spark.sql.session.timeZone", "UTC")
        .config("spark.sql.execution.arrow.pyspark.enabled", "true")
        # ---- CSV safety ----
        .config("spark.sql.csv.parser.columnPruning.enabled", "true")
        # ---- memory (optional; comment out if running on managed Spark) ----
        # .config("spark.driver.memory", "8g")
        # .config("spark.executor.memory", "8g")
        .getOrCreate()
)

spark.sparkContext.setLogLevel("WARN")

In [7]:
CSV_OPTIONS = {
    "header": "true",
    "mode": "FAILFAST",          # fail on malformed rows
    "quote": "\"",
    "escape": "\"",
    "multiLine": "false",
    "ignoreLeadingWhiteSpace": "true",
    "ignoreTrailingWhiteSpace": "true",
}
BASE_PATH = "/home/jovyan/work/investigator-data"

In [5]:
# -------------------------
# transactions.csv
# -------------------------
transactions_schema = StructType([
    StructField("txn_id", StringType(), False),
    StructField("party_id", StringType(), False),
    StructField("account_id", StringType(), False),
    StructField("instrument_type", StringType(), False),          # cash|wires|credit_cards|loans|ngi
    StructField("txn_timestamp_ms_utc", LongType(), False),       # epoch ms UTC
    StructField("direction", StringType(), False),                # debit|credit
    StructField("amount", DoubleType(), False),
    StructField("currency", StringType(), False),
    StructField("counterparty_id", StringType(), False),
    StructField("merchant_id", StringType(), True),               # nullable
    StructField("channel", StringType(), False),
    StructField("country", StringType(), False),
    StructField("state", StringType(), True),                     # nullable
    StructField("is_international", IntegerType(), False),        # 0/1
    StructField("description", StringType(), False),
])

# -------------------------
# parties.csv
# -------------------------
parties_schema = StructType([
    StructField("party_id", StringType(), False),
    StructField("party_type", StringType(), False),               # individual|business
    StructField("party_name", StringType(), False),
    StructField("industry", StringType(), True),                  # nullable
    StructField("country", StringType(), False),
    StructField("state", StringType(), True),                     # nullable
    StructField("onboarding_date", StringType(), False),          # YYYY-MM-DD
    StructField("expected_monthly_volume_usd", DoubleType(), False),
    StructField("expected_avg_txn_usd", DoubleType(), False),
    StructField("risk_rating", StringType(), False),              # low|medium|high
])

# -------------------------
# counterparties.csv
# -------------------------
counterparties_schema = StructType([
    StructField("counterparty_id", StringType(), False),
    StructField("counterparty_type", StringType(), False),        # individual|business
    StructField("country", StringType(), False),
])

# -------------------------
# merchants.csv
# -------------------------
merchants_schema = StructType([
    StructField("merchant_id", StringType(), False),
    StructField("merchant_name", StringType(), False),
    StructField("merchant_category", StringType(), False),
    StructField("country", StringType(), False),
    StructField("state", StringType(), True),                     # nullable
])

# -------------------------
# alerts_<model>.csv (cash/wires/ngi/credit_cards/loans)
# -------------------------
alerts_schema = StructType([
    StructField("alert_id", StringType(), False),
    StructField("party_id", StringType(), False),
    StructField("party_type", StringType(), True),                # nullable; join from parties if needed
    StructField("model_type", StringType(), False),               # cash|wires|credit_cards|loans|ngi
    StructField("model_version", StringType(), False),
    StructField("scenario_code", StringType(), False),
    StructField("alert_timestamp_ms_utc", LongType(), False),     # epoch ms UTC
    StructField("window_start_ms_utc", LongType(), False),
    StructField("window_end_ms_utc", LongType(), False),
    StructField("risk_score", DoubleType(), False),               # 0-100
    StructField("severity", StringType(), False),                 # low|medium|high
    StructField("trigger_summary", StringType(), False),
    StructField("supporting_txn_ids", StringType(), False),       # pipe-delimited txn_ids
    StructField("amount_total_usd", DoubleType(), False),
    StructField("txn_count", IntegerType(), False),
    StructField("features_json", StringType(), False),            # JSON string
    StructField("data_quality_flags", StringType(), False),       # pipe-delimited flags
])


In [9]:
transactions_df = (
    spark.read
        .options(**CSV_OPTIONS)
        .schema(transactions_schema)
        .csv(f"{BASE_PATH}/transactions.csv")
)

transactions_df.cache()

parties_df = (
    spark.read
        .options(**CSV_OPTIONS)
        .schema(parties_schema)
        .csv(f"{BASE_PATH}/parties.csv")
)

parties_df.cache()

counterparties_df = (
    spark.read
        .options(**CSV_OPTIONS)
        .schema(counterparties_schema)
        .csv(f"{BASE_PATH}/counterparties.csv")
)

merchants_df = (
    spark.read
        .options(**CSV_OPTIONS)
        .schema(merchants_schema)
        .csv(f"{BASE_PATH}/merchants.csv")
)

alerts_df = (
    spark.read
        .options(**CSV_OPTIONS)
        .schema(alerts_schema)
        .csv(f"{BASE_PATH}/alerts_*.csv")
)

alerts_df.cache()

print("Transactions:", transactions_df.count())
print("Parties:", parties_df.count())
print("Counterparties:", counterparties_df.count())
print("Merchants:", merchants_df.count())
print("Alerts:", alerts_df.count())

alerts_df.groupBy("model_type").count().show()


Transactions: 1000000
Parties: 50000
Counterparties: 200000
Merchants: 30000
Alerts: 10000
+------------+-----+
|  model_type|count|
+------------+-----+
|credit_cards| 2000|
|       loans| 2000|
|       wires| 2000|
|         ngi| 2000|
|        cash| 2000|
+------------+-----+



In [26]:
# alerts_df.toPandas()
alerts_df.select('trigger_summary').show(truncate=False)

+--------------------------------------------------------------------------------------------------+
|trigger_summary                                                                                   |
+--------------------------------------------------------------------------------------------------+
|credit_cards activity exceeded expected patterns: 8 txns totaling $566.15 (USD eq) in 7d window.  |
|credit_cards activity exceeded expected patterns: 11 txns totaling $566.21 (USD eq) in 14d window.|
|credit_cards activity exceeded expected patterns: 9 txns totaling $566.43 (USD eq) in 7d window.  |
|credit_cards activity exceeded expected patterns: 8 txns totaling $566.67 (USD eq) in 7d window.  |
|credit_cards activity exceeded expected patterns: 8 txns totaling $566.71 (USD eq) in 7d window.  |
|credit_cards activity exceeded expected patterns: 8 txns totaling $566.76 (USD eq) in 7d window.  |
|credit_cards activity exceeded expected patterns: 8 txns totaling $566.79 (USD eq) in 7d w

In [28]:
# ============================================
# AML Investigator Copilot
# Party-level Case Packet Spark Job + SQL Views
# ============================================
#
# Assumes you already created:
#   - spark session
#   - DataFrames: transactions_df, parties_df, counterparties_df, merchants_df, alerts_df
#   - Or temp views: transactions, parties, counterparties, merchants, alerts
#
# Uses the grounded schemas you requested earlier.
#
# Notes:
# - Uses epoch-ms UTC everywhere.
# - Builds deterministic "case packets" per party and window.
# - Produces both DataFrames and SQL views.
#
# --------------------------------------------

from pyspark.sql import functions as F
from pyspark.sql import Window
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, LongType, IntegerType, DoubleType

# -------------------------
# Parameters
# -------------------------
DEFAULT_WINDOW_DAYS = 30
SUPPORT_TXN_MAX = 200          # max txns to include in packet (most recent by timestamp)
TOP_COUNTERPARTIES_MAX = 50
TOP_MERCHANTS_MAX = 50

# Case selection knobs: choose one strategy
# 1) Use alert windows (recommended): create a case per (party_id, window_start, window_end).
# 2) Or build rolling window (e.g., last 30 days from max txn time). We'll do #1.

# -------------------------
# Helper: parse supporting_txn_ids to array
# -------------------------
alerts_enriched = (
    alerts_df
    .withColumn("supporting_txn_id_arr", F.split(F.coalesce(F.col("supporting_txn_ids"), F.lit("")), r"\|"))
    .withColumn("supporting_txn_id_arr", F.expr("filter(supporting_txn_id_arr, x -> x is not null and x != '')"))
)

# -------------------------
# 0) Build Case Index: one "case" per party per alert window
# -------------------------
# case_id is stable hash of party_id + window_start + window_end
case_index = (
    alerts_enriched
    .select(
        "party_id",
        F.col("window_start_ms_utc").alias("case_window_start_ms_utc"),
        F.col("window_end_ms_utc").alias("case_window_end_ms_utc"),
    )
    .dropDuplicates()
    .withColumn(
        "case_id",
        F.sha2(
            F.concat_ws(
                "||",
                F.col("party_id"),
                F.col("case_window_start_ms_utc").cast("string"),
                F.col("case_window_end_ms_utc").cast("string"),
            ),
            256
        )
    )
)

case_index.createOrReplaceTempView("case_index")

# -------------------------
# 1) Join Parties into Case Index
# -------------------------
case_party = (
    case_index
    .join(
        parties_df.select(
            "party_id", "party_type", "party_name", "industry", "country", "state",
            "onboarding_date", "expected_monthly_volume_usd", "expected_avg_txn_usd", "risk_rating"
        ),
        on="party_id",
        how="left"
    )
)

case_party.createOrReplaceTempView("case_party")

# -------------------------
# 2) Alerts per Case (summary + details arrays)
# -------------------------
a = alerts_enriched.alias("a")
c = case_index.alias("c")

alerts_case = (
    a.join(
        c,
        on=(
            (F.col("a.party_id") == F.col("c.party_id")) &
            (F.col("a.window_start_ms_utc") == F.col("c.case_window_start_ms_utc")) &
            (F.col("a.window_end_ms_utc") == F.col("c.case_window_end_ms_utc"))
        ),
        how="inner"
    )
    .select(
        F.col("c.case_id").alias("case_id"),
        F.col("a.party_id").alias("party_id"),
        F.col("a.model_type").alias("model_type"),
        F.col("a.model_version").alias("model_version"),
        F.col("a.scenario_code").alias("scenario_code"),
        F.col("a.alert_id").alias("alert_id"),
        F.col("a.alert_timestamp_ms_utc").alias("alert_timestamp_ms_utc"),
        F.col("a.window_start_ms_utc").alias("window_start_ms_utc"),
        F.col("a.window_end_ms_utc").alias("window_end_ms_utc"),
        F.col("a.risk_score").alias("risk_score"),
        F.col("a.severity").alias("severity"),
        F.col("a.trigger_summary").alias("trigger_summary"),
        F.col("a.supporting_txn_ids").alias("supporting_txn_ids"),
        F.col("a.supporting_txn_id_arr").alias("supporting_txn_id_arr"),
        F.col("a.amount_total_usd").alias("amount_total_usd"),
        F.col("a.txn_count").alias("txn_count"),
        F.col("a.features_json").alias("features_json"),
        F.col("a.data_quality_flags").alias("data_quality_flags"),
    )
)

alerts_case.createOrReplaceTempView("alerts_case")

# Summary per case
alerts_summary = (
    alerts_case
    .groupBy("case_id", "party_id")
    .agg(
        F.count("*").alias("alerts_count"),
        F.max("risk_score").alias("max_risk_score"),
        F.expr("percentile_approx(risk_score, 0.5)").alias("median_risk_score"),
        F.collect_set("model_type").alias("model_types"),
        F.collect_set("scenario_code").alias("scenario_codes"),
        F.collect_set("severity").alias("severities"),
        # Count by severity
        F.sum(F.when(F.col("severity") == "high", 1).otherwise(0)).alias("alerts_high"),
        F.sum(F.when(F.col("severity") == "medium", 1).otherwise(0)).alias("alerts_medium"),
        F.sum(F.when(F.col("severity") == "low", 1).otherwise(0)).alias("alerts_low"),
    )
)

alerts_summary.createOrReplaceTempView("alerts_summary")

# Detailed alerts array (ordered)
# We'll create a struct and sort it by alert_timestamp_ms_utc desc (Spark 3+ supports array_sort with lambda)
alerts_details = (
    alerts_case
    .withColumn(
        "alert_struct",
        F.struct(
            F.col("alert_timestamp_ms_utc").alias("ts"),
            F.col("alert_id").alias("alert_id"),
            F.col("model_type").alias("model_type"),
            F.col("scenario_code").alias("scenario_code"),
            F.col("risk_score").alias("risk_score"),
            F.col("severity").alias("severity"),
            F.col("trigger_summary").alias("trigger_summary"),
            F.col("supporting_txn_ids").alias("supporting_txn_ids"),
            F.col("amount_total_usd").alias("amount_total_usd"),
            F.col("txn_count").alias("txn_count"),
            F.col("features_json").alias("features_json"),
            F.col("data_quality_flags").alias("data_quality_flags"),
        )
    )
    .groupBy("case_id", "party_id")
    .agg(
        F.sort_array(F.collect_list("alert_struct"), asc=False).alias("alerts")
    )
)

alerts_details.createOrReplaceTempView("alerts_details")

# -------------------------
# 3) Transactions scoped to each case window
# -------------------------
tx_case = (
    transactions_df.alias("t")
    .join(
        case_index.alias("c"),
        on=[F.col("t.party_id") == F.col("c.party_id")],
        how="inner"
    )
    .where(
        (F.col("t.txn_timestamp_ms_utc") >= F.col("c.case_window_start_ms_utc")) &
        (F.col("t.txn_timestamp_ms_utc") <= F.col("c.case_window_end_ms_utc"))
    )
    .select(
        F.col("c.case_id"),
        F.col("t.*")
    )
)

tx_case.createOrReplaceTempView("tx_case")

# -------------------------
# 4) Case Timeline Aggregates (amounts, counts, intl, instrument mix)
# -------------------------
tx_summary = (
    tx_case
    .withColumn("amount_usd", F.when(F.col("currency") == "USD", F.col("amount"))
                .when(F.col("currency") == "EUR", F.col("amount") * F.lit(1.10))
                .when(F.col("currency") == "GBP", F.col("amount") * F.lit(1.28))
                .when(F.col("currency") == "CAD", F.col("amount") * F.lit(0.75))
                .when(F.col("currency") == "MXN", F.col("amount") * F.lit(0.058))
                .otherwise(F.col("amount")))
    .groupBy("case_id", "party_id")
    .agg(
        F.count("*").alias("txn_count_case"),
        F.sum("amount_usd").alias("amount_total_usd_case"),
        F.expr("percentile_approx(amount_usd, 0.5)").alias("median_amount_usd_case"),
        F.max("amount_usd").alias("max_amount_usd_case"),
        F.min("txn_timestamp_ms_utc").alias("first_txn_ms_utc_case"),
        F.max("txn_timestamp_ms_utc").alias("last_txn_ms_utc_case"),
        F.avg("is_international").alias("intl_ratio_case"),
        F.collect_set("instrument_type").alias("instrument_types_case"),
    )
)

tx_summary.createOrReplaceTempView("tx_summary")

# Instrument breakdown
tx_instrument_breakdown = (
    tx_case
    .withColumn("amount_usd", F.when(F.col("currency") == "USD", F.col("amount"))
                .when(F.col("currency") == "EUR", F.col("amount") * F.lit(1.10))
                .when(F.col("currency") == "GBP", F.col("amount") * F.lit(1.28))
                .when(F.col("currency") == "CAD", F.col("amount") * F.lit(0.75))
                .when(F.col("currency") == "MXN", F.col("amount") * F.lit(0.058))
                .otherwise(F.col("amount")))
    .groupBy("case_id", "party_id", "instrument_type")
    .agg(
        F.count("*").alias("txn_count"),
        F.sum("amount_usd").alias("amount_total_usd")
    )
)

tx_instrument_breakdown.createOrReplaceTempView("tx_instrument_breakdown")

# -------------------------
# 5) Top Counterparties / Merchants (for network context)
# -------------------------
top_counterparties = (
    tx_case
    .withColumn("amount_usd", F.when(F.col("currency") == "USD", F.col("amount"))
                .when(F.col("currency") == "EUR", F.col("amount") * F.lit(1.10))
                .when(F.col("currency") == "GBP", F.col("amount") * F.lit(1.28))
                .when(F.col("currency") == "CAD", F.col("amount") * F.lit(0.75))
                .when(F.col("currency") == "MXN", F.col("amount") * F.lit(0.058))
                .otherwise(F.col("amount")))
    .groupBy("case_id", "party_id", "counterparty_id")
    .agg(
        F.count("*").alias("txn_count"),
        F.sum("amount_usd").alias("amount_total_usd"),
        F.avg("is_international").alias("intl_ratio"),
        F.max("txn_timestamp_ms_utc").alias("last_txn_ms_utc"),
    )
)

w_cp = Window.partitionBy("case_id", "party_id").orderBy(F.col("amount_total_usd").desc(), F.col("txn_count").desc())
top_counterparties_ranked = top_counterparties.withColumn("rnk", F.row_number().over(w_cp)).where(F.col("rnk") <= TOP_COUNTERPARTIES_MAX)

top_counterparties_enriched = (
    top_counterparties_ranked
    .join(counterparties_df, on="counterparty_id", how="left")
    .withColumn(
        "cp_struct",
        F.struct(
            F.col("amount_total_usd").alias("amount_total_usd"),
            F.col("txn_count").alias("txn_count"),
            F.col("intl_ratio").alias("intl_ratio"),
            F.col("last_txn_ms_utc").alias("last_txn_ms_utc"),
            F.col("counterparty_id").alias("counterparty_id"),
            F.col("counterparty_type").alias("counterparty_type"),
            F.col("country").alias("country"),
        )
    )
    .groupBy("case_id", "party_id")
    .agg(F.collect_list("cp_struct").alias("top_counterparties"))
)

top_counterparties_enriched.createOrReplaceTempView("top_counterparties_enriched")

top_merchants = (
    tx_case
    .where(F.col("merchant_id").isNotNull())
    .withColumn("amount_usd", F.when(F.col("currency") == "USD", F.col("amount"))
                .when(F.col("currency") == "EUR", F.col("amount") * F.lit(1.10))
                .when(F.col("currency") == "GBP", F.col("amount") * F.lit(1.28))
                .when(F.col("currency") == "CAD", F.col("amount") * F.lit(0.75))
                .when(F.col("currency") == "MXN", F.col("amount") * F.lit(0.058))
                .otherwise(F.col("amount")))
    .groupBy("case_id", "party_id", "merchant_id")
    .agg(
        F.count("*").alias("txn_count"),
        F.sum("amount_usd").alias("amount_total_usd"),
        F.max("txn_timestamp_ms_utc").alias("last_txn_ms_utc"),
    )
)

w_m = Window.partitionBy("case_id", "party_id").orderBy(F.col("amount_total_usd").desc(), F.col("txn_count").desc())
top_merchants_ranked = top_merchants.withColumn("rnk", F.row_number().over(w_m)).where(F.col("rnk") <= TOP_MERCHANTS_MAX)

top_merchants_enriched = (
    top_merchants_ranked
    .join(merchants_df, on="merchant_id", how="left")
    .withColumn(
        "m_struct",
        F.struct(
            F.col("amount_total_usd").alias("amount_total_usd"),
            F.col("txn_count").alias("txn_count"),
            F.col("last_txn_ms_utc").alias("last_txn_ms_utc"),
            F.col("merchant_id").alias("merchant_id"),
            F.col("merchant_name").alias("merchant_name"),
            F.col("merchant_category").alias("merchant_category"),
            F.col("country").alias("country"),
            F.col("state").alias("state"),
        )
    )
    .groupBy("case_id", "party_id")
    .agg(F.collect_list("m_struct").alias("top_merchants"))
)

top_merchants_enriched.createOrReplaceTempView("top_merchants_enriched")

# -------------------------
# 6) Supporting Transactions (from supporting_txn_ids across alerts)
# -------------------------
# Explode supporting txn ids per case, de-dup, join to tx_case and keep most recent N
supporting_txn_ids = (
    alerts_case
    .select("case_id", "party_id", F.explode_outer("supporting_txn_id_arr").alias("txn_id"))
    .where(F.col("txn_id").isNotNull() & (F.col("txn_id") != ""))
    .dropDuplicates(["case_id", "party_id", "txn_id"])
)

supporting_txn_ids.createOrReplaceTempView("supporting_txn_ids")

supporting_tx = (
    supporting_txn_ids
    .join(tx_case.select("case_id","party_id","txn_id","txn_timestamp_ms_utc","instrument_type","direction","amount","currency",
                         "counterparty_id","merchant_id","channel","country","state","is_international","description"),
          on=["case_id","party_id","txn_id"],
          how="left")
)

w_stx = Window.partitionBy("case_id","party_id").orderBy(F.col("txn_timestamp_ms_utc").desc())
supporting_tx_limited = (
    supporting_tx
    .withColumn("rnk", F.row_number().over(w_stx))
    .where(F.col("rnk") <= SUPPORT_TXN_MAX)
    .drop("rnk")
)

supporting_tx_packet = (
    supporting_tx_limited
    .withColumn(
        "txn_struct",
        F.struct(
            F.col("txn_timestamp_ms_utc").alias("ts"),
            "txn_id",
            "instrument_type",
            "direction",
            "amount",
            "currency",
            "counterparty_id",
            "merchant_id",
            "channel",
            "country",
            "state",
            "is_international",
            "description",
        )
    )
    .groupBy("case_id","party_id")
    .agg(F.sort_array(F.collect_list("txn_struct"), asc=False).alias("supporting_transactions"))
)

supporting_tx_packet.createOrReplaceTempView("supporting_tx_packet")

# -------------------------
# 7) Case Packet (party-level) as a single row per (case_id, party_id)
# -------------------------
case_packet_df = (
    case_party
    .join(alerts_summary, on=["case_id","party_id"], how="left")
    .join(alerts_details, on=["case_id","party_id"], how="left")
    .join(tx_summary, on=["case_id","party_id"], how="left")
    .join(top_counterparties_enriched, on=["case_id","party_id"], how="left")
    .join(top_merchants_enriched, on=["case_id","party_id"], how="left")
    .join(supporting_tx_packet, on=["case_id","party_id"], how="left")
    .select(
        "case_id",
        "party_id",
        "party_type",
        "party_name",
        "industry",
        "country",
        "state",
        "onboarding_date",
        "expected_monthly_volume_usd",
        "expected_avg_txn_usd",
        "risk_rating",
        "case_window_start_ms_utc",
        "case_window_end_ms_utc",

        # Alerts summary
        "alerts_count",
        "alerts_high",
        "alerts_medium",
        "alerts_low",
        "max_risk_score",
        "median_risk_score",
        "model_types",
        "scenario_codes",
        "severities",

        # Tx summary
        "txn_count_case",
        "amount_total_usd_case",
        "median_amount_usd_case",
        "max_amount_usd_case",
        "first_txn_ms_utc_case",
        "last_txn_ms_utc_case",
        "intl_ratio_case",
        "instrument_types_case",

        # Details arrays
        "alerts",
        "top_counterparties",
        "top_merchants",
        "supporting_transactions",
    )
)

case_packet_df.createOrReplaceTempView("case_packet")

In [32]:
case_packet_df.toPandas()

Unnamed: 0,case_id,party_id,party_type,party_name,industry,country,state,onboarding_date,expected_monthly_volume_usd,expected_avg_txn_usd,risk_rating,case_window_start_ms_utc,case_window_end_ms_utc,alerts_count,alerts_high,alerts_medium,alerts_low,max_risk_score,median_risk_score,model_types,scenario_codes,severities,txn_count_case,amount_total_usd_case,median_amount_usd_case,max_amount_usd_case,first_txn_ms_utc_case,last_txn_ms_utc_case,intl_ratio_case,instrument_types_case,alerts,top_counterparties,top_merchants,supporting_transactions
0,000064bcf2defa62a3d128181729c0eedf62f6dd8cfaba...,P0006731,business,Smith United Solutions,Real Estate,CAN,,2024-10-18,79397.35,285.82,medium,1766083627747,1767293227747,1,0,1,0,81.65,81.65,[ngi],[NGI_RAPID_MOVEMENT],[medium],,,,,,,,,"[{'ts': 1767293227747, 'alert_id': 'ALNG000163...",,,"[{'ts': None, 'txn_id': 'T000000817706', 'inst..."
1,00526de54b44c98ca1cfd32a8413e5582e742238f86f31...,P0008775,individual,Alex Anderson,Real Estate,USA,NJ,2022-04-08,8262.46,85.23,low,1766716504285,1767321304285,1,1,0,0,97.00,97.00,[cash],[CASH_STRUCTURING],[high],,,,,,,,,"[{'ts': 1767321304285, 'alert_id': 'ALCA000194...",,,"[{'ts': None, 'txn_id': 'T000000959123', 'inst..."
2,007e7dec3a14413bf74312506048ffa4497fb4446adb4c...,P0040608,individual,Jamie Thomas,Healthcare,USA,OH,2022-03-01,22289.15,129.33,low,1766748855139,1767353655139,1,1,0,0,86.80,86.80,[credit_cards],[CC_VELOCITY_ANOMALY],[high],1.0,3541.67,3541.67,3541.67,1.766954e+12,1.766954e+12,1.0,[ngi],"[{'ts': 1767353655139, 'alert_id': 'ALCR000173...","[{'amount_total_usd': 3541.67, 'txn_count': 1,...",,"[{'ts': None, 'txn_id': 'T000000948777', 'inst..."
3,0091feabf7204de8714c2485297f2f12cd12563cecea93...,P0024287,individual,Robin Moore,Construction,DEU,,2018-10-23,52468.18,871.17,low,1766873579357,1767478379357,1,0,0,1,47.75,47.75,[wires],[WIRE_HIGH_RISK_GEO],[low],,,,,,,,,"[{'ts': 1767478379357, 'alert_id': 'ALWI000095...",,,"[{'ts': None, 'txn_id': 'T000000958624', 'inst..."
4,00c59720523474fb2afcc8bd03562d4ac5e7afe7c7ad5e...,P0032507,business,Wilson United Solutions,Retail,USA,NJ,2023-02-06,20107.49,307.59,low,1766833777189,1767438577189,1,0,0,1,6.40,6.40,[cash],[CASH_STRUCTURING],[low],,,,,,,,,"[{'ts': 1767438577189, 'alert_id': 'ALCA000012...",,,"[{'ts': None, 'txn_id': 'T000000945979', 'inst..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,ffc5fa89007fa48ff0175b95bbce03298fea2e2b146603...,P0001786,business,Davis Rapid LLC,Auto,USA,NY,2015-05-29,11799.18,130.55,low,1766232470491,1767442070491,1,0,1,0,78.15,78.15,[ngi],[NGI_RAPID_MOVEMENT],[medium],,,,,,,,,"[{'ts': 1767442070491, 'alert_id': 'ALNG000156...",,,"[{'ts': None, 'txn_id': 'T000000938041', 'inst..."
9996,ffcb8665267195eada9103d98a4b609e96687bcbb16d96...,P0025418,individual,Jordan Brown,Hospitality,USA,MI,2022-05-15,150344.59,605.62,low,1766143399746,1767352999746,1,0,1,0,74.00,74.00,[loans],[LOAN_UNUSUAL_PAYDOWN],[medium],1.0,35916.96,35916.96,35916.96,1.767172e+12,1.767172e+12,0.0,[loans],"[{'ts': 1767352999746, 'alert_id': 'ALLO000148...","[{'amount_total_usd': 35916.96, 'txn_count': 1...",,"[{'ts': 1767172192283.0, 'txn_id': 'T000000405..."
9997,ffcfe340be92299c82dec632c17b284e6c7f6c41bacd01...,P0021560,individual,Taylor Hernandez,Auto,USA,PA,2019-05-03,58742.03,399.04,low,1766219610278,1767429210278,1,0,0,1,15.90,15.90,[credit_cards],[CC_VELOCITY_ANOMALY],[low],,,,,,,,,"[{'ts': 1767429210278, 'alert_id': 'ALCR000031...",,,"[{'ts': None, 'txn_id': 'T000000959853', 'inst..."
9998,ffd69f5bba7c4c7bbd47a31ad60938d55e281f1099b1cf...,P0038121,business,Smith Prime Trading,Auto,MEX,,2019-04-07,1785.07,119.04,low,1766864217931,1767469017931,1,0,0,1,59.85,59.85,[cash],[CASH_STRUCTURING],[low],,,,,,,,,"[{'ts': 1767469017931, 'alert_id': 'ALCA000119...",,,"[{'ts': None, 'txn_id': 'T000000884231', 'inst..."


In [33]:
# -------------------------
# 8) Spark SQL View Layer
# -------------------------
# These create logical views for investigation + LLM consumption.
# You can materialize as tables as needed.

# v_alert_summary: one row per case_id
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_alert_summary AS
SELECT
  c.case_id,
  c.party_id,
  c.case_window_start_ms_utc,
  c.case_window_end_ms_utc,
  s.alerts_count,
  s.alerts_high,
  s.alerts_medium,
  s.alerts_low,
  s.max_risk_score,
  s.median_risk_score,
  s.model_types,
  s.scenario_codes,
  s.severities
FROM case_index c
LEFT JOIN alerts_summary s
  ON c.case_id = s.case_id AND c.party_id = s.party_id
""")

# v_tx_timeline_daily: daily buckets (good for charts + narratives)
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_tx_timeline_daily AS
SELECT
  t.case_id,
  t.party_id,
  from_unixtime(CAST(t.txn_timestamp_ms_utc/1000 AS BIGINT), 'yyyy-MM-dd') AS txn_date_utc,
  t.instrument_type,
  COUNT(*) AS txn_count,
  SUM(
    CASE currency
      WHEN 'USD' THEN amount
      WHEN 'EUR' THEN amount*1.10
      WHEN 'GBP' THEN amount*1.28
      WHEN 'CAD' THEN amount*0.75
      WHEN 'MXN' THEN amount*0.058
      ELSE amount
    END
  ) AS amount_total_usd
FROM tx_case t
GROUP BY t.case_id, t.party_id, from_unixtime(CAST(t.txn_timestamp_ms_utc/1000 AS BIGINT), 'yyyy-MM-dd'), t.instrument_type
""")

# v_instrument_breakdown: one row per case + instrument
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_instrument_breakdown AS
SELECT
  case_id,
  party_id,
  instrument_type,
  txn_count,
  amount_total_usd
FROM tx_instrument_breakdown
""")

# v_top_counterparties: exploded top counterparties for easy querying
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_top_counterparties AS
SELECT
  case_id,
  party_id,
  cp.counterparty_id,
  cp.counterparty_type,
  cp.country,
  cp.txn_count,
  cp.amount_total_usd,
  cp.intl_ratio,
  cp.last_txn_ms_utc
FROM top_counterparties_enriched
LATERAL VIEW explode(top_counterparties) e AS cp
""")

# v_top_merchants: exploded top merchants
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_top_merchants AS
SELECT
  case_id,
  party_id,
  m.merchant_id,
  m.merchant_name,
  m.merchant_category,case_index
  m.country,
  m.state,
  m.txn_count,
  m.amount_total_usd,
  m.last_txn_ms_utc
FROM top_merchants_enriched
LATERAL VIEW explode(top_merchants) e AS m
""")

# v_supporting_transactions: exploded supporting txns (LLM-ready evidence list)
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_supporting_transactions AS
SELECT
  case_id,
  party_id,
  st.ts AS txn_timestamp_ms_utc,
  st.txn_id,
  st.instrument_type,
  st.direction,
  st.amount,
  st.currency,
  st.counterparty_id,
  st.merchant_id,
  st.channel,
  st.country,
  st.state,
  st.is_international,
  st.description
FROM supporting_tx_packet
LATERAL VIEW explode(supporting_transactions) e AS st
""")

# v_case_packet: the complete, LLM-ready packet (1 row per case)
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_case_packet AS
SELECT * FROM case_packet
""")

print("✅ Created case_packet_df and SQL views: v_case_packet, v_alert_summary, v_tx_timeline_daily, v_instrument_breakdown, v_top_counterparties, v_top_merchants, v_supporting_transactions")


✅ Created case_packet_df and SQL views: v_case_packet, v_alert_summary, v_tx_timeline_daily, v_instrument_breakdown, v_top_counterparties, v_top_merchants, v_supporting_transactions


In [34]:

# -------------------------
# 9) Outputs / Examples
# -------------------------
# Example: show one packet
spark.sql("SELECT case_id, party_id, max_risk_score, alerts_count, txn_count_case, amount_total_usd_case FROM v_case_packet ORDER BY max_risk_score DESC LIMIT 20").show(truncate=False)

+----------------------------------------------------------------+--------+--------------+------------+--------------+---------------------+
|case_id                                                         |party_id|max_risk_score|alerts_count|txn_count_case|amount_total_usd_case|
+----------------------------------------------------------------+--------+--------------+------------+--------------+---------------------+
|17a0d12cf6496500fdd0f797097ff9ff861ebc94ee8ca713ebf43fc7c4b26505|P0047017|100.0         |1           |NULL          |NULL                 |
|cd312b05e77f53643de0709e84457e3a6cb1be8aa8fd7d51861fd84e6cacb262|P0005105|100.0         |1           |2             |654.7853             |
|4739377456d421ce1ea2cbdd9d11c83f3ef30143fcc79abc26846bfe0c5cb637|P0004670|100.0         |1           |1             |1953.59              |
|18b4dd24f9fd6628e18b0689dd92f1a3cf054f46c396325107ef74ccde37be9a|P0020127|100.0         |1           |NULL          |NULL                 |
|112e98b94173

In [36]:

# Example: daily timeline for a case
spark.sql("SELECT * FROM v_tx_timeline_daily WHERE case_id = 'ba6a7e48245f3df7b8b7afbc8e48376ba13036c49d1110d6eddbd05daf0449c3' ORDER BY txn_date_utc, instrument_type").show(200, truncate=False)

+----------------------------------------------------------------+--------+------------+---------------+---------+----------------+
|case_id                                                         |party_id|txn_date_utc|instrument_type|txn_count|amount_total_usd|
+----------------------------------------------------------------+--------+------------+---------------+---------+----------------+
|ba6a7e48245f3df7b8b7afbc8e48376ba13036c49d1110d6eddbd05daf0449c3|P0041929|2025-12-31  |cash           |1        |935.22          |
+----------------------------------------------------------------+--------+------------+---------------+---------+----------------+



In [37]:

# Example: top counterparties for a case
spark.sql("SELECT * FROM v_top_counterparties WHERE case_id = 'ba6a7e48245f3df7b8b7afbc8e48376ba13036c49d1110d6eddbd05daf0449c3' ORDER BY amount_total_usd DESC LIMIT 50").show(truncate=False)

+----------------------------------------------------------------+--------+---------------+-----------------+-------+---------+----------------+----------+---------------+
|case_id                                                         |party_id|counterparty_id|counterparty_type|country|txn_count|amount_total_usd|intl_ratio|last_txn_ms_utc|
+----------------------------------------------------------------+--------+---------------+-----------------+-------+---------+----------------+----------+---------------+
|ba6a7e48245f3df7b8b7afbc8e48376ba13036c49d1110d6eddbd05daf0449c3|P0041929|CP00048808     |individual       |MEX    |1        |935.22          |0.0       |1767185376324  |
+----------------------------------------------------------------+--------+---------------+-----------------+-------+---------+----------------+----------+---------------+



In [38]:
# -------------------------
# 10) Optional: persist to parquet (recommended for scale)
# -------------------------
# OUT_PATH = "s3://your-bucket/aml_copilot/curated/"
case_packet_df.write.mode("overwrite").parquet(f"{BASE_PATH}/case_packet")
spark.table("v_tx_timeline_daily").write.mode("overwrite").parquet(f"{BASE_PATH}/tx_timeline_daily")

In [39]:
# 2) Create a “case list” view (triage queue)
# Investigators need a queue ordered by risk and recency.
spark.sql("""
CREATE OR REPLACE TEMP VIEW v_case_queue AS
SELECT
  case_id,
  party_id,
  party_type,
  party_name,
  risk_rating,
  max_risk_score,
  alerts_count,
  txn_count_case,
  amount_total_usd_case,
  last_txn_ms_utc_case,
  case_window_start_ms_utc,
  case_window_end_ms_utc
FROM case_packet
ORDER BY max_risk_score DESC, last_txn_ms_utc_case DESC
""")


DataFrame[]

In [40]:
# 3) Define the “LLM packet payload” (normalized JSON shape)

# Even though case_packet already has arrays, you’ll want a stable, minimal JSON to send to the LLM:
# - party profile
# - alert summary + top N alert details
# - tx summary
# - daily timeline
# - top counterparties/merchants
# - supporting transactions (evidence list)
packet_for_llm = spark.sql("""
SELECT
  cp.case_id,
  to_json(named_struct(
    'case_id', cp.case_id,
    'party', named_struct(
      'party_id', cp.party_id,
      'party_type', cp.party_type,
      'party_name', cp.party_name,
      'industry', cp.industry,
      'country', cp.country,
      'state', cp.state,
      'onboarding_date', cp.onboarding_date,
      'risk_rating', cp.risk_rating,
      'expected_monthly_volume_usd', cp.expected_monthly_volume_usd,
      'expected_avg_txn_usd', cp.expected_avg_txn_usd
    ),
    'window', named_struct(
      'start_ms_utc', cp.case_window_start_ms_utc,
      'end_ms_utc', cp.case_window_end_ms_utc
    ),
    'alerts_summary', named_struct(
      'alerts_count', cp.alerts_count,
      'alerts_high', cp.alerts_high,
      'alerts_medium', cp.alerts_medium,
      'alerts_low', cp.alerts_low,
      'max_risk_score', cp.max_risk_score,
      'median_risk_score', cp.median_risk_score,
      'model_types', cp.model_types,
      'scenario_codes', cp.scenario_codes
    ),
    'alerts', cp.alerts,
    'tx_summary', named_struct(
      'txn_count_case', cp.txn_count_case,
      'amount_total_usd_case', cp.amount_total_usd_case,
      'median_amount_usd_case', cp.median_amount_usd_case,
      'max_amount_usd_case', cp.max_amount_usd_case,
      'first_txn_ms_utc_case', cp.first_txn_ms_utc_case,
      'last_txn_ms_utc_case', cp.last_txn_ms_utc_case,
      'intl_ratio_case', cp.intl_ratio_case,
      'instrument_types_case', cp.instrument_types_case
    ),
    'top_counterparties', cp.top_counterparties,
    'top_merchants', cp.top_merchants,
    'supporting_transactions', cp.supporting_transactions
  )) AS case_packet_json
FROM case_packet cp
""")

packet_for_llm.createOrReplaceTempView("packet_for_llm")


In [43]:
packet_for_llm.write.mode("overwrite").parquet(f"{BASE_PATH}/case_packet_json")