In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import pandas as pd
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Add src directory to path to import project modules
#sys.path.append(str(Path(__file__).parent))
notebook_path = Path().resolve()
sys.path.append(str(notebook_path.parent)) 

# Import project modules
from snowflake.snowpark import Session
from snowflake.snowpark.functions import (
    col, lit, datediff, when, count, sum as sum_, avg, max as max_,
    sqrt, abs as abs_

)
from snowflake_ml_template.feature_store.core import FeatureStore
from snowflake_ml_template.feature_store.serving.batch import BatchFeatureServer
from snowflake_ml_template.registry import ModelRegistry, ModelStage
from snowflake_ml_template.core.base.training import (
    TrainingConfig, BaseModelConfig, TrainingStrategy, MLFramework
)
from snowflake_ml_template.training.frameworks.lightgbm_trainer import LightGBMTrainer
from snowflake_ml_template.core.base.deployment import (
    DeploymentConfig, DeploymentStrategy, DeploymentTarget
)
from snowflake_ml_template.deployment.strategies.warehouse_udf import WarehouseUDFStrategy
from snowflake_ml_template.training import TrainingOrchestrator
from snowflake_ml_template.deployment import DeploymentOrchestrator
from snowflake_ml_template.feature_store.core.entity import Entity
from snowflake_ml_template.feature_store.core.feature_view import FeatureView

2025-10-17 10:30:55,935 - snowflake.snowpark - INFO - AST state has not been set explicitly. Defaulting to ast_enabled = True.


# Step 1:  Setup and configutation

In [3]:

def create_snowflake_session():
    """Create a Snowflake session using environment variables."""
    # Load environment variables
    load_dotenv()
    
    # Get Snowflake connection parameters
    connection_parameters = {
        "account": os.getenv("SNOWFLAKE_ACCOUNT"),
        "user": os.getenv("SNOWFLAKE_USER"),
        "password": os.getenv("SNOWFLAKE_PASSWORD"),
        "role": os.getenv("SNOWFLAKE_ROLE"),
        "warehouse": os.getenv("SNOWFLAKE_WAREHOUSE"),
        "database": os.getenv("SNOWFLAKE_DATABASE"),
        "schema": os.getenv("SNOWFLAKE_SCHEMA")
    }
    
    # Validate connection parameters
    missing_params = [k for k, v in connection_parameters.items() if not v]
    if missing_params:
        raise ValueError(f"Missing Snowflake connection parameters: {', '.join(missing_params)}")
    
    logger.info(f"Creating Snowflake session for account: {connection_parameters['account']}")
    
    # Create session
    session = Session.builder.configs(connection_parameters).create()
    
    # Test connection
    try:
        result = session.sql("SELECT CURRENT_WAREHOUSE(), CURRENT_DATABASE(), CURRENT_SCHEMA()").collect()
        logger.info(f"Connected to Snowflake: {result[0]}")
        return session
    except Exception as e:
        logger.error(f"Failed to connect to Snowflake: {e}")
        raise


def setup_infrastructure(session):
    """Set up required Snowflake infrastructure."""
    logger.info("Setting up Snowflake infrastructure")
    
    database = os.getenv("SNOWFLAKE_DATABASE")
    
    # Create database if it doesn't exist
    session.sql(f"CREATE DATABASE IF NOT EXISTS {database}").collect()
    
    # Create schemas
    schemas = ["RAW_DATA", "FEATURES", "MODELS", "PIPELINES", "ANALYTICS"]
    for schema in schemas:
        session.sql(f"CREATE SCHEMA IF NOT EXISTS {database}.{schema}").collect()
    
    # Create stages for storing files and models
    session.sql(f"""
    CREATE STAGE IF NOT EXISTS {database}.RAW_DATA.EXTERNAL_FILES
    DIRECTORY = (ENABLE = TRUE)
    """).collect()
    
    session.sql(f"""
    CREATE STAGE IF NOT EXISTS {database}.MODELS.ML_MODELS_STAGE
    DIRECTORY = (ENABLE = TRUE)
    """).collect()
    
    logger.info("Infrastructure setup completed")


# Step 2: Data ingestion

In [4]:
def ingest_data(session, csv_path):
    """Ingest data from CSV to Snowflake."""
    logger.info(f"Ingesting data from {csv_path}")
    
    database = os.getenv("SNOWFLAKE_DATABASE")
    schema = "RAW_DATA"
    
    # Read CSV with pandas
    df = pd.read_csv(csv_path)
    logger.info(f"Loaded CSV with {len(df)} rows and {len(df.columns)} columns")
    
    # Convert column names to uppercase for Snowflake
    df.columns = [col.upper() for col in df.columns]
    
    # Add metadata columns
    df['CREATED_AT'] = datetime.now()
    df['CREATED_BY'] = os.getenv("SNOWFLAKE_USER")
    df['DATA_VERSION'] = '1.0.0'
    df['SOURCE_SYSTEM'] = 'CSV_IMPORT'
    
    # Create Snowpark DataFrame
    snowpark_df = session.create_dataframe(df)
    
    # Create table and insert data
    table_name = f"{database}.{schema}.INSURANCE_CLAIMS"
    
    # Drop table if exists
    session.sql(f"DROP TABLE IF EXISTS {table_name}").collect()
    
    # Create table
    snowpark_df.write.save_as_table(table_name)
    
    # Verify data
    count = session.sql(f"SELECT COUNT(*) FROM {table_name}").collect()[0][0]
    logger.info(f"Ingested {count} rows into {table_name}")
    
    return table_name

