# Stateful Fraud Detection with applyInPandasWithState

This notebook demonstrates **advanced streaming fraud detection** using PySpark's `applyInPandasWithState` API.

## What is applyInPandasWithState?

`applyInPandasWithState` is a powerful Structured Streaming API that enables:
- **Stateful processing**: Maintain state across micro-batches per key (e.g., user_id, IP address)
- **Pandas UDFs**: Process data using familiar Pandas operations
- **Complex logic**: Implement sophisticated fraud detection rules with historical context
- **Bounded state**: Automatic state cleanup with timeout management

## Fraud Detection Features

This notebook calculates real-time fraud indicators based on:

1. **Transaction Velocity**: Count of transactions in time window
2. **IP Address Changes**: Frequency of IP changes per user
3. **Location Anomalies**: Geographic distance from previous transaction
4. **Amount Patterns**: Statistical anomalies in transaction amounts
5. **Time-based Patterns**: Unusual transaction timing

## Architecture

```
Streaming Source (Rate/Kafka)
    ↓
Feature Generation (TransactionDataGenerator)
    ↓
Group by Key (user_id)
    ↓
applyInPandasWithState
  • Maintain transaction history per user
  • Calculate velocity features
  • Detect location anomalies
  • Track IP changes
  • Compute fraud scores
    ↓
Write to Lakebase PostgreSQL (foreachBatch)
    ↓
Real-time Feature Serving (<10ms latency)
```

## Prerequisites

- Run `00_setup.ipynb` first to provision Lakebase PostgreSQL
- Databricks Runtime 13.0+ (for applyInPandasWithState support)
- Lakebase PostgreSQL instance configured and accessible


In [None]:
# Import required libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.streaming.state import GroupStateTimeout
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from typing import Iterator, Tuple
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print("Imports successful")


In [None]:
# Import utility modules
from utils.data_generator import TransactionDataGenerator
from utils.lakebase_client import LakebaseClient

# Initialize data generator
data_gen = TransactionDataGenerator(spark)

print("Utility modules loaded")


In [None]:
# Configure Lakebase connection
LAKEBASE_CONFIG = {
    "instance_name": "neha-lakebase-demo",
    "database": "databricks_postgres"
}

# Initialize Lakebase client
lakebase = LakebaseClient(**LAKEBASE_CONFIG)

# Test connection
if lakebase.test_connection():
    print("Connected to Lakebase PostgreSQL")
else:
    raise Exception("Failed to connect to Lakebase")


## Step 1: Generate Streaming Transaction Data

Generate synthetic transaction data with fraud indicators.


In [None]:
# Generate streaming transaction data
df_transactions = data_gen.generate_transaction_data(
    rows_per_second=10,
    num_users=100,
    fraud_ratio=0.1
)

print("Schema of streaming transactions:")
df_transactions.printSchema()


## Step 2: Define State and Output Schemas

Define the structure for maintaining state across streaming batches and the output schema.


In [None]:
# Define state schema - what we track for each user
state_schema = StructType([
    StructField("user_id", StringType(), False),
    StructField("transaction_count", IntegerType(), False),
    StructField("last_transaction_time", TimestampType(), False),
    StructField("last_ip_address", StringType(), True),
    StructField("last_latitude", DoubleType(), True),
    StructField("last_longitude", DoubleType(), True),
    StructField("ip_change_count", IntegerType(), False),
    StructField("total_amount", DoubleType(), False),
    StructField("avg_amount", DoubleType(), False),
    StructField("max_amount", DoubleType(), False),
    StructField("transaction_times", ArrayType(TimestampType()), False),
    StructField("recent_amounts", ArrayType(DoubleType()), False)
])

