In [None]:
# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()

In [None]:
import logging
from colorama import Fore, Style, init

# Initialize colorama for cross-platform color support
init(autoreset=True)

# Color mapping by log level
LOG_COLORS = {
    logging.DEBUG: Fore.CYAN,
    logging.INFO: Fore.GREEN,
    logging.WARNING: Fore.YELLOW,
    logging.ERROR: Fore.RED,
    logging.CRITICAL: Fore.MAGENTA
}

class ColorFormatter(logging.Formatter):
    def format(self, record):
        color = LOG_COLORS.get(record.levelno, "")
        message = super().format(record)
        return f"{color}{message}{Style.RESET_ALL}"

def get_logger(name="AppLogger", level=logging.INFO):
    logger = logging.getLogger(name)
    logger.setLevel(level)

    if not logger.handlers:
        handler = logging.StreamHandler()
        formatter = ColorFormatter("[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%H:%M:%S")
        handler.setFormatter(formatter)
        logger.addHandler(handler)

    return logger

logger = get_logger("DemoLogger", logging.DEBUG)

In [None]:
# Import python packages
import os
from snowflake.snowpark import Session
import pandas as pd

session = get_active_session()

table_name = "ML_CREDIT.RAW_DATA.INSURANCE_CLAIMS"
claims = session.table(table_name)
claims.show()


logger.info(f"Total de registros: {claims.count()}")

In [None]:
from snowflake.snowpark.functions import avg, col,when, count, lit  # noqa: E402
from snowflake.snowpark.functions import max as max_  # noqa: E402
from snowflake.snowpark.functions import sum as sum_  # noqa: E402


# 0. Fix AGE (likely missing value)
claims = claims.with_column(
    "AGE",
    when((col("AGE") <= 0) | (col("AGE") > 120), lit(None)).otherwise(
        col("AGE").cast("int")
    ),
)
# 1. TEMPORAL FEATURES - Critical for fraud detection
temporal_features = claims.select(
    col("POLICYNUMBER"),
    col("MONTH"),
    col("WEEKOFMONTH"),
    # Convert categorical time ranges to numeric
    when(col("DAYS_POLICY_CLAIM") == "more than 30", 35)
    .when(col("DAYS_POLICY_CLAIM") == "15 to 30", 22)
    .when(col("DAYS_POLICY_CLAIM") == "8 to 15", 11)
    .when(col("DAYS_POLICY_CLAIM") == "1 to 7", 4)
    .otherwise(0)
    .alias("DAYS_TO_CLAIM_NUM"),
    when(col("DAYS_POLICY_ACCIDENT") == "more than 30", 35)
    .when(col("DAYS_POLICY_ACCIDENT") == "15 to 30", 22)
    .when(col("DAYS_POLICY_ACCIDENT") == "8 to 15", 11)
    .when(col("DAYS_POLICY_ACCIDENT") == "1 to 7", 4)
    .otherwise(0)
    .alias("POLICY_AGE_AT_ACCIDENT"),
    # Suspicious if claim month differs from accident month
    when(col("MONTH") != col("MONTHCLAIMED"), 1)
    .otherwise(0)
    .alias("MONTH_MISMATCH"),
    # Suspicious if day of week differs
    when(col("DAYOFWEEK") != col("DAYOFWEEKCLAIMED"), 1)
    .otherwise(0)
    .alias("DAY_MISMATCH"),
)


# 2. VEHICLE FEATURES
vehicle_features = claims.select(
    col("POLICYNUMBER"),
    col("MONTH"),
    col("WEEKOFMONTH"),
    # Vehicle age numeric
    when(col("AGEOFVEHICLE") == "new", 0)
    .when(col("AGEOFVEHICLE") == "1 year", 1)
    .when(col("AGEOFVEHICLE") == "2 years", 2)
    .when(col("AGEOFVEHICLE") == "3 years", 3)
    .when(col("AGEOFVEHICLE") == "4 years", 4)
    .when(col("AGEOFVEHICLE") == "5 years", 5)
    .when(col("AGEOFVEHICLE") == "6 years", 6)
    .when(col("AGEOFVEHICLE") == "7 years", 7)
    .when(col("AGEOFVEHICLE") == "more than 7", 9)
    .otherwise(None)
    .alias("VEHICLE_AGE_NUM"),
    # Vehicle price midpoint
    when(col("VEHICLEPRICE") == "less than 20000", 15000)
    .when(col("VEHICLEPRICE") == "20000 to 29000", 24500)
    .when(col("VEHICLEPRICE") == "30000 to 39000", 34500)
    .when(col("VEHICLEPRICE") == "40000 to 59000", 49500)
    .when(col("VEHICLEPRICE") == "60000 to 69000", 64500)
    .when(col("VEHICLEPRICE") == "more than 69000", 80000)
    .otherwise(None)
    .alias("VEHICLE_PRICE_NUM"),
    # Risk factors based on insurance industry data
    when(col("VEHICLECATEGORY") == "Sport", 3)
    .when(col("VEHICLECATEGORY") == "Utility", 2)
    .when(col("VEHICLECATEGORY") == "Sedan", 1)
    .otherwise(1)
    .alias("VEHICLE_RISK"),
)