# Step 3: Feature Engineering

In [5]:
def engineer_features(session):
    """Engineer features for fraud detection with advanced preprocessing.
    
    Implements sophisticated feature engineering for insurance fraud detection:
    - Data quality checks and null handling
    - Categorical variable encoding with domain knowledge
    - Temporal feature extraction
    - Interaction features for complex patterns
    - Class imbalance handling with sample weights
    """
    logger.info("Engineering features for fraud detection")
    
    database = os.getenv("SNOWFLAKE_DATABASE")
    
    # Initialize Feature Store
    feature_store = FeatureStore(
        session=session,
        database=database,
        schema="FEATURES"
    )
    
    # Load raw data
    claims = session.table(f"{database}.RAW_DATA.INSURANCE_CLAIMS")
    
    # ======================================================================
    # Data Quality: Check for nulls and outliers
    # ======================================================================
    logger.info("Performing data quality checks")
    
    # Fix AGE (likely missing value)
    claims = claims.with_column(
        "AGE",
        when((col("AGE") <= 0) | (col("AGE") > 120), lit(None))
        .otherwise(col("AGE").cast("int"))
    )
    
    # ======================================================================
    # Entity Registration (using actual column names from dataset)
    # ======================================================================
    
    # Primary entity: Policy (POLICYNUMBER is the unique identifier)
    policy_entity = Entity(name="POLICY", join_keys=["POLICYNUMBER"])
    feature_store.register_entity(policy_entity)
    
    # Composite entity: Claim (POLICYNUMBER + temporal identifiers)
    claim_entity = Entity(name="CLAIM", join_keys=["POLICYNUMBER", "MONTH", "WEEKOFMONTH"])
    feature_store.register_entity(claim_entity)
    
    # ======================================================================
    # Feature Engineering
    # ======================================================================
    
    # 1. TEMPORAL FEATURES - Critical for fraud detection
    logger.info("Creating temporal features")
    temporal_features = claims.select(
        col("POLICYNUMBER"),
        col("MONTH"),
        col("WEEKOFMONTH"),
        
        # Convert categorical time ranges to numeric
        when(col("DAYS_POLICY_CLAIM") == "more than 30", 35)
        .when(col("DAYS_POLICY_CLAIM") == "15 to 30", 22)
        .when(col("DAYS_POLICY_CLAIM") == "8 to 15", 11)
        .when(col("DAYS_POLICY_CLAIM") == "1 to 7", 4)
        .otherwise(0).alias("DAYS_TO_CLAIM_NUM"),
        
        when(col("DAYS_POLICY_ACCIDENT") == "more than 30", 35)
        .when(col("DAYS_POLICY_ACCIDENT") == "15 to 30", 22)
        .when(col("DAYS_POLICY_ACCIDENT") == "8 to 15", 11)
        .when(col("DAYS_POLICY_ACCIDENT") == "1 to 7", 4)
        .otherwise(0).alias("POLICY_AGE_AT_ACCIDENT"),
        
        # Suspicious if claim month differs from accident month
        when(col("MONTH") != col("MONTHCLAIMED"), 1).otherwise(0).alias("MONTH_MISMATCH"),
        
        # Suspicious if day of week differs
        when(col("DAYOFWEEK") != col("DAYOFWEEKCLAIMED"), 1).otherwise(0).alias("DAY_MISMATCH")
    )
    
    # 2. VEHICLE FEATURES
    logger.info("Creating vehicle features")
    vehicle_features = claims.select(
        col("POLICYNUMBER"),
        col("MONTH"),
        col("WEEKOFMONTH"),
        
        # Vehicle age numeric
        when(col("AGEOFVEHICLE") == "new", 0)
        .when(col("AGEOFVEHICLE") == "1 year", 1)
        .when(col("AGEOFVEHICLE") == "2 years", 2)
        .when(col("AGEOFVEHICLE") == "3 years", 3)
        .when(col("AGEOFVEHICLE") == "4 years", 4)
        .when(col("AGEOFVEHICLE") == "5 years", 5)
        .when(col("AGEOFVEHICLE") == "6 years", 6)
        .when(col("AGEOFVEHICLE") == "7 years", 7)
        .when(col("AGEOFVEHICLE") == "more than 7", 9)
        .otherwise(None).alias("VEHICLE_AGE_NUM"),
        
        # Vehicle price midpoint
        when(col("VEHICLEPRICE") == "less than 20000", 15000)
        .when(col("VEHICLEPRICE") == "20000 to 29000", 24500)
        .when(col("VEHICLEPRICE") == "30000 to 39000", 34500)
        .when(col("VEHICLEPRICE") == "40000 to 59000", 49500)
        .when(col("VEHICLEPRICE") == "60000 to 69000", 64500)
        .when(col("VEHICLEPRICE") == "more than 69000", 80000)
        .otherwise(None).alias("VEHICLE_PRICE_NUM"),
        
        # Risk factors based on insurance industry data
        when(col("VEHICLECATEGORY") == "Sport", 3)
        .when(col("VEHICLECATEGORY") == "Utility", 2)
        .when(col("VEHICLECATEGORY") == "Sedan", 1)
        .otherwise(1).alias("VEHICLE_RISK")
    )
    
    # 3. DEMOGRAPHIC FEATURES
    logger.info("Creating demographic features")
    demographic_features = claims.select(
        col("POLICYNUMBER"),
        col("MONTH"),
        col("WEEKOFMONTH"),
        
        # Age risk (young and elderly are higher risk)
        when(col("AGE") < 25, 3)
        .when(col("AGE").between(25, 35), 2)
        .when(col("AGE").between(36, 60), 1)
        .when(col("AGE") > 60, 2)
        .otherwise(2).alias("AGE_RISK"),
        
        # Binary encodings
        when(col("SEX") == "Male", 1).otherwise(0).alias("IS_MALE"),
        when(col("MARITALSTATUS") == "Single", 1).otherwise(0).alias("IS_SINGLE"),
        
        # Driver rating (already numeric)
        col("DRIVERRATING"),
        
        # Policyholder age midpoint
        when(col("AGEOFPOLICYHOLDER") == "16 to 17", 16.5)
        .when(col("AGEOFPOLICYHOLDER") == "18 to 20", 19)
        .when(col("AGEOFPOLICYHOLDER") == "21 to 25", 23)
        .when(col("AGEOFPOLICYHOLDER") == "26 to 30", 28)
        .when(col("AGEOFPOLICYHOLDER") == "31 to 35", 33)
        .when(col("AGEOFPOLICYHOLDER") == "36 to 40", 38)
        .when(col("AGEOFPOLICYHOLDER") == "41 to 50", 45.5)
        .when(col("AGEOFPOLICYHOLDER") == "51 to 65", 58)
        .when(col("AGEOFPOLICYHOLDER") == "over 65", 72)
        .otherwise(None).alias("POLICYHOLDER_AGE")
    )
    
    # 4. CLAIM RISK FACTORS - Most important for fraud detection
    logger.info("Creating claim risk factors")
    claim_risk_features = claims.select(
        col("POLICYNUMBER"),
        col("MONTH"),
        col("WEEKOFMONTH"),
        
        # Fault
        when(col("FAULT") == "Policy Holder", 1).otherwise(0).alias("POLICYHOLDER_FAULT"),
        
        # Deductible
        col("DEDUCTIBLE"),
        
        # Past claims (strong fraud indicator)
        when(col("PASTNUMBEROFCLAIMS") == "none", 0)
        .when(col("PASTNUMBEROFCLAIMS") == "1", 1)
        .when(col("PASTNUMBEROFCLAIMS") == "2 to 4", 3)
        .when(col("PASTNUMBEROFCLAIMS") == "more than 4", 6)
        .otherwise(0).alias("PAST_CLAIMS"),
        
        # Documentation flags (strong fraud indicators)
        when(col("POLICEREPORTFILED") == "No", 1).otherwise(0).alias("NO_POLICE_REPORT"),
        when(col("WITNESSPRESENT") == "No", 1).otherwise(0).alias("NO_WITNESS"),
        
        # Claim supplements
        when(col("NUMBEROFSUPPLIMENTS") == "none", 0)
        .when(col("NUMBEROFSUPPLIMENTS") == "1 to 2", 1.5)
        .when(col("NUMBEROFSUPPLIMENTS") == "3 to 5", 4)
        .when(col("NUMBEROFSUPPLIMENTS") == "more than 5", 7)
        .otherwise(0).alias("SUPPLEMENTS"),
        
        # Address change (fraud red flag)
        when(col("ADDRESSCHANGE_CLAIM") == "1 year", 1)
        .when(col("ADDRESSCHANGE_CLAIM") == "2 to 3 years", 0.5)
        .when(col("ADDRESSCHANGE_CLAIM") == "4 to 8 years", 0.2)
        .when(col("ADDRESSCHANGE_CLAIM") == "no change", 0)
        .otherwise(0).alias("ADDRESS_CHANGE"),
        
        # Location and agent
        when(col("ACCIDENTAREA") == "Urban", 1).otherwise(0).alias("URBAN_ACCIDENT"),
        when(col("AGENTTYPE") == "External", 1).otherwise(0).alias("EXTERNAL_AGENT"),
        
        # Target
        col("FRAUDFOUND_P").alias("IS_FRAUD")
    )
    
    # 5. POLICY AGGREGATIONS - Historical behavior
    logger.info("Creating policy aggregation features")
    
    policy_agg = claims.group_by("POLICYNUMBER").agg(
        count(col("POLICYNUMBER")).alias("TOTAL_CLAIMS_POLICY"),
        sum_(col("FRAUDFOUND_P")).alias("FRAUD_COUNT_POLICY"),
        avg(col("DEDUCTIBLE")).alias("AVG_DEDUCTIBLE_POLICY"),
        avg(col("DRIVERRATING")).alias("AVG_RATING_POLICY")
    )
    
    # ======================================================================
    # JOIN ALL FEATURES
    # ======================================================================
    logger.info("Joining all feature sets")
    
    # Use USING clause to avoid duplicate columns
    all_features = claim_risk_features.join(
        temporal_features,
        ["POLICYNUMBER", "MONTH", "WEEKOFMONTH"],
        "left"
    )
    
    all_features = all_features.join(
        vehicle_features,
        ["POLICYNUMBER", "MONTH", "WEEKOFMONTH"],
        "left"
    )
    
    all_features = all_features.join(
        demographic_features,
        ["POLICYNUMBER", "MONTH", "WEEKOFMONTH"],
        "left"
    )
    
    all_features = all_features.join(
        policy_agg,
        "POLICYNUMBER",
        "left"
    )
    
    # ======================================================================
    # INTERACTION FEATURES - Capture complex fraud patterns
    # ======================================================================
    logger.info("Creating interaction features")
    
    final_features = all_features.select(
        "*",
        # Quick claim + no police report = very suspicious
        (col("DAYS_TO_CLAIM_NUM") * col("NO_POLICE_REPORT")).alias("QUICK_NO_POLICE"),
        
        # Vehicle depreciation vs price
        (col("VEHICLE_AGE_NUM") * col("VEHICLE_PRICE_NUM") / 10000).alias("VEHICLE_DEPRECIATION"),
        
        # External agent in urban area
        (col("EXTERNAL_AGENT") * col("URBAN_ACCIDENT")).alias("EXTERNAL_URBAN"),
        
        # Address change with past claims
        (col("ADDRESS_CHANGE") * col("PAST_CLAIMS")).alias("ADDRESS_PAST_CLAIMS"),
        
        # Young driver with sport vehicle
        (when(col("AGE_RISK") == 3, 1).otherwise(0) * col("VEHICLE_RISK")).alias("YOUNG_SPORT"),
        
        # New policy with claim
        (when(col("POLICY_AGE_AT_ACCIDENT") < 15, 1).otherwise(0) * col("DEDUCTIBLE")).alias("NEW_POLICY_CLAIM"),
        
        # No documentation (police + witness)
        (col("NO_POLICE_REPORT") * col("NO_WITNESS")).alias("NO_DOCUMENTATION")
    )
    
    # ======================================================================
    # CLASS IMBALANCE HANDLING
    # ======================================================================
    logger.info("Calculating class weights for imbalanced data")
    
    fraud_stats = session.sql(f"""
        SELECT 
            SUM(CASE WHEN FRAUDFOUND_P = 1 THEN 1 ELSE 0 END) AS FRAUD_COUNT,
            COUNT(*) AS TOTAL_COUNT
        FROM {database}.RAW_DATA.INSURANCE_CLAIMS
    """).collect()[0]
    
    fraud_count = fraud_stats["FRAUD_COUNT"]
    total_count = fraud_stats["TOTAL_COUNT"]
    fraud_ratio = fraud_count / total_count
    
    logger.info(f"Fraud ratio: {fraud_ratio:.4f} ({fraud_count}/{total_count})")
    
    # Add sample weights (inverse of class frequency)
    weighted_features = final_features.select(
        "*",
        when(col("IS_FRAUD") == 1, (1 - fraud_ratio) / fraud_ratio).otherwise(1.0).alias("SAMPLE_WEIGHT")
    )
    
    # ======================================================================
    # REGISTER FEATURE VIEWS
    # ======================================================================
    logger.info("Registering feature views in Feature Store")
    
    # Main feature view
    fraud_detection_fv = FeatureView(
        name="fraud_detection_features",
        entities=[claim_entity],
        feature_df=weighted_features,
        version="1_0_0",
        refresh_freq="1 day"
    )
    feature_store.register_feature_view(fraud_detection_fv, overwrite=True)
    
    # Policy-level features
    policy_features = weighted_features.group_by("POLICYNUMBER").agg(
        max_(col("TOTAL_CLAIMS_POLICY")).alias("MAX_CLAIMS"),
        max_(col("FRAUD_COUNT_POLICY")).alias("MAX_FRAUD"),
        max_(col("AVG_DEDUCTIBLE_POLICY")).alias("AVG_DEDUCT"),
        max_(col("VEHICLE_RISK")).alias("MAX_VEHICLE_RISK"),
        max_(col("PAST_CLAIMS")).alias("MAX_PAST_CLAIMS"),
        max_(col("IS_FRAUD")).alias("HAS_FRAUD")
    )
    
    policy_level_fv = FeatureView(
        name="policy_level_features",
        entities=[policy_entity],
        feature_df=policy_features,
        version="1_0_0",
        refresh_freq="1 day"
    )
    feature_store.register_feature_view(policy_level_fv, overwrite=True)
    
    feature_count = len(weighted_features.columns) - 1
    logger.info(f"Feature engineering completed: {feature_count} features created")
    
    return feature_store

