In [1]:
# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()

In [2]:
import logging
from colorama import Fore, Style, init

# Initialize colorama for cross-platform color support
init(autoreset=True)

# Color mapping by log level
LOG_COLORS = {
    logging.DEBUG: Fore.CYAN,
    logging.INFO: Fore.GREEN,
    logging.WARNING: Fore.YELLOW,
    logging.ERROR: Fore.RED,
    logging.CRITICAL: Fore.MAGENTA
}

class ColorFormatter(logging.Formatter):
    def format(self, record):
        color = LOG_COLORS.get(record.levelno, "")
        message = super().format(record)
        return f"{color}{message}{Style.RESET_ALL}"

def get_logger(name="AppLogger", level=logging.INFO):
    logger = logging.getLogger(name)
    logger.setLevel(level)

    if not logger.handlers:
        handler = logging.StreamHandler()
        formatter = ColorFormatter("[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%H:%M:%S")
        handler.setFormatter(formatter)
        logger.addHandler(handler)

    return logger

logger = get_logger("DemoLogger", logging.DEBUG)

2

# Extract

In [None]:
# Import python packages
import os
from snowflake.snowpark import Session
import pandas as pd

session = get_active_session()

table_name = "ML_CREDIT.RAW_DATA.INSURANCE_CLAIMS"
claims = session.table(table_name)
claims.show()


logger.info(f"Total records: {claims.count()}")

# Transform

In [None]:
from snowflake.snowpark.functions import avg, col,when, count, lit  # noqa: E402
from snowflake.snowpark.functions import max as max_  # noqa: E402
from snowflake.snowpark.functions import sum as sum_  # noqa: E402


# 0. Fix AGE (likely missing value)
claims = claims.with_column(
    "AGE",
    when((col("AGE") <= 0) | (col("AGE") > 120), lit(None)).otherwise(
        col("AGE").cast("int")
    ),
)
# 1. TEMPORAL FEATURES - Critical for fraud detection
temporal_features = claims.select(
    col("POLICYNUMBER"),
    col("MONTH"),
    col("WEEKOFMONTH"),
    # Convert categorical time ranges to numeric
    when(col("DAYS_POLICY_CLAIM") == "more than 30", 35)
    .when(col("DAYS_POLICY_CLAIM") == "15 to 30", 22)
    .when(col("DAYS_POLICY_CLAIM") == "8 to 15", 11)
    .when(col("DAYS_POLICY_CLAIM") == "1 to 7", 4)
    .otherwise(0)
    .alias("DAYS_TO_CLAIM_NUM"),
    when(col("DAYS_POLICY_ACCIDENT") == "more than 30", 35)
    .when(col("DAYS_POLICY_ACCIDENT") == "15 to 30", 22)
    .when(col("DAYS_POLICY_ACCIDENT") == "8 to 15", 11)
    .when(col("DAYS_POLICY_ACCIDENT") == "1 to 7", 4)
    .otherwise(0)
    .alias("POLICY_AGE_AT_ACCIDENT"),
    # Suspicious if claim month differs from accident month
    when(col("MONTH") != col("MONTHCLAIMED"), 1)
    .otherwise(0)
    .alias("MONTH_MISMATCH"),
    # Suspicious if day of week differs
    when(col("DAYOFWEEK") != col("DAYOFWEEKCLAIMED"), 1)
    .otherwise(0)
    .alias("DAY_MISMATCH"),
)