# Define output schema - fraud features per transaction
output_schema = StructType([
    StructField("transaction_id", StringType(), False),
    StructField("user_id", StringType(), False),
    StructField("timestamp", TimestampType(), False),
    StructField("amount", DoubleType(), False),
    StructField("merchant_id", StringType(), False),
    StructField("ip_address", StringType(), False),
    StructField("latitude", DoubleType(), False),
    StructField("longitude", DoubleType(), False),
    
    # Fraud detection features
    StructField("user_transaction_count", IntegerType(), False),
    StructField("transactions_last_hour", IntegerType(), False),
    StructField("transactions_last_10min", IntegerType(), False),
    StructField("ip_changed", IntegerType(), False),
    StructField("ip_change_count_total", IntegerType(), False),
    StructField("distance_from_last_km", DoubleType(), True),
    StructField("velocity_kmh", DoubleType(), True),
    StructField("amount_vs_user_avg_ratio", DoubleType(), True),
    StructField("amount_vs_user_max_ratio", DoubleType(), True),
    StructField("amount_zscore", DoubleType(), True),
    StructField("seconds_since_last_transaction", DoubleType(), True),
    StructField("is_rapid_transaction", IntegerType(), False),
    StructField("is_impossible_travel", IntegerType(), False),
    StructField("is_amount_anomaly", IntegerType(), False),
    StructField("fraud_score", DoubleType(), False),
    StructField("is_fraud_prediction", IntegerType(), False)
])

print("State and output schemas defined")


## Step 3: Implement Stateful Fraud Detection Function

This function processes each user's transactions with maintained state across micro-batches.


In [None]:
def calculate_haversine_distance(lat1, lon1, lat2, lon2):
    """
    Calculate distance between two geographic points in kilometers.
    """
    if pd.isna(lat1) or pd.isna(lon1) or pd.isna(lat2) or pd.isna(lon2):
        return None
    
    R = 6371.0  # Earth radius in kilometers
    
    # Convert to radians
    lat1_rad = np.radians(lat1)
    lon1_rad = np.radians(lon1)
    lat2_rad = np.radians(lat2)
    lon2_rad = np.radians(lon2)
    
    # Haversine formula
    dlat = lat2_rad - lat1_rad
    dlon = lon2_rad - lon1_rad
    a = np.sin(dlat/2)**2 + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon/2)**2
    c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))
    
    return R * c

print("Distance calculation helper function defined")