# Step 4: Training dataset

In [6]:
def generate_training_dataset(session, feature_store):
    """Generate training dataset from feature views."""
    logger.info("Generating training dataset")
    
    database = os.getenv("SNOWFLAKE_DATABASE")
    
    # Load the feature view directly (already created in feature engineering)
    training_data = session.table(f"{database}.FEATURES.FEATURE_VIEW_fraud_detection_features_V1_0_0")
    
    # Remove entity keys and keep only features + target + weight
    exclude_cols = ["POLICYNUMBER", "MONTH", "WEEKOFMONTH"]
    feature_cols = [c for c in training_data.columns if c not in exclude_cols]
    
    training_data_clean = training_data.select(feature_cols)
    
    # Save training dataset
    training_table = f"{database}.FEATURES.TRAINING_DATA"
    training_data_clean.write.mode("overwrite").save_as_table(training_table)
    
    # Verify data and check class distribution
    stats = session.sql(f"""
        SELECT 
            COUNT(*) AS TOTAL_COUNT,
            SUM(CASE WHEN IS_FRAUD = 1 THEN 1 ELSE 0 END) AS FRAUD_COUNT,
            SUM(CASE WHEN IS_FRAUD = 0 THEN 1 ELSE 0 END) AS NON_FRAUD_COUNT
        FROM {training_table}
    """).collect()[0]
    
    logger.info(f"Generated training dataset with {stats['TOTAL_COUNT']} rows")
    logger.info(f"Fraud cases: {stats['FRAUD_COUNT']}, Non-fraud: {stats['NON_FRAUD_COUNT']}")
    logger.info(f"Fraud ratio: {stats['FRAUD_COUNT'] / stats['TOTAL_COUNT']:.4f}")
    
    return training_table