# 3. DEMOGRAPHIC FEATURES
demographic_features = claims.select(
    col("POLICYNUMBER"),
    col("MONTH"),
    col("WEEKOFMONTH"),
    # Age risk (young and elderly are higher risk)
    when(col("AGE") < 25, 3)
    .when(col("AGE").between(25, 35), 2)
    .when(col("AGE").between(36, 60), 1)
    .when(col("AGE") > 60, 2)
    .otherwise(2)
    .alias("AGE_RISK"),
    # Binary encodings
    when(col("SEX") == "Male", 1).otherwise(0).alias("IS_MALE"),
    when(col("MARITALSTATUS") == "Single", 1).otherwise(0).alias("IS_SINGLE"),
    # Driver rating (already numeric)
    col("DRIVERRATING"),
    # Policyholder age midpoint
    when(col("AGEOFPOLICYHOLDER") == "16 to 17", 16.5)
    .when(col("AGEOFPOLICYHOLDER") == "18 to 20", 19)
    .when(col("AGEOFPOLICYHOLDER") == "21 to 25", 23)
    .when(col("AGEOFPOLICYHOLDER") == "26 to 30", 28)
    .when(col("AGEOFPOLICYHOLDER") == "31 to 35", 33)
    .when(col("AGEOFPOLICYHOLDER") == "36 to 40", 38)
    .when(col("AGEOFPOLICYHOLDER") == "41 to 50", 45.5)
    .when(col("AGEOFPOLICYHOLDER") == "51 to 65", 58)
    .when(col("AGEOFPOLICYHOLDER") == "over 65", 72)
    .otherwise(None)
    .alias("POLICYHOLDER_AGE"),
)

 # 4. CLAIM RISK FACTORS - Most important for fraud detection

claim_risk_features = claims.select(
    col("POLICYNUMBER"),
    col("MONTH"),
    col("WEEKOFMONTH"),
    # Fault
    when(col("FAULT") == "Policy Holder", 1)
    .otherwise(0)
    .alias("POLICYHOLDER_FAULT"),
    # Deductible
    col("DEDUCTIBLE"),
    # Past claims (strong fraud indicator)
    when(col("PASTNUMBEROFCLAIMS") == "none", 0)
    .when(col("PASTNUMBEROFCLAIMS") == "1", 1)
    .when(col("PASTNUMBEROFCLAIMS") == "2 to 4", 3)
    .when(col("PASTNUMBEROFCLAIMS") == "more than 4", 6)
    .otherwise(0)
    .alias("PAST_CLAIMS"),
    # Documentation flags (strong fraud indicators)
    when(col("POLICEREPORTFILED") == "No", 1)
    .otherwise(0)
    .alias("NO_POLICE_REPORT"),
    when(col("WITNESSPRESENT") == "No", 1).otherwise(0).alias("NO_WITNESS"),
    # Claim supplements
    when(col("NUMBEROFSUPPLIMENTS") == "none", 0)
    .when(col("NUMBEROFSUPPLIMENTS") == "1 to 2", 1.5)
    .when(col("NUMBEROFSUPPLIMENTS") == "3 to 5", 4)
    .when(col("NUMBEROFSUPPLIMENTS") == "more than 5", 7)
    .otherwise(0)
    .alias("SUPPLEMENTS"),
    # Address change (fraud red flag)
    when(col("ADDRESSCHANGE_CLAIM") == "1 year", 1)
    .when(col("ADDRESSCHANGE_CLAIM") == "2 to 3 years", 0.5)
    .when(col("ADDRESSCHANGE_CLAIM") == "4 to 8 years", 0.2)
    .when(col("ADDRESSCHANGE_CLAIM") == "no change", 0)
    .otherwise(0)
    .alias("ADDRESS_CHANGE"),
    # Location and agent
    when(col("ACCIDENTAREA") == "Urban", 1).otherwise(0).alias("URBAN_ACCIDENT"),
    when(col("AGENTTYPE") == "External", 1).otherwise(0).alias("EXTERNAL_AGENT"),
    # Target
    col("FRAUDFOUND_P").alias("IS_FRAUD"),
)