# 2. VEHICLE FEATURES
vehicle_features = claims.select(
    col("POLICYNUMBER"),
    col("MONTH"),
    col("WEEKOFMONTH"),
    # Vehicle age numeric
    when(col("AGEOFVEHICLE") == "new", 0)
    .when(col("AGEOFVEHICLE") == "1 year", 1)
    .when(col("AGEOFVEHICLE") == "2 years", 2)
    .when(col("AGEOFVEHICLE") == "3 years", 3)
    .when(col("AGEOFVEHICLE") == "4 years", 4)
    .when(col("AGEOFVEHICLE") == "5 years", 5)
    .when(col("AGEOFVEHICLE") == "6 years", 6)
    .when(col("AGEOFVEHICLE") == "7 years", 7)
    .when(col("AGEOFVEHICLE") == "more than 7", 9)
    .otherwise(None)
    .alias("VEHICLE_AGE_NUM"),
    # Vehicle price midpoint
    when(col("VEHICLEPRICE") == "less than 20000", 15000)
    .when(col("VEHICLEPRICE") == "20000 to 29000", 24500)
    .when(col("VEHICLEPRICE") == "30000 to 39000", 34500)
    .when(col("VEHICLEPRICE") == "40000 to 59000", 49500)
    .when(col("VEHICLEPRICE") == "60000 to 69000", 64500)
    .when(col("VEHICLEPRICE") == "more than 69000", 80000)
    .otherwise(None)
    .alias("VEHICLE_PRICE_NUM"),
    # Risk factors based on insurance industry data
    when(col("VEHICLECATEGORY") == "Sport", 3)
    .when(col("VEHICLECATEGORY") == "Utility", 2)
    .when(col("VEHICLECATEGORY") == "Sedan", 1)
    .otherwise(1)
    .alias("VEHICLE_RISK"),
)

# 3. DEMOGRAPHIC FEATURES
demographic_features = claims.select(
    col("POLICYNUMBER"),
    col("MONTH"),
    col("WEEKOFMONTH"),
    # Age risk (young and elderly are higher risk)
    when(col("AGE") < 25, 3)
    .when(col("AGE").between(25, 35), 2)
    .when(col("AGE").between(36, 60), 1)
    .when(col("AGE") > 60, 2)
    .otherwise(2)
    .alias("AGE_RISK"),
    # Binary encodings
    when(col("SEX") == "Male", 1).otherwise(0).alias("IS_MALE"),
    when(col("MARITALSTATUS") == "Single", 1).otherwise(0).alias("IS_SINGLE"),
    # Driver rating (already numeric)
    col("DRIVERRATING"),
    # Policyholder age midpoint
    when(col("AGEOFPOLICYHOLDER") == "16 to 17", 16.5)
    .when(col("AGEOFPOLICYHOLDER") == "18 to 20", 19)
    .when(col("AGEOFPOLICYHOLDER") == "21 to 25", 23)
    .when(col("AGEOFPOLICYHOLDER") == "26 to 30", 28)
    .when(col("AGEOFPOLICYHOLDER") == "31 to 35", 33)
    .when(col("AGEOFPOLICYHOLDER") == "36 to 40", 38)
    .when(col("AGEOFPOLICYHOLDER") == "41 to 50", 45.5)
    .when(col("AGEOFPOLICYHOLDER") == "51 to 65", 58)
    .when(col("AGEOFPOLICYHOLDER") == "over 65", 72)
    .otherwise(None)
    .alias("POLICYHOLDER_AGE"),
)

 # 4. CLAIM RISK FACTORS - Most important for fraud detection

claim_risk_features = claims.select(
    col("POLICYNUMBER"),
    col("MONTH"),
    col("WEEKOFMONTH"),
    # Fault
    when(col("FAULT") == "Policy Holder", 1)
    .otherwise(0)
    .alias("POLICYHOLDER_FAULT"),
    # Deductible
    col("DEDUCTIBLE"),
    # Past claims (strong fraud indicator)
    when(col("PASTNUMBEROFCLAIMS") == "none", 0)
    .when(col("PASTNUMBEROFCLAIMS") == "1", 1)
    .when(col("PASTNUMBEROFCLAIMS") == "2 to 4", 3)
    .when(col("PASTNUMBEROFCLAIMS") == "more than 4", 6)
    .otherwise(0)
    .alias("PAST_CLAIMS"),
    # Documentation flags (strong fraud indicators)
    when(col("POLICEREPORTFILED") == "No", 1)
    .otherwise(0)
    .alias("NO_POLICE_REPORT"),
    when(col("WITNESSPRESENT") == "No", 1).otherwise(0).alias("NO_WITNESS"),
    # Claim supplements
    when(col("NUMBEROFSUPPLIMENTS") == "none", 0)
    .when(col("NUMBEROFSUPPLIMENTS") == "1 to 2", 1.5)
    .when(col("NUMBEROFSUPPLIMENTS") == "3 to 5", 4)
    .when(col("NUMBEROFSUPPLIMENTS") == "more than 5", 7)
    .otherwise(0)
    .alias("SUPPLEMENTS"),
    # Address change (fraud red flag)
    when(col("ADDRESSCHANGE_CLAIM") == "1 year", 1)
    .when(col("ADDRESSCHANGE_CLAIM") == "2 to 3 years", 0.5)
    .when(col("ADDRESSCHANGE_CLAIM") == "4 to 8 years", 0.2)
    .when(col("ADDRESSCHANGE_CLAIM") == "no change", 0)
    .otherwise(0)
    .alias("ADDRESS_CHANGE"),
    # Location and agent
    when(col("ACCIDENTAREA") == "Urban", 1).otherwise(0).alias("URBAN_ACCIDENT"),
    when(col("AGENTTYPE") == "External", 1).otherwise(0).alias("EXTERNAL_AGENT"),
    # Target
    col("FRAUDFOUND_P").alias("IS_FRAUD"),
)