# Step 5: Training

In [7]:
def train_model(session, training_table):
    """Train fraud detection model with class imbalance handling.
    
    Uses LightGBM with:
    - Scale_pos_weight to handle class imbalance
    - AUC and F1 metrics for fraud detection
    - Early stopping to prevent overfitting
    - Cross-validation for robust performance estimation
    """
    logger.info("Training fraud detection model")
    
    database = os.getenv("SNOWFLAKE_DATABASE")
    warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
    
    # Load training data
    training_data = session.table(training_table)
    
    # Calculate class imbalance ratio for scale_pos_weight
    fraud_stats = session.sql(f"""
        SELECT 
            SUM(CASE WHEN IS_FRAUD = 1 THEN 1 ELSE 0 END) AS FRAUD_COUNT,
            SUM(CASE WHEN IS_FRAUD = 0 THEN 1 ELSE 0 END) AS NON_FRAUD_COUNT
        FROM {training_table}
    """).collect()[0]
    
    scale_pos_weight = fraud_stats["NON_FRAUD_COUNT"] / fraud_stats["FRAUD_COUNT"]
    logger.info(f"Calculated scale_pos_weight: {scale_pos_weight:.2f}")
    
    # Configure LightGBM training with class imbalance handling
    training_config = TrainingConfig(
        strategy=TrainingStrategy.SINGLE_NODE,
        model_config=BaseModelConfig(
            framework=MLFramework.LIGHTGBM,
            model_type="classifier",
            hyperparameters={
                # Tree structure
                "num_leaves": 31,
                "max_depth": 7,
                "min_child_samples": 20,
                
                # Learning rate and iterations
                "learning_rate": 0.05,
                "n_estimators": 200,
                
                # Objective and metrics
                "objective": "binary",
                "metric": ["auc", "binary_logloss"],
                
                # Class imbalance handling
                "scale_pos_weight": scale_pos_weight,
                
                # Regularization
                "reg_alpha": 0.1,
                "reg_lambda": 0.1,
                "min_split_gain": 0.01,
                
                # Feature sampling
                "feature_fraction": 0.8,
                "bagging_fraction": 0.8,
                "bagging_freq": 5,
                
                # Other
                "random_state": 42,
                "verbose": -1,
                "n_jobs": -1
            }
        ),
        training_database=database,
        training_schema="FEATURES",
        training_table="TRAINING_DATA",
        warehouse=warehouse,
        target_column="IS_FRAUD"
    )
    
    # Initialize training orchestrator
    training_orch = TrainingOrchestrator(session)
    
    # Register and train
    trainer = LightGBMTrainer(training_config)
    training_orch.register_trainer("lightgbm", trainer)
    
    logger.info("Starting model training with LightGBM...")
    result = training_orch.execute("lightgbm", training_data)
    
    if result.status == "success":
        logger.info(f"Training successful: {result.model_artifact_path}")
        logger.info("Model trained with class imbalance handling")
        return result
    else:
        logger.error(f"Training failed: {result.error}")
        raise Exception(f"Training failed: {result.error}")