# 5. POLICY AGGREGATIONS - Historical behavior

policy_agg = claims.group_by("POLICYNUMBER").agg(
    count(col("POLICYNUMBER")).alias("TOTAL_CLAIMS_POLICY"),
    sum_(col("FRAUDFOUND_P")).alias("FRAUD_COUNT_POLICY"),
    avg(col("DEDUCTIBLE")).alias("AVG_DEDUCTIBLE_POLICY"),
    avg(col("DRIVERRATING")).alias("AVG_RATING_POLICY"),
)


In [None]:
# ======================================================================
# JOIN ALL FEATURES
# ======================================================================
logger.info("Joining all feature sets")

# Use USING clause to avoid duplicate columns
all_features = claim_risk_features.join(
    temporal_features, ["POLICYNUMBER", "MONTH", "WEEKOFMONTH"], "left"
)

all_features = all_features.join(
    vehicle_features, ["POLICYNUMBER", "MONTH", "WEEKOFMONTH"], "left"
)

all_features = all_features.join(
    demographic_features, ["POLICYNUMBER", "MONTH", "WEEKOFMONTH"], "left"
)

all_features = all_features.join(policy_agg, "POLICYNUMBER", "left")

# ======================================================================
# INTERACTION FEATURES - Capture complex fraud patterns
# ======================================================================
logger.info("Creating interaction features")

final_features = all_features.select(
    "*",
    # Quick claim + no police report = very suspicious
    (col("DAYS_TO_CLAIM_NUM") * col("NO_POLICE_REPORT")).alias("QUICK_NO_POLICE"),
    # Vehicle depreciation vs price
    (col("VEHICLE_AGE_NUM") * col("VEHICLE_PRICE_NUM") / 10000).alias(
        "VEHICLE_DEPRECIATION"
    ),
    # External agent in urban area
    (col("EXTERNAL_AGENT") * col("URBAN_ACCIDENT")).alias("EXTERNAL_URBAN"),
    # Address change with past claims
    (col("ADDRESS_CHANGE") * col("PAST_CLAIMS")).alias("ADDRESS_PAST_CLAIMS"),
    # Young driver with sport vehicle
    (when(col("AGE_RISK") == 3, 1).otherwise(0) * col("VEHICLE_RISK")).alias(
        "YOUNG_SPORT"
    ),
    # New policy with claim
    (
        when(col("POLICY_AGE_AT_ACCIDENT") < 15, 1).otherwise(0) * col("DEDUCTIBLE")
    ).alias("NEW_POLICY_CLAIM"),
    # No documentation (police + witness)
    (col("NO_POLICE_REPORT") * col("NO_WITNESS")).alias("NO_DOCUMENTATION"),
)

# ======================================================================
# CLASS IMBALANCE HANDLING
# ======================================================================
logger.info("Calculating class weights for imbalanced data")

fraud_stats = session.sql(
    f"""
    SELECT
        SUM(CASE WHEN FRAUDFOUND_P = 1 THEN 1 ELSE 0 END) AS FRAUD_COUNT,
        COUNT(*) AS TOTAL_COUNT
    FROM ML_CREDIT.RAW_DATA.INSURANCE_CLAIMS
"""
).collect()[0]

fraud_count = fraud_stats["FRAUD_COUNT"]
total_count = fraud_stats["TOTAL_COUNT"]
fraud_ratio = fraud_count / total_count

logger.info(f"Fraud ratio: {fraud_ratio:.4f} ({fraud_count}/{total_count})")

# Add sample weights (inverse of class frequency)
weighted_features = final_features.select(
    "*",
    when(col("IS_FRAUD") == 1, (1 - fraud_ratio) / fraud_ratio)
    .otherwise(1.0)
    .alias("SAMPLE_WEIGHT"),
)

