In [0]:
#01_fraud_detection notebook
# 03 - Fraud detection
# **Purpose**: Monitor fraud patterns, generate alerts, and create compliance reports
# **Schedule**: Run daily after Gold layer processing completes
from pyspark.sql.functions import col, date_format, add_months, current_date
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType

@udf(FloatType())
def advanced_fraud_score(amount, diagnosis_code, provider_id, claim_frequency, member_risk_history):
    score = 0.0
    
    # Amount-based scoring (progressive)
    try:
        amount_val = float(amount) if amount else 0
        if amount_val > 50000: score += 0.4
        elif amount_val > 25000: score += 0.3
        elif amount_val > 10000: score += 0.2
    except:
        pass
    
    # Diagnosis code patterns
    suspicious_diagnoses = ['E119', 'I10', 'M545', 'R558', 'Z798']
    if diagnosis_code in suspicious_diagnoses:
        score += 0.2
    
    # Provider watchlist
    high_risk_providers = ['PROV999', 'PROV888', 'PROV777']
    if provider_id in high_risk_providers:
        score += 0.3
    
    # Claim frequency
    claim_freq = float(claim_frequency) if claim_frequency else 0
    if claim_freq > 10:
        score += 0.2
    
    # Member history
    member_risk = float(member_risk_history) if member_risk_history else 0
    if member_risk > 0.6:
        score += 0.2
    
    return min(score, 1.0)

def main():
    # =====================================================
    # 1. Read from Gold tables for monitoring
    # =====================================================
    print("Reading gold tables for fraud monitoring...")
    fraud_alerts = spark.table("medisure_jen.gold.gold_realtime_fraud_alerts")
    gold_claims = spark.table("medisure_jen.gold.gold_claims_analytics")
    provider_performance = spark.table("medisure_jen.gold.gold_provider_performance")
    member_summary = spark.table("medisure_jen.gold.gold_member_claims_summary")
    
    # Display current day's fraud alerts for monitoring
    print("Current fraud alerts:")
    display(spark.sql("""
    SELECT * FROM medisure_jen.gold.gold_realtime_fraud_alerts 
    ORDER BY alert_severity DESC, alert_timestamp DESC
    LIMIT 20
    """))
    
    # =====================================================
    # 1.2 Fraud Trends Analysis (Essential for monitoring)
    # =====================================================
    print("Analyzing fraud trends...")
    fraud_trends = spark.sql(f"""
    SELECT 
      processing_month,
      COUNT(*) as total_claims,
      SUM(CASE WHEN fraud_risk_score > 0.7 THEN 1 ELSE 0 END) as high_risk_claims,
      ROUND(AVG(fraud_risk_score), 3) as avg_fraud_score,
      ROUND(SUM(CASE WHEN fraud_risk_score > 0.7 THEN claim_amount ELSE 0 END), 2) as high_risk_amount
    FROM medisure_jen.gold.gold_claims_analytics
    WHERE processing_month >= date_format(add_months(current_date(), -6), 'yyyy-MM')
    GROUP BY processing_month
    ORDER BY processing_month
    """)
    
    print("Fraud trends analysis:")
    display(fraud_trends)
    
    # =====================================================
    # 2. Read additional tables for fraud detection
    # =====================================================
    print("Reading additional tables for enhanced fraud detection...")
    members = spark.table("medisure_jen.silver.silver_members")
    providers = spark.table("medisure_jen.silver.silver_providers")
    member_summary = spark.table("medisure_jen.gold.gold_member_claims_summary")
    
    # =====================================================
    # 3. Apply fraud scoring (detection part)
    # =====================================================
    print("Applying enhanced fraud scoring...")
    alerts_with_enhanced_scoring = spark.sql("""
    SELECT 
      a.claim_id, a.member_id, a.provider_id, a.claim_amount,
      a.diagnosis_code, a.alert_severity, a.alert_reason, a.alert_timestamp,
      split(m.member_name, ' ')[0] as first_name,
      array_join(slice(split(m.member_name, ' '), 2, size(split(m.member_name, ' ')) - 1), ' ') as last_name,
      p.provider_name,
      mem.claims_count as member_claim_frequency,
      mem.member_risk_score as member_risk_history
    FROM medisure_jen.gold.gold_realtime_fraud_alerts a
    LEFT JOIN medisure_jen.silver.silver_members m ON a.member_id = m.member_id
    LEFT JOIN medisure_jen.silver.silver_providers p ON a.provider_id = p.provider_id
    LEFT JOIN medisure_jen.gold.gold_member_claims_summary mem ON a.member_id = mem.member_id 
      AND mem.summary_period = date_format(current_date(), 'yyyy-MM-dd')
    WHERE a.alert_severity IN ('Critical', 'High')
      AND date(a.alert_timestamp) >= date_add(current_date(), -7)
    """)
    
    spark.udf.register("advanced_fraud_score", advanced_fraud_score)
    critical_alerts = alerts_with_enhanced_scoring.withColumn(
        "enhanced_fraud_score", 
        advanced_fraud_score(
            col("claim_amount"), 
            col("diagnosis_code"), 
            col("provider_id"),
            col("member_claim_frequency"),
            col("member_risk_history")
        )
    ).filter(col("enhanced_fraud_score") >= 0.5).orderBy(col("claim_amount").desc())

    
    # =====================================================
    # 4. Save results for next tasks
    # =====================================================
    critical_alerts.createOrReplaceTempView("critical_alerts")
    display(critical_alerts)
    
    # =====================================================
    # 5. Return simple count for routing + monitoring info
    # =====================================================
    fraud_count = critical_alerts.count()
    
    print("="*60)
    print("FRAUD DETECTION SUMMARY")
    print("="*60)
    print(f"Critical Alerts Detected: {fraud_count}")
    print(f"Total Fraud Alerts Today: {fraud_alerts.count()}")
    print(f"High Risk Claims (Trends): {fraud_trends.filter(col('processing_month') == date_format(current_date(), 'yyyy-MM')).select('high_risk_claims').first()[0] if fraud_trends.count() > 0 else 'N/A'}")
    print("="*60)
    
    return fraud_count

# Execute main function
if __name__ == "__main__":
    main()