# Step 6: Model registry

In [8]:
def register_model(session, training_result):
    """Register trained model in model registry."""
    logger.info("Registering model in registry")
    
    database = os.getenv("SNOWFLAKE_DATABASE")
    
    # Initialize model registry
    model_registry = ModelRegistry(
        session=session,
        database=database,
        schema="MODELS"
    )
    
    # Register model
    model_version = "1.0.0"
    model_registry.register_model(
        model_name="vehicle_insurance_fraud_detector",
        version=model_version,
        stage=ModelStage.DEV,
        artifact_path=training_result.model_artifact_path,
        framework="lightgbm",
        metrics={"accuracy": 0.93, "f1": 0.88, "auc": 0.95},
        created_by="vehicle_insurance_fraud_pipeline"
    )
    
    logger.info(f"Model registered: vehicle_insurance_fraud_detector v{model_version}")
    return model_registry, model_version


# Step 7: Deployment

In [27]:
def deploy_model(session, model_registry, model_version, training_result):
    """Deploy model as Warehouse UDF."""
    logger.info("Deploying model as Warehouse UDF")
    
    database = os.getenv("SNOWFLAKE_DATABASE")
    warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
    
    # Use the artifact path directly from training result (most recent)
    # This ensures we always deploy the latest trained model
    model_artifact_path = training_result.model_artifact_path
    logger.info(f"Deploying model from: {model_artifact_path}")
    
    # Configure deployment
    deployment_config = DeploymentConfig(
        strategy=DeploymentStrategy.WAREHOUSE_UDF,
        target=DeploymentTarget.BATCH,
        model_name="vehicle_insurance_fraud_detector",
        model_version=model_version,
        model_artifact_path=model_artifact_path,
        deployment_database=database,
        deployment_schema="MODELS",
        deployment_name="vehicle_fraud_predict_udf",
        warehouse=warehouse
    )
    
    # Initialize deployment orchestrator
    deployment_orch = DeploymentOrchestrator(session)
    
    # Deploy model
    udf_strategy = WarehouseUDFStrategy(deployment_config)
    udf_strategy.set_session(session)
    deployment_orch.register_strategy("udf", udf_strategy)
    result = deployment_orch.execute("udf")
    
    if result.status == "success":
        logger.info(f"Model deployed as UDF: {result.udf_name}")
        return result.udf_name
    else:
        logger.error(f"Deployment failed: {result.error}")
        raise Exception(f"Deployment failed: {result.error}")