# ======================================================================
# REGISTER FEATURE VIEWS
# ======================================================================
logger.info("Registering feature views in Feature Store")
weighted_features.show()

In [None]:
from snowflake.snowpark.functions import col, when, mean, median

logger.info("Imputing null values in numeric features")

# Impute nulls in DEDUCTIBLE with median
median_deductible = weighted_features.select(median(col("DEDUCTIBLE"))).collect()[0][0]
weighted_features = weighted_features.with_column(
    "DEDUCTIBLE",
    when(col("DEDUCTIBLE").is_null(), median_deductible).otherwise(col("DEDUCTIBLE"))
)

# Impute nulls in VEHICLE_PRICE_NUM with mean
mean_vehicle_price = weighted_features.select(mean(col("VEHICLE_PRICE_NUM"))).collect()[0][0]
weighted_features = weighted_features.with_column(
    "VEHICLE_PRICE_NUM",
    when(col("VEHICLE_PRICE_NUM").is_null(), mean_vehicle_price).otherwise(col("VEHICLE_PRICE_NUM"))
)

logger.info("Imputation completed")
weighted_features.show()


In [None]:
# ======================================================================
# MODEL TRAINING - LOGISTIC REGRESSION (Luis Vejarano)
# ======================================================================
# Assuming weighted_features is your Snowpark DataFrame from previous cells
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
import joblib

# Convert to pandas for training
training_data = weighted_features.to_pandas()

# Prepare features and target
exclude_cols = ["POLICYNUMBER", "MONTH", "WEEKOFMONTH", "SAMPLE_WEIGHT"]
feature_cols = [c for c in training_data.columns if c not in exclude_cols and c != "IS_FRAUD"]

# For now I am not going to do the stuff related to train/test/validation + cross validation ... (but you can guys add it)
X = training_data[feature_cols]
y = training_data["IS_FRAUD"]
sample_weight = training_data["SAMPLE_WEIGHT"] if "SAMPLE_WEIGHT" in training_data.columns else None

# Calculate class imbalance ratio for class_weight
fraud_count = y.sum()
non_fraud_count = len(y) - fraud_count
class_weight_ratio = non_fraud_count / fraud_count
logger.info(f"Calculated class_weight ratio: {class_weight_ratio:.2f}")

# Feature scaling (critical for Logistic Regression)
logger.info("Scaling features with StandardScaler...")
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Configure Logistic Regression with class imbalance handling
logger.info("Starting model training with Logistic Regression...")
model = LogisticRegression(
    max_iter=1000,                                    # Maximum iterations for convergence
    class_weight={0: 1.0, 1: class_weight_ratio},    # Class imbalance handling
    random_state=42,                                  # Reproducibility
    solver='lbfgs',                                   # Optimization algorithm
    C=1.0,                                           # Regularization strength (inverse)
    verbose=0,                                       # No verbose output
    n_jobs=1                                        # Use all CPU cores
)

# Train the model
model.fit(X_scaled, y, sample_weight=sample_weight)

# Comprehensive evaluation metrics
y_pred_proba = model.predict_proba(X_scaled)[:, 1]
y_pred = model.predict(X_scaled)

# Calculate multiple metrics
auc = roc_auc_score(y, y_pred_proba)
accuracy = accuracy_score(y, y_pred)
precision = precision_score(y, y_pred)
recall = recall_score(y, y_pred)
f1 = f1_score(y, y_pred)

# Log all metrics
logger.info(f"Training AUC: {auc:.4f}")
logger.info(f"Accuracy: {accuracy:.4f}")
logger.info(f"Precision: {precision:.4f}")
logger.info(f"Recall: {recall:.4f}")
logger.info(f"F1-Score: {f1:.4f}")
logger.info("Model trained successfully with Logistic Regression and class imbalance handling")

In [None]:
# ======================================================================
# SAVE METRICS TO TABLE - Luis Vejarano
# ======================================================================
import json
from datetime import datetime

metrics_table = "ML_CREDIT.ANALYTICS.MODEL_METRICS"