# 5. POLICY AGGREGATIONS - Historical behavior

policy_agg = claims.group_by("POLICYNUMBER").agg(
    count(col("POLICYNUMBER")).alias("TOTAL_CLAIMS_POLICY"),
    sum_(col("FRAUDFOUND_P")).alias("FRAUD_COUNT_POLICY"),
    avg(col("DEDUCTIBLE")).alias("AVG_DEDUCTIBLE_POLICY"),
    avg(col("DRIVERRATING")).alias("AVG_RATING_POLICY"),
)

In [None]:
# ======================================================================
# JOIN ALL FEATURES
# ======================================================================
logger.info("Joining all feature sets")

# Use USING clause to avoid duplicate columns
all_features = claim_risk_features.join(
    temporal_features, ["POLICYNUMBER", "MONTH", "WEEKOFMONTH"], "left"
)

all_features = all_features.join(
    vehicle_features, ["POLICYNUMBER", "MONTH", "WEEKOFMONTH"], "left"
)

all_features = all_features.join(
    demographic_features, ["POLICYNUMBER", "MONTH", "WEEKOFMONTH"], "left"
)

all_features = all_features.join(policy_agg, "POLICYNUMBER", "left")

# ======================================================================
# INTERACTION FEATURES - Capture complex fraud patterns
# ======================================================================
logger.info("Creating interaction features")

final_features = all_features.select(
    "*",
    # Quick claim + no police report = very suspicious
    (col("DAYS_TO_CLAIM_NUM") * col("NO_POLICE_REPORT")).alias("QUICK_NO_POLICE"),
    # Vehicle depreciation vs price
    (col("VEHICLE_AGE_NUM") * col("VEHICLE_PRICE_NUM") / 10000).alias(
        "VEHICLE_DEPRECIATION"
    ),
    # External agent in urban area
    (col("EXTERNAL_AGENT") * col("URBAN_ACCIDENT")).alias("EXTERNAL_URBAN"),
    # Address change with past claims
    (col("ADDRESS_CHANGE") * col("PAST_CLAIMS")).alias("ADDRESS_PAST_CLAIMS"),
    # Young driver with sport vehicle
    (when(col("AGE_RISK") == 3, 1).otherwise(0) * col("VEHICLE_RISK")).alias(
        "YOUNG_SPORT"
    ),
    # New policy with claim
    (
        when(col("POLICY_AGE_AT_ACCIDENT") < 15, 1).otherwise(0) * col("DEDUCTIBLE")
    ).alias("NEW_POLICY_CLAIM"),
    # No documentation (police + witness)
    (col("NO_POLICE_REPORT") * col("NO_WITNESS")).alias("NO_DOCUMENTATION"),
)

# ======================================================================
# CLASS IMBALANCE HANDLING
# ======================================================================
logger.info("Calculating class weights for imbalanced data")