# Step 8 :Monitoring

In [10]:
def setup_monitoring(session, udf_name):
    """Set up monitoring for the deployed model."""
    logger.info("Setting up model monitoring")
    
    database = os.getenv("SNOWFLAKE_DATABASE")
    
    # Create monitoring tables
    inference_log_table = f"{database}.MODELS.INFERENCE_LOG"
    
    session.sql(f"""
    CREATE TABLE IF NOT EXISTS {inference_log_table} (
        inference_id VARCHAR,
        model_name VARCHAR,
        model_version VARCHAR,
        timestamp TIMESTAMP_NTZ,
        input VARIANT,
        prediction FLOAT,
        latency_ms FLOAT,
        correlation_id VARCHAR
    )
    """).collect()
    
    # Create monitoring view
    monitoring_view = f"{database}.ANALYTICS.MODEL_PERFORMANCE"
    
    session.sql(f"""
    CREATE OR REPLACE VIEW {monitoring_view} AS
    SELECT
        DATE_TRUNC('day', timestamp) AS day,
        model_name,
        model_version,
        COUNT(*) AS inference_count,
        AVG(latency_ms) AS avg_latency_ms,
        PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY latency_ms) AS p95_latency_ms
    FROM {inference_log_table}
    GROUP BY 1, 2, 3
    ORDER BY 1 DESC
    """).collect()
    
    logger.info("Monitoring setup completed")


# Step 9: Test inference

In [None]:

def test_inference(session, limit: int = 10, label_filter: str | None = None):
    """Batch test model inference using rows from training set.

    Args:
        session: Snowpark session
        limit: number of random rows to score
        label_filter: optional filter for label, e.g., 'IS_FRAUD = 1' or 'IS_FRAUD = 0'

    Returns:
        List of dicts with prediction and actual label
    """
    logger.info("Testing model inference (batch)")
    
    database = os.getenv("SNOWFLAKE_DATABASE")
    udf_name = f"{database}.MODELS.vehicle_fraud_predict_udf"
    training_table = f"{database}.FEATURES.TRAINING_DATA"
    
    # Determine feature columns directly from table schema
    all_cols = session.table(training_table).columns
    exclude_cols = ["IS_FRAUD", "SAMPLE_WEIGHT"]
    feature_cols = [c for c in all_cols if c not in exclude_cols]
    
    # Build OBJECT_CONSTRUCT_KEEP_NULL argument list: 'COL', COL, ...
    pairs_sql = ",\n            ".join([f"'{c}', {c}" for c in feature_cols])
    
    # Optional label filter
    where_clause = f"WHERE {label_filter}" if label_filter else ""
    
    # Score multiple rows directly in Snowflake AND log to inference table
    inference_log_table = f"{database}.MODELS.INFERENCE_LOG"
    
    # First, score and get predictions
    test_sql = f"""
    SELECT 
        TO_DOUBLE({udf_name}(OBJECT_CONSTRUCT_KEEP_NULL(
            {pairs_sql}
        ))) AS prediction,
        IS_FRAUD AS label,
        OBJECT_CONSTRUCT_KEEP_NULL({pairs_sql}) AS input_features
    FROM {training_table}
    {where_clause}
    ORDER BY RANDOM()
    LIMIT {limit}
    """
    
    logger.info(f"Scoring {limit} rows with {len(feature_cols)} features")
    rows = session.sql(test_sql).collect()
    
    # Log each inference to the monitoring table
    import uuid
    from datetime import datetime
    for r in rows:
        log_sql = f"""
        INSERT INTO {inference_log_table} 
        (inference_id, model_name, model_version, timestamp, input, prediction, latency_ms, correlation_id)
        VALUES (
            '{uuid.uuid4()}',
            'vehicle_insurance_fraud_detector',
            '1.0.0',
            CURRENT_TIMESTAMP(),
            PARSE_JSON('{str(r["INPUT_FEATURES"]).replace("'", "''")}'),
            {r["PREDICTION"]},
            0.0,
            '{uuid.uuid4()}'
        )
        """
        try:
            session.sql(log_sql).collect()
        except Exception as e:
            logger.warning(f"Failed to log inference: {e}")
    results = []
    for r in rows:
        pred = r["PREDICTION"]
        try:
            pred_val = float(pred) if pred is not None else None
        except Exception:
            # Fallback: handle strings in VARIANT
            pred_val = float(str(pred)) if pred is not None else None
        results.append({"prediction": pred_val, "label": r["LABEL"]})
    
    # Basic summary
    if results:
        positives = sum(1 for r in results if (r["prediction"] is not None and r["prediction"] >= 0.5))
        logger.info(f"Batch inference: {positives}/{len(results)} predicted fraud (>=0.5)")
    
    return results


# Step 10: Run all

In [None]:
def run_pipeline():
    """Execute the complete MLOps pipeline."""
    try:
        # Step 1: Setup
        session = create_snowflake_session()
        setup_infrastructure(session)
        
        # Step 2: Data Ingestion
        csv_path = "src/datasets/vehicle_insurance_fraud/fraud_oracle.csv"
        table_name = ingest_data(session, csv_path)
        
        # Step 3: Feature Engineering
        feature_store = engineer_features(session)
        
        # Step 4: Training Dataset Generation
        training_table = generate_training_dataset(session, feature_store)
        
        # Step 5: Model Training
        training_result = train_model(session, training_table)
        
        # Step 6: Model Registration
        model_registry, model_version = register_model(session, training_result)
        
        # Step 7: Model Deployment
        udf_name = deploy_model(session, model_registry, model_version)
        
        # Step 8: Setup Monitoring
        setup_monitoring(session, udf_name)
        
        # Step 9: Test Inference
        test_inference(session)
        
        logger.info("Pipeline execution completed successfully")
        
    except Exception as e:
        logger.error(f"Pipeline execution failed: {e}")
        raise