In [None]:
def detect_fraud(
    key: Tuple[str],
    pdf_iter: Iterator[pd.DataFrame],
    state: GroupState
) -> Iterator[pd.DataFrame]:
    """
    Stateful fraud detection function using applyInPandasWithState.
    
    Args:
        key: Tuple containing the grouping key (user_id)
        pdf_iter: Iterator of Pandas DataFrames for this key
        state: GroupState object to maintain state across batches
    
    Yields:
        pd.DataFrame: Enriched transactions with fraud features
    """
    user_id = key[0]
    
    # Process each micro-batch for this user
    for pdf in pdf_iter:
        if pdf.empty:
            continue
        
        # Sort by timestamp
        pdf = pdf.sort_values('timestamp')
        
        # Initialize or retrieve state
        if state.exists:
            state_row = state.get
            prev_count = state_row['transaction_count']
            prev_last_time = state_row['last_transaction_time']
            prev_ip = state_row['last_ip_address']
            prev_lat = state_row['last_latitude']
            prev_lon = state_row['last_longitude']
            prev_ip_changes = state_row['ip_change_count']
            prev_total_amount = state_row['total_amount']
            prev_avg_amount = state_row['avg_amount']
            prev_max_amount = state_row['max_amount']
            prev_times = state_row['transaction_times']
            prev_amounts = state_row['recent_amounts']
        else:
            prev_count = 0
            prev_last_time = None
            prev_ip = None
            prev_lat = None
            prev_lon = None
            prev_ip_changes = 0
            prev_total_amount = 0.0
            prev_avg_amount = 0.0
            prev_max_amount = 0.0
            prev_times = []
            prev_amounts = []
        
        # Initialize output columns
        results = []
        
        # Process each transaction
        for idx, row in pdf.iterrows():
            current_time = row['timestamp']
            current_ip = row['ip_address']
            current_lat = row['latitude']
            current_lon = row['longitude']
            current_amount = row['amount']
            
            # Update transaction count
            prev_count += 1
            
            # Calculate time-based features
            if prev_last_time is not None:
                time_diff = (current_time - prev_last_time).total_seconds()
            else:
                time_diff = None
            
            # IP change detection
            ip_changed = 0
            if prev_ip is not None and current_ip != prev_ip:
                ip_changed = 1
                prev_ip_changes += 1
            
            # Geographic distance calculation
            distance_km = None
            velocity_kmh = None
            if prev_lat is not None and prev_lon is not None:
                distance_km = calculate_haversine_distance(
                    prev_lat, prev_lon, current_lat, current_lon
                )
                if distance_km is not None and time_diff is not None and time_diff > 0:
                    velocity_kmh = (distance_km / time_diff) * 3600  # km/h
            
            # Amount-based features
            prev_total_amount += current_amount
            prev_avg_amount = prev_total_amount / prev_count
            prev_max_amount = max(prev_max_amount, current_amount)
            
            amount_vs_avg_ratio = current_amount / prev_avg_amount if prev_avg_amount > 0 else 1.0
            amount_vs_max_ratio = current_amount / prev_max_amount if prev_max_amount > 0 else 1.0
            
            # Z-score calculation for amount
            amount_zscore = None
            if len(prev_amounts) >= 3:
                amounts_std = np.std(prev_amounts)
                if amounts_std > 0:
                    amount_zscore = (current_amount - prev_avg_amount) / amounts_std
            
            # Update recent transactions list (keep last 50)
            prev_times.append(current_time)
            prev_amounts.append(current_amount)
            if len(prev_times) > 50:
                prev_times = prev_times[-50:]
                prev_amounts = prev_amounts[-50:]
            
            # Count transactions in time windows
            one_hour_ago = current_time - timedelta(hours=1)
            ten_min_ago = current_time - timedelta(minutes=10)
            
            trans_last_hour = sum(1 for t in prev_times if t >= one_hour_ago)
            trans_last_10min = sum(1 for t in prev_times if t >= ten_min_ago)
            
            # Fraud indicators
            is_rapid = 1 if trans_last_10min >= 5 else 0
            is_impossible_travel = 1 if velocity_kmh is not None and velocity_kmh > 800 else 0
            is_amount_anomaly = 1 if amount_zscore is not None and abs(amount_zscore) > 3 else 0
            
            # Calculate fraud score (0-100)
            fraud_score = 0.0
            if is_rapid:
                fraud_score += 20
            if is_impossible_travel:
                fraud_score += 30
            if is_amount_anomaly:
                fraud_score += 25
            if prev_ip_changes >= 5:
                fraud_score += 15
            if trans_last_hour >= 10:
                fraud_score += 10
            fraud_score = min(fraud_score, 100.0)
            
            # Fraud prediction (threshold at 50)
            is_fraud_pred = 1 if fraud_score >= 50 else 0
            
            # Append result
            results.append({
                'transaction_id': row['transaction_id'],
                'user_id': user_id,
                'timestamp': current_time,
                'amount': current_amount,
                'merchant_id': row['merchant_id'],
                'ip_address': current_ip,
                'latitude': current_lat,
                'longitude': current_lon,
                'user_transaction_count': prev_count,
                'transactions_last_hour': trans_last_hour,
                'transactions_last_10min': trans_last_10min,
                'ip_changed': ip_changed,
                'ip_change_count_total': prev_ip_changes,
                'distance_from_last_km': distance_km,
                'velocity_kmh': velocity_kmh,
                'amount_vs_user_avg_ratio': amount_vs_avg_ratio,
                'amount_vs_user_max_ratio': amount_vs_max_ratio,
                'amount_zscore': amount_zscore,
                'seconds_since_last_transaction': time_diff,
                'is_rapid_transaction': is_rapid,
                'is_impossible_travel': is_impossible_travel,
                'is_amount_anomaly': is_amount_anomaly,
                'fraud_score': fraud_score,
                'is_fraud_prediction': is_fraud_pred
            })
            
            # Update state for next transaction
            prev_last_time = current_time
            prev_ip = current_ip
            prev_lat = current_lat
            prev_lon = current_lon
        
        # Update state
        new_state = pd.Series({
            'user_id': user_id,
            'transaction_count': prev_count,
            'last_transaction_time': prev_last_time,
            'last_ip_address': prev_ip,
            'last_latitude': prev_lat,
            'last_longitude': prev_lon,
            'ip_change_count': prev_ip_changes,
            'total_amount': prev_total_amount,
            'avg_amount': prev_avg_amount,
            'max_amount': prev_max_amount,
            'transaction_times': prev_times,
            'recent_amounts': prev_amounts
        })
        state.update(new_state)
        state.setTimeoutDuration("1 hour")
        
        # Yield results
        if results:
            yield pd.DataFrame(results)

print("Fraud detection function defined")


## Step 4: Apply Stateful Processing

Apply the fraud detection function to the streaming data using `applyInPandasWithState`.