fraud_stats = session.sql(
    f"""
    SELECT
        SUM(CASE WHEN FRAUDFOUND_P = 1 THEN 1 ELSE 0 END) AS FRAUD_COUNT,
        COUNT(*) AS TOTAL_COUNT
    FROM ML_CREDIT.RAW_DATA.INSURANCE_CLAIMS
"""
).collect()[0]

fraud_count = fraud_stats["FRAUD_COUNT"]
total_count = fraud_stats["TOTAL_COUNT"]
fraud_ratio = fraud_count / total_count

logger.info(f"Fraud ratio: {fraud_ratio:.4f} ({fraud_count}/{total_count})")

# Add sample weights (inverse of class frequency)
weighted_features = final_features.select(
    "*",
    when(col("IS_FRAUD") == 1, (1 - fraud_ratio) / fraud_ratio)
    .otherwise(1.0)
    .alias("SAMPLE_WEIGHT"),
)

# ======================================================================
# REGISTER FEATURE VIEWS
# ======================================================================
logger.info("Registering feature views in Feature Store")
weighted_features.show()

# Model

In [None]:
# ---------------------------
# CELL 2 - Helper functions
# ---------------------------

def prepare_data_from_snowpark(df_snp, exclude_cols=None, target_col='IS_FRAUD', weight_col='SAMPLE_WEIGHT'):
    """Convert Snowpark DataFrame to pandas and prepare X, y, sample_weight.
    df_snp: snowpark.DataFrame or pandas.DataFrame
    """
    if exclude_cols is None:
        exclude_cols = ["POLICYNUMBER", "MONTH", "WEEKOFMONTH"]

    # If Snowpark DataFrame, convert
    if hasattr(df_snp, 'to_pandas'):
        df = df_snp.to_pandas()
    else:
        df = df_snp.copy()

    # Drop rows with missing target
    df = df[df[target_col].notna()].reset_index(drop=True)

    # Identify features
    feature_cols = [c for c in df.columns if c not in exclude_cols and c != target_col and c != weight_col]

    X = df[feature_cols]
    y = df[target_col].astype(int)
    sample_weight = df[weight_col] if weight_col in df.columns else None

    return X, y, sample_weight, df


def make_oof_predictions(clf_builder, X, y, sample_weight=None, n_splits=5, scaler_needed=True):
    """Run Stratified K-Fold and return out-of-fold predicted probabilities and fitted models list.
    clf_builder: function() -> sklearn-like estimator (fresh instance each call)
    Returns: oof_proba (np.array), models (list), cv_metrics (dict)
    """
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    oof_proba = np.zeros(len(y))
    models = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
        y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
        sw_train = sample_weight.iloc[train_idx] if sample_weight is not None else None

        if scaler_needed:
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_val_scaled = scaler.transform(X_val)
        else:
            X_train_scaled, X_val_scaled = X_train.values, X_val.values
            scaler = None

        model = clf_builder()
        # Some sklearn wrappers accept sample_weight in fit
        fit_kwargs = {'X': X_train_scaled, 'y': y_train}
        try:
            if sw_train is not None:
                model.fit(X_train_scaled, y_train, sample_weight=sw_train)
            else:
                model.fit(X_train_scaled, y_train)
        except TypeError:
            # fallback for some wrappers with different API
            model.fit(X_train_scaled, y_train)

        # Predict proba for val
        try:
            proba = model.predict_proba(X_val_scaled)[:, 1]
        except Exception:
            # Some models may not have predict_proba; fall back to decision_function
            try:
                scores = model.decision_function(X_val_scaled)
                proba = 1 / (1 + np.exp(-scores))
            except Exception:
                proba = model.predict(X_val_scaled)

        oof_proba[val_idx] = proba

        # store model + scaler
        models.append({'model': model, 'scaler': scaler})

        logger.info(f"Fold {fold+1}/{n_splits} done")

    # Compute CV AUC
    auc = roc_auc_score(y, oof_proba)
    return oof_proba, models, {'auc': auc}