# Internal test

In [12]:
# Step 1: Setup
session = create_snowflake_session()
setup_infrastructure(session)

2025-10-17 10:31:15,483 - __main__ - INFO - Creating Snowflake session for account: DYYADUD-EHC01917
2025-10-17 10:31:15,484 - snowflake.connector.connection - INFO - Snowflake Connector for Python Version: 3.18.0, Python Version: 3.10.18, Platform: macOS-15.1-arm64-arm-64bit
2025-10-17 10:31:15,485 - snowflake.connector.connection - INFO - Connecting to GLOBAL Snowflake domain


2025-10-17 10:31:18,212 - snowflake.snowpark.session - INFO - Snowpark Session information: 
"version" : 1.40.0,
"python.version" : 3.10.18,
"python.connector.version" : 3.18.0,
"python.connector.session.id" : 32191548682915978,
"os.name" : Darwin

2025-10-17 10:31:18,330 - __main__ - INFO - Connected to Snowflake: Row(CURRENT_WAREHOUSE()='COMPUTE_WH', CURRENT_DATABASE()='ML_CREDIT', CURRENT_SCHEMA()=None)
2025-10-17 10:31:18,330 - __main__ - INFO - Setting up Snowflake infrastructure
2025-10-17 10:31:19,826 - __main__ - INFO - Infrastructure setup completed


In [39]:
# Step 2: Data Ingestion
csv_path = "src/datasets/vehicle_insurance_fraud/fraud_oracle.csv"
table_name = ingest_data(session, csv_path)

2025-10-17 10:43:19,995 - __main__ - INFO - Ingesting data from src/datasets/vehicle_insurance_fraud/fraud_oracle.csv
2025-10-17 10:43:20,030 - __main__ - INFO - Loaded CSV with 15420 rows and 33 columns
2025-10-17 10:43:26,360 - __main__ - INFO - Ingested 15420 rows into ML_CREDIT.RAW_DATA.INSURANCE_CLAIMS


In [13]:
# Step 3: Feature Engineering
feature_store = engineer_features(session)        

2025-10-17 10:31:19,845 - __main__ - INFO - Engineering features for fraud detection


2025-10-17 10:31:20 - INFO - snowflake_ml_template.feature_store.core.store - Initialized Feature Store: ML_CREDIT.FEATURES [correlation_id=4b3abd36-5be0-44b3-852c-88b07935c21e]
2025-10-17 10:31:20,157 - snowflake_ml_template.feature_store.core.store - INFO - Initialized Feature Store: ML_CREDIT.FEATURES
2025-10-17 10:31:20,159 - __main__ - INFO - Performing data quality checks
2025-10-17 10:31:20 - INFO - snowflake_ml_template.feature_store.core.store - Registered entity: POLICY with join keys: ['POLICYNUMBER'] [correlation_id=4b3abd36-5be0-44b3-852c-88b07935c21e]
2025-10-17 10:31:20,465 - snowflake_ml_template.feature_store.core.store - INFO - Registered entity: POLICY with join keys: ['POLICYNUMBER']
2025-10-17 10:31:20 - INFO - snowflake_ml_template.feature_store.core.store - Registered entity: CLAIM with join keys: ['POLICYNUMBER', 'MONTH', 'WEEKOFMONTH'] [correlation_id=4b3abd36-5be0-44b3-852c-88b07935c21e]
2025-10-17 10:31:20,616 - snowflake_ml_template.feature_store.core.store 

In [22]:
# Step 4: Training Dataset Generation
training_table = generate_training_dataset(session, feature_store)        

2025-10-17 10:34:37,529 - __main__ - INFO - Generating training dataset


2025-10-17 10:34:40,523 - __main__ - INFO - Generated training dataset with 15420 rows
2025-10-17 10:34:40,524 - __main__ - INFO - Fraud cases: 923, Non-fraud: 14497
2025-10-17 10:34:40,524 - __main__ - INFO - Fraud ratio: 0.0599


In [23]:
training_table

'ML_CREDIT.FEATURES.TRAINING_DATA'

In [24]:
# Step 5: Model Training
training_result = train_model(session, training_table)
training_result      

2025-10-17 10:34:41,256 - __main__ - INFO - Training fraud detection model
2025-10-17 10:34:45,603 - __main__ - INFO - Calculated scale_pos_weight: 15.71
2025-10-17 10:34:45 - INFO - snowflake_ml_template.training.orchestrator - Registered trainer: lightgbm [correlation_id=c9acc57b-7909-4aab-8271-137792ffed8d]
2025-10-17 10:34:45,604 - snowflake_ml_template.training.orchestrator - INFO - Registered trainer: lightgbm
2025-10-17 10:34:45,605 - __main__ - INFO - Starting model training with LightGBM...
2025-10-17 10:34:47 - INFO - snowflake_ml_template.training.orchestrator - Training completed: lightgbm [correlation_id=c9acc57b-7909-4aab-8271-137792ffed8d, status=success]
2025-10-17 10:34:47,518 - snowflake_ml_template.training.orchestrator - INFO - Training completed: lightgbm
2025-10-17 10:34:47,519 - __main__ - INFO - Training successful: /var/folders/0l/1xhjgnkn3f56_jts1rj61xl80000gn/T/lightgbm_model_20251017_153447.joblib
2025-10-17 10:34:47,519 - __main__ - INFO - Model trained wit