In [None]:
# Apply stateful fraud detection
df_with_fraud_features = df_transactions \
    .withWatermark("timestamp", "10 minutes") \
    .groupBy("user_id") \
    .applyInPandasWithState(
        detect_fraud,
        output_schema,
        state_schema,
        "append",
        GroupStateTimeout.ProcessingTimeTimeout
    )

print("Stateful processing configured")
print("\nOutput schema:")
df_with_fraud_features.printSchema()


## Step 5: Create Fraud Features Table in Lakebase

Create the PostgreSQL table to store fraud detection features.


In [None]:
# Create fraud features table in Lakebase
create_table_sql = """
CREATE TABLE IF NOT EXISTS fraud_features (
    transaction_id VARCHAR(100) PRIMARY KEY,
    user_id VARCHAR(100) NOT NULL,
    timestamp TIMESTAMP NOT NULL,
    amount DOUBLE PRECISION NOT NULL,
    merchant_id VARCHAR(100),
    ip_address VARCHAR(50),
    latitude DOUBLE PRECISION,
    longitude DOUBLE PRECISION,
    
    -- Velocity features
    user_transaction_count INTEGER,
    transactions_last_hour INTEGER,
    transactions_last_10min INTEGER,
    
    -- IP features
    ip_changed INTEGER,
    ip_change_count_total INTEGER,
    
    -- Location features
    distance_from_last_km DOUBLE PRECISION,
    velocity_kmh DOUBLE PRECISION,
    
    -- Amount features
    amount_vs_user_avg_ratio DOUBLE PRECISION,
    amount_vs_user_max_ratio DOUBLE PRECISION,
    amount_zscore DOUBLE PRECISION,
    
    -- Time features
    seconds_since_last_transaction DOUBLE PRECISION,
    
    -- Fraud indicators
    is_rapid_transaction INTEGER,
    is_impossible_travel INTEGER,
    is_amount_anomaly INTEGER,
    fraud_score DOUBLE PRECISION,
    is_fraud_prediction INTEGER,
    
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

-- Indexes for fast queries
CREATE INDEX IF NOT EXISTS idx_fraud_user_id ON fraud_features(user_id);
CREATE INDEX IF NOT EXISTS idx_fraud_timestamp ON fraud_features(timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_fraud_score ON fraud_features(fraud_score DESC);
CREATE INDEX IF NOT EXISTS idx_fraud_prediction ON fraud_features(is_fraud_prediction);
"""

# Execute table creation
with lakebase.get_connection() as conn:
    with conn.cursor() as cur:
        cur.execute(create_table_sql)
    conn.commit()

print("fraud_features table created in Lakebase PostgreSQL")


## Step 6: Write Fraud Features to Lakebase

Stream fraud features to Lakebase PostgreSQL for real-time serving.


In [None]:
# Define foreachBatch function to write to Lakebase
def write_to_lakebase(batch_df, batch_id):
    """
    Write each micro-batch to Lakebase PostgreSQL.
    """
    if batch_df.isEmpty():
        return
    
    logger.info(f"Processing batch {batch_id} with {batch_df.count()} rows")
    
    # Write to Lakebase using client
    lakebase.write_streaming_batch(batch_df, "fraud_features")
    
    logger.info(f"Batch {batch_id} written to Lakebase")


# Start streaming query to Lakebase
query_lakebase = df_with_fraud_features \
    .writeStream \
    .outputMode("append") \
    .foreachBatch(write_to_lakebase) \
    .option("checkpointLocation", "/tmp/fraud_detection_checkpoint") \
    .trigger(processingTime="10 seconds") \
    .start()

print("Streaming to Lakebase PostgreSQL...")
print(f"Query ID: {query_lakebase.id}")
print(f"Status: {query_lakebase.status}")


## Step 7: Monitor and Query Fraud Features

Query fraud features from Lakebase for real-time insights.


In [None]:
# Wait for data to process
import time
print("Waiting 30 seconds for data to process...")
time.sleep(30)

# Query top users by fraud score
query_results = """
SELECT 
    user_id,
    COUNT(*) as total_transactions,
    SUM(is_fraud_prediction) as predicted_frauds,
    AVG(fraud_score) as avg_fraud_score,
    MAX(fraud_score) as max_fraud_score,
    SUM(is_rapid_transaction) as rapid_transactions,
    SUM(is_impossible_travel) as impossible_travels,
    SUM(is_amount_anomaly) as amount_anomalies
FROM fraud_features
GROUP BY user_id
ORDER BY predicted_frauds DESC
LIMIT 10
"""