def compute_metrics(y_true, proba, threshold=0.5):
    y_pred = (proba >= threshold).astype(int)
    return {
        'auc': float(roc_auc_score(y_true, proba)),
        'accuracy': float(accuracy_score(y_true, y_pred)),
        'precision': float(precision_score(y_true, y_pred, zero_division=0)),
        'recall': float(recall_score(y_true, y_pred, zero_division=0)),
        'f1': float(f1_score(y_true, y_pred, zero_division=0)),
        'confusion_matrix': confusion_matrix(y_true, y_pred).tolist()
    }


def plot_calibration(y_true, proba, n_bins=10, title=None):
    prob_true, prob_pred = calibration_curve(y_true, proba, n_bins=n_bins)
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(prob_pred, prob_true, marker='o', linewidth=1)
    ax.plot([0, 1], [0, 1], linestyle='--')
    ax.set_xlabel('Predicted probability')
    ax.set_ylabel('Observed frequency')
    if title:
        ax.set_title(title)
    return fig

def plot_rank_order(y_true, proba, n_bins=10, title=None):
    df = pd.DataFrame({'y': y_true, 'proba': proba})
    df['decile'] = pd.qcut(df['proba'].rank(method='first'), q=n_bins, labels=False)
    dec = df.groupby('decile').agg({'y': ['sum', 'count'], 'proba': 'mean'})
    dec.columns = ['fraud_count', 'total', 'avg_proba']
    dec = dec.sort_index(ascending=False)  # highest score first
    dec['fraud_rate'] = dec['fraud_count'] / dec['total']
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.bar(range(1, n_bins+1), dec['fraud_rate'])
    ax.set_xlabel('Decile (1 = highest scores)')
    ax.set_ylabel('Fraud rate')
    if title:
        ax.set_title(title)
    return fig, dec


def compute_business_metric(df_full, proba, model_name, threshold=0.5, booking_col='BOOKED'):
    """Compute a sample business metric: booking rate overall and for predicted negatives/positives.
    If booking_col not present, function will return None and skip.
    """
    df = df_full.copy()
    df['_proba'] = proba
    df['_pred'] = (df['_proba'] >= threshold).astype(int)

    if booking_col not in df.columns:
        logger.warning(f"Business column '{booking_col}' not found in data. Skipping business metric.")
        return None

    overall = df[booking_col].mean()
    rate_pred_pos = df[df['_pred'] == 1][booking_col].mean()
    rate_pred_neg = df[df['_pred'] == 0][booking_col].mean()

    return {
        'model': model_name,
        'overall_booking_rate': float(overall),
        'booking_rate_pred_pos': float(rate_pred_pos) if not np.isnan(rate_pred_pos) else None,
        'booking_rate_pred_neg': float(rate_pred_neg) if not np.isnan(rate_pred_neg) else None,
        'n_pred_pos': int((df['_pred'] == 1).sum()),
        'n_pred_neg': int((df['_pred'] == 0).sum())
    }

In [None]:
# Unified modeling notebook: Logistic Regression, LightGBM, XGBoost + CV + Model Registry + Streamlit dashboard
# Author: Generated for you
# Usage: Run as Jupyter notebook cells. Streamlit app is included at the bottom as a separate runnable block.

# ---------------------------
# CELL 1 - Imports & config
# ---------------------------
import os
import json
import tempfile
from datetime import datetime

import numpy as np
import pandas as pd

from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (roc_auc_score, accuracy_score, precision_score,
                             recall_score, f1_score, confusion_matrix)
from sklearn.calibration import calibration_curve

# LightGBM & XGBoost
try:
    import lightgbm as lgb
except Exception:
    lgb = None

try:
    import xgboost as xgb
except Exception:
    xgb = None

# Plotting
import matplotlib.pyplot as plt

# Snowflake Snowpark (user must have snowpark installed and active session configured)
from snowflake.snowpark.context import get_active_session

# Persist models
import joblib

# Logger (simple)
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
)
import joblib, tempfile, os, json
from datetime import datetime
import numpy as np