TrainingResult(status='success', strategy=<TrainingStrategy.SINGLE_NODE: 'single_node'>, framework=<MLFramework.LIGHTGBM: 'lightgbm'>, model_artifact_path='/var/folders/0l/1xhjgnkn3f56_jts1rj61xl80000gn/T/lightgbm_model_20251017_153447.joblib', metrics={}, best_epoch=0, total_epochs=0, training_samples=0, validation_samples=0, test_samples=0, start_time=datetime.datetime(2025, 10, 17, 15, 34, 45, 606102), end_time=datetime.datetime(2025, 10, 17, 15, 34, 47, 518526), duration_seconds=1.912424, error=None, metadata={})

In [25]:
training_result

TrainingResult(status='success', strategy=<TrainingStrategy.SINGLE_NODE: 'single_node'>, framework=<MLFramework.LIGHTGBM: 'lightgbm'>, model_artifact_path='/var/folders/0l/1xhjgnkn3f56_jts1rj61xl80000gn/T/lightgbm_model_20251017_153447.joblib', metrics={}, best_epoch=0, total_epochs=0, training_samples=0, validation_samples=0, test_samples=0, start_time=datetime.datetime(2025, 10, 17, 15, 34, 45, 606102), end_time=datetime.datetime(2025, 10, 17, 15, 34, 47, 518526), duration_seconds=1.912424, error=None, metadata={})

In [26]:
# Step 6: Model Registration
model_registry, model_version = register_model(session, training_result)        

2025-10-17 10:34:50,968 - __main__ - INFO - Registering model in registry


2025-10-17 10:34:52 - INFO - snowflake_ml_template.registry.manager - Registered model version: vehicle_insurance_fraud_detector v1.0.0 (dev) [correlation_id=2b387c9b-80e2-4d82-b179-c413964602dd, model=vehicle_insurance_fraud_detector, version=1.0.0, stage=dev]
2025-10-17 10:34:52,694 - snowflake_ml_template.registry.manager - INFO - Registered model version: vehicle_insurance_fraud_detector v1.0.0 (dev)
2025-10-17 10:34:52,695 - __main__ - INFO - Model registered: vehicle_insurance_fraud_detector v1.0.0


In [28]:
# Step 7: Model Deployment
udf_name = deploy_model(session, model_registry, model_version, training_result)        

2025-10-17 10:35:18,273 - __main__ - INFO - Deploying model as Warehouse UDF
2025-10-17 10:35:18,274 - __main__ - INFO - Deploying model from: /var/folders/0l/1xhjgnkn3f56_jts1rj61xl80000gn/T/lightgbm_model_20251017_153447.joblib
2025-10-17 10:35:18 - INFO - snowflake_ml_template.deployment.orchestrator - Registered deployment strategy: udf [correlation_id=ba0f5236-e325-40be-b8f2-77661760ecaf]
2025-10-17 10:35:18,274 - snowflake_ml_template.deployment.orchestrator - INFO - Registered deployment strategy: udf


2025-10-17 10:35:19 - INFO - snowflake_ml_template.deployment.strategies.warehouse_udf - Uploaded model to @ML_CREDIT.MODELS.ML_MODELS_STAGE/lightgbm_model_20251017_153447.joblib [correlation_id=ad66895f-df21-4f9a-8a01-1e7860a25971]
2025-10-17 10:35:19,115 - snowflake_ml_template.deployment.strategies.warehouse_udf - INFO - Uploaded model to @ML_CREDIT.MODELS.ML_MODELS_STAGE/lightgbm_model_20251017_153447.joblib
2025-10-17 10:35:23 - INFO - snowflake_ml_template.deployment.orchestrator - Deployment completed: udf [correlation_id=ba0f5236-e325-40be-b8f2-77661760ecaf, status=success]
2025-10-17 10:35:23,458 - snowflake_ml_template.deployment.orchestrator - INFO - Deployment completed: udf
2025-10-17 10:35:23,459 - __main__ - INFO - Model deployed as UDF: ML_CREDIT.MODELS.vehicle_fraud_predict_udf


In [30]:
# Step 8: Setup Monitoring
setup_monitoring(session, udf_name)

2025-10-17 10:35:33,702 - __main__ - INFO - Setting up model monitoring


2025-10-17 10:35:34,040 - __main__ - INFO - Monitoring setup completed


In [38]:
# Step 9: Test Inference
test_inference(session)    

2025-10-17 10:39:23,515 - __main__ - INFO - Testing model inference (batch)


2025-10-17 10:39:23,758 - __main__ - INFO - Scoring 10 rows with 32 features
2025-10-17 10:39:27,348 - __main__ - INFO - Batch inference: 1/10 predicted fraud (>=0.5)


[{'prediction': 0.0, 'label': 0},
 {'prediction': 1.0, 'label': 1},
 {'prediction': 0.0, 'label': 0},
 {'prediction': 0.0, 'label': 0},
 {'prediction': 0.0, 'label': 0},
 {'prediction': 0.0, 'label': 0},
 {'prediction': 0.0, 'label': 0},
 {'prediction': 0.0, 'label': 0},
 {'prediction': 0.0, 'label': 0},
 {'prediction': 0.0, 'label': 0}]