# Create metrics table if it doesn't exist
session.sql(f"""
CREATE TABLE IF NOT EXISTS {metrics_table} (
    MODEL_NAME STRING,
    MODEL_VERSION STRING,
    FRAMEWORK STRING,
    TRAINING_DATE TIMESTAMP_NTZ,
    AUC FLOAT,
    ACCURACY FLOAT,
    PRECISION FLOAT,
    RECALL FLOAT,
    F1_SCORE FLOAT,
    CREATED_BY STRING
)
""").collect()

# Insert metrics
insert_sql = f"""
INSERT INTO {metrics_table}
VALUES (?, ?, ?, CURRENT_TIMESTAMP(), ?, ?, ?, ?, ?, ?)
"""

params = [
    "vehicle_insurance_fraud_detector",  # MODEL_NAME
    "1.0.2_logistic",                    # MODEL_VERSION
    "logistic_regression",               # FRAMEWORK
    float(auc),                          # AUC
    float(accuracy),                     # ACCURACY
    float(precision),                    # PRECISION
    float(recall),                       # RECALL
    float(f1),                           # F1_SCORE
    "Luis Vejarano"                      # CREATED_BY
]

session.sql(insert_sql, params=params).collect()

logger.info("Metrics saved to ML_CREDIT.ANALYTICS.MODEL_METRICS")

# Show saved metrics
session.sql(f"SELECT * FROM {metrics_table} ORDER BY TRAINING_DATE DESC LIMIT 5").show()

In [None]:
# ======================================================================
# REGISTER MODEL - Save model artifact to Snowflake Stage (Luis Bejarano)
# ======================================================================
# Prereqs: model and scaler are already trained in previous cells

import os, json, tempfile, joblib
from snowflake.snowpark.context import get_active_session

session = get_active_session()
database = "ML_CREDIT"
stage = f"{database}.MODELS.ML_MODELS_STAGE"
artifact_filename = "fraud_detector_logistic_v1_0_2.joblib"

# 1) Save model and scaler to a local file
logger.info("Saving model and scaler to local file...")
with tempfile.TemporaryDirectory() as tmpdir:
    local_path = os.path.join(tmpdir, artifact_filename)
    # Save both model and scaler together (needed for predictions)
    joblib.dump({'model': model, 'scaler': scaler}, local_path)
    
    # 2) Upload to Snowflake Stage
    logger.info(f"Uploading model to Snowflake Stage: {stage}")
    session.file.put(
        local_path,
        f"@{stage}",
        overwrite=True,
        auto_compress=False
    )

artifact_uri = f"@{stage}/{artifact_filename}"

# 3) Create registry table for Logistic Regression models
registry_table = f"{database}.MODELS.MODEL_REGISTRY_LOGISTIC"
session.sql(f"""
CREATE TABLE IF NOT EXISTS {registry_table} (
  MODEL_NAME STRING,
  MODEL_VER  STRING,
  STAGE_NAME STRING,
  ARTIFACT_URI STRING,
  FRAMEWORK STRING,
  METRICS VARIANT,
  CREATED_AT TIMESTAMP_NTZ,
  CREATED_BY STRING
)
""").collect()

# 4) Prepare metrics to save with model
metrics_json = {
    "auc": float(auc),
    "accuracy": float(accuracy),
    "precision": float(precision),
    "recall": float(recall),
    "f1": float(f1)
}

# 5) Insert model registry record
insert_sql = f"""
INSERT INTO {registry_table}
  (MODEL_NAME, MODEL_VER, STAGE_NAME, ARTIFACT_URI, FRAMEWORK, METRICS, CREATED_AT, CREATED_BY)
SELECT ?, ?, ?, ?, ?, PARSE_JSON(?), CURRENT_TIMESTAMP(), ?
"""

params = [
    "vehicle_insurance_fraud_detector",  # MODEL_NAME
    "1.0.2_logistic",                    # MODEL_VER
    "DEV",                               # STAGE_NAME
    artifact_uri,                        # ARTIFACT_URI
    "logistic_regression",               # FRAMEWORK
    json.dumps(metrics_json),            # METRICS
    "Luis Bejarano"                      # CREATED_BY
]

session.sql(insert_sql, params=params).collect()

logger.info(f"✓ Model saved to: {artifact_uri}")
logger.info(f"✓ Model registered in: {registry_table}")

# 6) Verify registration
logger.info("Latest registered model:")
session.sql(f"SELECT * FROM {registry_table} ORDER BY CREATED_AT DESC LIMIT 1").show()