def ks_statistic(y_true, y_proba):
    """Compute Kolmogorov–Smirnov statistic."""
    data = pd.DataFrame({'y': y_true, 'proba': y_proba})
    data = data.sort_values('proba', ascending=False)
    data['cum_pos'] = np.cumsum(data['y'])
    data['cum_neg'] = np.cumsum(1 - data['y'])
    total_pos = data['y'].sum()
    total_neg = len(data) - total_pos
    data['cum_pos_rate'] = data['cum_pos'] / total_pos if total_pos > 0 else 0
    data['cum_neg_rate'] = data['cum_neg'] / total_neg if total_neg > 0 else 0
    data['ks'] = np.abs(data['cum_pos_rate'] - data['cum_neg_rate'])
    return data['ks'].max()

def train_and_register_models(
    session,
    weighted_features,
    target_col='IS_FRAUD',
    weight_col='SAMPLE_WEIGHT',
    exclude_cols=None,
    models_to_run=None,
    db_for_registry='ML_CREDIT',
    models_stage_name='MODELS.ML_MODELS_STAGE',
    registry_table_name='MODELS.MODEL_REGISTRY_COMPARISON',
    metrics_val_table_name='ANALYTICS.MODEL_METRICS_VAL',
    metrics_test_table_name='ANALYTICS.MODEL_METRICS_TEST',
    booking_col='BOOKED'
):
    if models_to_run is None:
        models_to_run = ['logistic', 'lgbm', 'xgb']

    X, y, sample_weight, df_full = prepare_data_from_snowpark(weighted_features, exclude_cols, target_col, weight_col)

    # -----------------------------------
    # Train/Val/Test Split
    # -----------------------------------
    X_train, X_temp, y_train, y_temp, w_train, w_temp = train_test_split(
        X, y, sample_weight, test_size=0.3, stratify=y, random_state=42
    )
    X_val, X_test, y_val, y_test, w_val, w_test = train_test_split(
        X_temp, y_temp, w_temp, test_size=0.5, stratify=y_temp, random_state=42
    )

    results = {}

    # Builders
    def logistic_builder():
        fraud_count = int(y_train.sum())
        non_fraud_count = len(y_train) - fraud_count
        cw = {0: 1.0, 1: float(non_fraud_count) / float(max(1, fraud_count))} if fraud_count > 0 else None
        return LogisticRegression(max_iter=1000, class_weight=cw, random_state=42, solver='lbfgs')

    def lgbm_builder():
        return lgb.LGBMClassifier(n_estimators=500, random_state=42)

    def xgb_builder():
        return xgb.XGBClassifier(
            n_estimators=500, use_label_encoder=False, eval_metric='logloss', random_state=42
        )

    builders = {
        'logistic': (logistic_builder, True),
        'lgbm': (lgbm_builder, False),
        'xgb': (xgb_builder, False)
    }

    for model_key in models_to_run:
        builder, scaler_needed = builders[model_key]
        model = builder()

        # Scaling if needed
        if scaler_needed:
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_val_scaled = scaler.transform(X_val)
            X_test_scaled = scaler.transform(X_test)
        else:
            scaler = None
            X_train_scaled, X_val_scaled, X_test_scaled = X_train.values, X_val.values, X_test.values

        # Fit
        model.fit(X_train_scaled, y_train, sample_weight=w_train)

        # Predict probabilities
        val_proba = model.predict_proba(X_val_scaled)[:, 1]
        test_proba = model.predict_proba(X_test_scaled)[:, 1]

        # Metrics function (weighted averages)
        def compute_metrics(y_true, y_pred_proba, w=None):
            y_pred = (y_pred_proba > 0.5).astype(int)
            return {
                'auc': roc_auc_score(y_true, y_pred_proba, sample_weight=w),
                'accuracy': accuracy_score(y_true, y_pred, sample_weight=w),
                'precision': precision_score(y_true, y_pred, average='weighted', sample_weight=w, zero_division=0),
                'recall': recall_score(y_true, y_pred, average='weighted', sample_weight=w, zero_division=0),
                'f1': f1_score(y_true, y_pred, average='weighted', sample_weight=w, zero_division=0),
                'ks': ks_statistic(y_true, y_pred_proba)
            }

        val_metrics = compute_metrics(y_val, val_proba, w_val)
        test_metrics = compute_metrics(y_test, test_proba, w_test)

        # -----------------------------------
        # Save Artifact to Stage
        # -----------------------------------
        artifact_filename = f"{model_key}_model_{datetime.utcnow().strftime('%Y%m%dT%H%M%SZ')}.joblib"
        with tempfile.TemporaryDirectory() as tmpdir:
            local_path = os.path.join(tmpdir, artifact_filename)
            joblib.dump({'model': model, 'scaler': scaler, 'feature_columns': X.columns.tolist()}, local_path)
            stage = f"{db_for_registry}.{models_stage_name}"
            session.file.put(local_path, f"@{stage}", overwrite=True, auto_compress=False)
        artifact_uri = f"@{stage}/{artifact_filename}"

        # -----------------------------------
        # CREATE TABLES IF NOT EXIST
        # -----------------------------------
        for table in [metrics_val_table_name, metrics_test_table_name]:
            session.sql(f"""
                CREATE TABLE IF NOT EXISTS {db_for_registry}.{table} (
                    MODEL_NAME STRING,
                    MODEL_VERSION STRING,
                    FRAMEWORK STRING,
                    TRAINING_DATE TIMESTAMP_NTZ,
                    AUC FLOAT,
                    ACCURACY FLOAT,
                    PRECISION FLOAT,
                    RECALL FLOAT,
                    F1_SCORE FLOAT,
                    KS FLOAT,
                    CREATED_BY STRING
                )
            """).collect()

        # -----------------------------------
        # INSERT INTO METRICS TABLES
        # -----------------------------------
        version = datetime.utcnow().strftime('%Y.%m.%d.%H%M%S')

        insert_sql = f"""
        INSERT INTO {db_for_registry}.{metrics_val_table_name}
        VALUES (?, ?, ?, CURRENT_TIMESTAMP(), ?, ?, ?, ?, ?, ?, ?)
        """
        session.sql(insert_sql, params=[
            f"vehicle_insurance_fraud_detector_{model_key}",
            version, model_key,
            val_metrics['auc'], val_metrics['accuracy'], val_metrics['precision'],
            val_metrics['recall'], val_metrics['f1'], val_metrics['ks'], "Luis Vejarano"
         ]).collect()

        insert_sql2 = f"""
        INSERT INTO {db_for_registry}.{metrics_test_table_name}
        VALUES (?, ?, ?, CURRENT_TIMESTAMP(), ?, ?, ?, ?, ?, ?, ?)
        """
        session.sql(insert_sql2, params=[
            f"vehicle_insurance_fraud_detector_{model_key}",
            version, model_key,
            test_metrics['auc'], test_metrics['accuracy'], test_metrics['precision'],
            test_metrics['recall'], test_metrics['f1'], test_metrics['ks'], "Manideep"
        ]).collect()

        # -----------------------------------
        # REGISTER MODEL
        # -----------------------------------
        registry_table = f"{db_for_registry}.{registry_table_name}"
        session.sql(f"""
        CREATE TABLE IF NOT EXISTS {registry_table} (
          MODEL_NAME STRING,
          MODEL_VER  STRING,
          STAGE_NAME STRING,
          ARTIFACT_URI STRING,
          FRAMEWORK STRING,
          METRICS VARIANT,
          CREATED_AT TIMESTAMP_NTZ,
          CREATED_BY STRING
        )
        """).collect()

        metrics_json = {
            'validation': val_metrics,
            'test': test_metrics
        }

        session.sql(f"""
        INSERT INTO {registry_table}
        (MODEL_NAME, MODEL_VER, STAGE_NAME, ARTIFACT_URI, FRAMEWORK, METRICS, CREATED_AT, CREATED_BY)
        SELECT ?, ?, ?, ?, ?, PARSE_JSON(?), CURRENT_TIMESTAMP(), ?
        """, params=[
            f"vehicle_insurance_fraud_detector_{model_key}",
            version, 'DEV', artifact_uri, model_key,
            json.dumps(metrics_json), "Manideep"
        ]).collect()

        results[model_key] = {
            'val_metrics': val_metrics,
            'test_metrics': test_metrics,
            'artifact_uri': artifact_uri
        }

        logger.info(f"✅ Completed training & registration for {model_key}")

    return results