with lakebase.get_connection() as conn:
    result_df = pd.read_sql(query_results, conn)

print("\nTop 10 Users by Predicted Fraud Count:")
display(result_df)


In [None]:
# Query high-risk transactions
high_risk_query = """
SELECT 
    transaction_id,
    user_id,
    timestamp,
    amount,
    fraud_score,
    is_rapid_transaction,
    is_impossible_travel,
    is_amount_anomaly,
    transactions_last_10min,
    velocity_kmh
FROM fraud_features
WHERE fraud_score >= 50
ORDER BY fraud_score DESC, timestamp DESC
LIMIT 20
"""

with lakebase.get_connection() as conn:
    high_risk_df = pd.read_sql(high_risk_query, conn)

print("\nHigh-Risk Transactions (fraud_score >= 50):")
display(high_risk_df)


In [None]:
# Real-time feature serving example - Get features for specific user
def get_user_fraud_features(user_id: str):
    """
    Get real-time fraud features for a user from Lakebase PostgreSQL.
    Query latency: <10ms
    """
    query = """
    SELECT 
        transaction_id,
        timestamp,
        amount,
        user_transaction_count,
        transactions_last_hour,
        transactions_last_10min,
        fraud_score,
        is_fraud_prediction
    FROM fraud_features
    WHERE user_id = %s
    ORDER BY timestamp DESC
    LIMIT 10
    """
    
    with lakebase.get_connection() as conn:
        df = pd.read_sql(query, conn, params=(user_id,))
    
    return df

# Example: Get features for a user
sample_user = "user_001"
user_features = get_user_fraud_features(sample_user)

print(f"\nRecent transactions for {sample_user}:")
display(user_features)


## Step 8: Stop Streaming Queries

When done, stop the streaming query.


In [None]:
# Stop streaming query
if query_lakebase.isActive:
    query_lakebase.stop()
    print("Streaming query stopped")

print("\nAll streaming queries stopped successfully")


## Summary

This notebook demonstrated advanced streaming fraud detection using `applyInPandasWithState`:

### Key Capabilities

1. **Stateful Streaming**: Using `applyInPandasWithState` for complex fraud detection logic
2. **State Management**: Maintaining user transaction history across micro-batches
3. **Fraud Features**:
   - Transaction velocity (counts in time windows)
   - IP address change tracking
   - Geographic anomalies (impossible travel detection)
   - Amount-based anomalies (z-score, ratios)
   - Composite fraud scores (0-100)
4. **Real-time Serving**: Writing features to Lakebase PostgreSQL for <10ms query latency
5. **Production Patterns**: Proper state timeout, watermarking, and checkpointing

### Key Benefits of applyInPandasWithState

- **Stateful**: Maintain context across streaming batches per user
- **Flexible**: Implement any logic using Python/Pandas
- **Scalable**: Parallel processing per partition key
- **Bounded**: Automatic state cleanup with timeouts (1 hour in this example)
- **Fault-tolerant**: State stored in checkpoints

### Fraud Detection Logic

**Fraud Score Calculation (0-100 points):**
- Rapid transactions (5+ in 10 min): +20 points
- Impossible travel (>800 km/h): +30 points  
- Amount anomaly (z-score > 3): +25 points
- Frequent IP changes (5+ total): +15 points
- High velocity (10+ in 1 hour): +10 points

**Fraud Prediction:** Score >= 50 triggers fraud flag

### Real-time Architecture

```
Streaming Transactions
    ↓
applyInPandasWithState (per user_id)
  • Track transaction history
  • Calculate velocity features
  • Detect location anomalies
  • Monitor IP changes
  • Compute fraud scores
    ↓
Lakebase PostgreSQL (foreachBatch)
    ↓
Real-time Queries (<10ms latency)
```

### Next Steps

- Integrate with ML models for enhanced fraud scoring
- Add alerting for high-risk transactions
- Connect to downstream systems (dashboards, notification services)
- Tune state timeout and processing trigger intervals
- Add more sophisticated fraud detection rules (device fingerprinting, network analysis)
- Implement A/B testing for fraud detection thresholds
