In [0]:
from pyspark.sql import functions as F
from delta.tables import DeltaTable
from pyspark.sql.functions import window
from config.config import *
from config.schemas import *

In [0]:
# Load st_fact_claims_enriched table
df_claims = spark.table(st_fact_claims_enriched)


# Check for duplicate billing
df_duplicates = (
    df_claims
        .groupBy("ClaimID", "ClaimDate")
        .agg(F.count("*").alias("dup_count"))
        .filter(F.col("dup_count") > 1)
)


# Check for too many claims per member in short window
df_burst = (
    df_claims
        .groupBy("MemberID", "ClaimDate", window(F.col("ClaimDate"), "1 day"))
        .count()
        .filter(F.col("count") >= 2)
)  # Threshold set to 2 days (just to have output for the example)


# Check suspecious member cluster (sometimes can be an indicator of a provider fabricating claims)
df_member_provider = df_claims.select("MemberID", "ProviderID", "ClaimDate").distinct()
df_provider_counts = df_member_provider.groupBy("ProviderID").count()
df_exclusive_members = (
    df_member_provider
        .groupBy("MemberID").count()
        .filter(F.col("count")==1)
)  # member has only 1 provider

df_exclusive_with_provider = df_exclusive_members.join(df_member_provider, "MemberID")

In [0]:
# Fill up columns

dup_signals = (df_duplicates
    .select("ClaimID", "ClaimDate")
    .withColumn("signal", F.lit("duplicate_claim"))
    .withColumn("MemberID", F.lit(None).cast("string"))
    .withColumn("ProviderID", F.lit(None).cast("string"))
    .select("MemberID", "ProviderID", "ClaimID", "ClaimDate", "signal")
)


burst_signals = (df_burst
    .select("MemberID", "ClaimDate")
    .withColumn("ClaimID", F.lit(None).cast("string"))
    .withColumn("ProviderID", F.lit(None).cast("string"))
    .withColumn("signal", F.lit("claim_burst"))
    .select("MemberID", "ProviderID", "ClaimID", "ClaimDate", "signal")
)


exclusive_signals = (df_exclusive_with_provider
    .select("ProviderID","MemberID", "ClaimDate")
    .withColumn("ClaimID", F.lit(None).cast("string"))
    .withColumn("signal", F.lit("exclusive_member_provider"))
    .select("MemberID", "ProviderID", "ClaimID", "ClaimDate", "signal")
)

# Union all into a single fraud_signals table
df_signals = (
    dup_signals
        .union(burst_signals)
        .union(exclusive_signals)
)

df_signals.write.format("delta").mode("overwrite").saveAsTable(gt_fraud_signal)


In [0]:
df_signals.display()