# End-to-End ML Workflow with Observability

This notebook demonstrates a complete machine learning workflow for customer churn detection using Snowflake's ML capabilities. The workflow includes:

1. **Environment Preparation** - Set up the ML environment and dependencies
2. **Initial Data Ingestion** - Load and process the first batch of data
3. **First Model Training** - Train and deploy the initial model
4. **Iterative Model Improvement** - Continuously add data and retrain models with observability

## Key Features
- Customer churn prediction using sales and feedback data
- Sentiment analysis of customer feedback using Cortex AI
- Feature engineering with Snowflake analytical functions
- Model Registry and Feature Store integration
- Model monitoring and drift detection
- Automated retraining based on performance metrics
- Full ML observability and lineage tracking


# 1. Environment Preparation

Setting up the ML environment, importing necessary packages, and configuring Snowflake connections.


In [27]:
# Import essential packages for ML workflow
import streamlit as st
import pandas as pd
import logging
from datetime import datetime, timedelta
import time
import json

# Snowflake packages
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T
from snowflake.snowpark.types import DecimalType, FloatType, IntegerType, DoubleType, LongType

# Snowflake ML packages
from snowflake.ml.feature_store import FeatureStore, FeatureView, Entity, CreationMode
from snowflake.ml.registry import Registry
from snowflake.ml.model import type_hints
from snowflake.cortex import sentiment
from snowflake.snowpark import Session


# ML packages  
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import f1_score
from xgboost import XGBClassifier

# Configure logging
logger = logging.getLogger("e2e-ml-workflow")
numeric_types = (DecimalType, FloatType, IntegerType, DoubleType, LongType)

print("✅ Packages imported successfully")

# Get active Snowflake session
session = Session.builder.getOrCreate()
print("✅ Snowflake session established")


✅ Packages imported successfully
✅ Snowflake session established


In [28]:
# Environment Configuration
SCHEMA = 'E2E_DEMO'
WAREHOUSE = 'COMPUTE_WH'  # Modify as needed
CHURN_WINDOW = 30  # Days to define churn

# Create and use dedicated schema
session.sql(f'CREATE OR REPLACE SCHEMA {SCHEMA}').collect()
session.sql(f'USE SCHEMA {SCHEMA}').collect()
session.sql('CREATE OR REPLACE STAGE ML_STAGE').collect()

print(f"✅ Environment configured - Schema: {SCHEMA}")

# Get current context
db = session.get_current_database()
sc = session.get_current_schema()
print(f"📍 Working in Database: {db}, Schema: {sc}")


✅ Environment configured - Schema: E2E_DEMO
📍 Working in Database: "CC_ML_JOBS", Schema: "E2E_DEMO"


In [29]:
# Setup Feature Store and Model Registry
mr_schema = f'{sc}_MODEL_REGISTRY'.replace('"', '')
fs_schema = f'{sc}_FEATURE_STORE'.replace('"', '')

# Clean up and create Model Registry
session.sql(f'DROP SCHEMA IF EXISTS {mr_schema}').collect()
session.sql(f'DROP SCHEMA IF EXISTS {fs_schema}').collect()

# Create Model Registry
cs = session.get_current_schema()
session.sql(f'CREATE SCHEMA {mr_schema}').collect()
mr = Registry(session=session, database_name=db, schema_name=mr_schema)
session.sql(f'USE SCHEMA {cs}').collect()

# Create Feature Store
fs = FeatureStore(
    session=session, 
    database=db, 
    name=fs_schema,
    default_warehouse=WAREHOUSE, 
    creation_mode=CreationMode.CREATE_IF_NOT_EXIST
)

print(f"✅ Model Registry created: {mr_schema}")
print(f"✅ Feature Store created: {fs_schema}")


✅ Model Registry created: E2E_DEMO_MODEL_REGISTRY
✅ Feature Store created: E2E_DEMO_FEATURE_STORE


# 2. Initial Data Ingestion

Setting up data tables, staging areas, and ingesting the first batch of data for training.


In [30]:
# Set the right database and schema as creating the model registry and feature store can change it
session.sql(f'USE DATABASE {db}').collect()
session.sql(f'USE SCHEMA {sc}').collect()


# Create core data tables
print("🔄 Creating core data tables...")

# Create SALES table
session.sql("""
CREATE OR REPLACE TABLE SALES (
    TRANSACTION_ID VARCHAR,
    CUSTOMER_ID VARCHAR,
    TRANSACTION_DATE DATE,
    DISCOUNT_APPLIED BOOLEAN,
    NUM_ITEMS NUMBER,
    PAYMENT_METHOD VARCHAR, 
    TOTAL_AMOUNT FLOAT
)
""").collect()

# Create CUSTOMERS table
session.sql("""
CREATE OR REPLACE TABLE CUSTOMERS (
    CUSTOMER_ID VARCHAR,
    AGE BIGINT,
    CUSTOMER_SEGMENT VARCHAR,
    GENDER VARCHAR,
    LOCATION VARCHAR,
    SIGNUP_DATE DATE
)
""").collect()

# Create FEEDBACK_RAW table
session.sql("""
CREATE OR REPLACE TABLE FEEDBACK_RAW (
    CHAT_DATE DATE,
    COMMENT VARCHAR,
    CUSTOMER_ID VARCHAR,
    FEEDBACK_ID VARCHAR,
    INTERNAL_ID BIGINT
)
""").collect()

# Create stream for processing feedback
session.sql("""
CREATE OR REPLACE STREAM FEEDBACK_RAW_STREAM 
    ON TABLE FEEDBACK_RAW
    APPEND_ONLY = TRUE
""").collect()

# Create table for processed sentiment
session.sql("""
CREATE OR REPLACE TABLE FEEDBACK_SENTIMENT (
    FEEDBACK_ID VARCHAR,
    CHAT_DATE DATE,
    CUSTOMER_ID VARCHAR,
    INTERNAL_ID BIGINT,
    COMMENT VARCHAR,
    SENTIMENT FLOAT
)
""").collect()

print("✅ Core data tables created successfully")
print("   - SALES")
print("   - CUSTOMERS") 
print("   - FEEDBACK_RAW")
print("   - FEEDBACK_RAW_STREAM")
print("   - FEEDBACK_SENTIMENT")


🔄 Creating core data tables...
✅ Core data tables created successfully
   - SALES
   - CUSTOMERS
   - FEEDBACK_RAW
   - FEEDBACK_RAW_STREAM
   - FEEDBACK_SENTIMENT


In [31]:
# Setup staging area and data ingestion tracking
session.sql("""
CREATE OR REPLACE STAGE CSV
DIRECTORY = (ENABLE = TRUE)
URL = 's3://sfquickstarts/vhol_end_2_end_ml_with_observability/';
""").collect()

# Create tracking table for file ingestion
session.sql("""
CREATE OR REPLACE TABLE FILES_INGESTED (
    YEAR INT,
    MONTH INT,
    FILE_TYPE VARCHAR,
    FILE_NAME VARCHAR,
    STAGE_NAME VARCHAR,
    INGESTED BOOLEAN
)
""").collect()

print("✅ Staging area and tracking tables created")


✅ Staging area and tracking tables created


In [32]:
# Data ingestion utility functions
import re
from typing import List, Tuple

def get_year_month_files(session, stage_name: str, file_prefix: str) -> List[Tuple[int, int, str]]:
    """Extract year and month information from files in staging area"""
    list_files_query = f"LIST @{stage_name}"
    files = session.sql(list_files_query).collect()
    
    file_names = [file["name"].split("/")[-1] for file in files]
    file_pattern = re.compile(rf"{re.escape(file_prefix)}_(\d+)_(\d+)\.csv")
    
    results = []
    for file_name in file_names:
        match = file_pattern.match(file_name)
        if match:
            year, month = int(match.group(1)), int(match.group(2))
            results.append((year, month, file_name, stage_name))
    
    return sorted(results)

def insert_file_tracking(table, db, sc, files):
    """Track files for ingestion"""
    for file in files:
        year, month, file_name, stage_name = file
        sql_cmd = f"""
            INSERT INTO {db}.{sc}.FILES_INGESTED
            (YEAR, MONTH, FILE_TYPE, FILE_NAME, STAGE_NAME, INGESTED)
            VALUES ('{year}', '{month}', '{table}', '{file_name}', '{stage_name}', False)
        """
        session.sql(sql_cmd).collect()

def load_into_table(session, table_name, file_name):
    """Load CSV file into Snowflake table"""
    sql_cmd = f""" 
        COPY INTO {table_name}
        FROM {file_name}  
        FILE_FORMAT = (TYPE = 'CSV' FIELD_OPTIONALLY_ENCLOSED_BY='"')  
        ON_ERROR = 'ABORT_STATEMENT';      
    """
    session.sql(sql_cmd).collect()

print("✅ Data ingestion utilities defined")


✅ Data ingestion utilities defined


In [33]:
# Discover and register files for ingestion
stage_name = "PUBLIC.CSV"

# Get available files
sales_files = get_year_month_files(session, stage_name, 'sales')
feedback_files = get_year_month_files(session, stage_name, 'feedback_raw')

# Track files for ingestion
insert_file_tracking('sales', db, sc, sales_files)
insert_file_tracking('feedback_raw', db, sc, feedback_files)

print(f"📋 Found {len(sales_files)} sales files and {len(feedback_files)} feedback files")

# Load customers (static data)
load_into_table(session, f'{db}.{sc}.CUSTOMERS', f'@{stage_name}/customers.csv')
print("✅ Customer data loaded")


📋 Found 8 sales files and 13 feedback files
✅ Customer data loaded


In [34]:
# Load initial data (first 4 months for training)
def copy_next_file(session, db: str, sc: str):
    """Copy the next unprocessed file from staging to tables"""
    files_df = session.table(f'{db}.{sc}.FILES_INGESTED')
    
    # Get next sales file
    sales_file = files_df.filter(
        (F.col("file_type") == 'sales') & (F.col("ingested") == False)
    ).select("year", "month", "file_name", "stage_name").order_by("year", "month").limit(1)
    
    sales_pd = sales_file.to_pandas()
    if sales_pd.empty:
        print("No unprocessed sales files found.")
        return False
    
    # Load sales data
    year, month = int(sales_pd.YEAR[0]), int(sales_pd.MONTH[0])
    file_name, stage_name = sales_pd.FILE_NAME[0], sales_pd.STAGE_NAME[0]
    
    load_into_table(session, f'{db}.{sc}.SALES', f'@{stage_name}/{file_name}')
    
    # Mark sales file as processed
    session.sql(f"""
        UPDATE {db}.{sc}.FILES_INGESTED
        SET INGESTED = TRUE
        WHERE FILE_NAME = '{file_name}' AND FILE_TYPE = 'sales'
    """).collect()
    
    # Load corresponding feedback file
    feedback_file = files_df.filter(
        (F.col("file_type") == 'feedback_raw') & 
        (F.col("ingested") == False) &
        (F.col("YEAR") == year) & 
        (F.col("MONTH") == month)
    ).limit(1)
    
    if feedback_file.count() > 0:
        feedback_pd = feedback_file.to_pandas()
        feedback_name = feedback_pd.FILE_NAME[0]
        load_into_table(session, f'{db}.{sc}.FEEDBACK_RAW', f'@{stage_name}/{feedback_name}')
        
        # Mark feedback file as processed
        session.sql(f"""
            UPDATE {db}.{sc}.FILES_INGESTED
            SET INGESTED = TRUE
            WHERE FILE_NAME = '{feedback_name}' AND FILE_TYPE = 'feedback_raw'
        """).collect()
    
    print(f"📥 Loaded data for {year}-{month:02d}")
    return True

# Sentiment processing function
def process_sentiment():
    """Process sentiment for new feedback using Cortex AI"""
    feedback_stream_df = session.table("FEEDBACK_RAW_STREAM")
    
    if feedback_stream_df.count() > 0:
        cols = ['FEEDBACK_ID', 'CHAT_DATE', 'CUSTOMER_ID', 'INTERNAL_ID', 'COMMENT']
        feedback_sentiment_df = feedback_stream_df.select(cols).with_columns(
            ["SENTIMENT"], [sentiment(F.col("COMMENT"))]
        )
        feedback_sentiment_df.write.mode("append").save_as_table("FEEDBACK_SENTIMENT")
        print("✅ Sentiment processed for new feedback")

print("✅ Data loading functions defined")


✅ Data loading functions defined


In [35]:
# Feature Engineering Functions
def create_customer_features(session, db: str, sc: str, cur_date: datetime, table_name: str):
    """Create customer behavioral features for churn prediction"""
    
    # Load data tables
    customers_df = session.table(f'{db}.{sc}.CUSTOMERS')
    sales_df = session.table(f'{db}.{sc}.SALES').filter(F.col("TRANSACTION_DATE") < F.lit(cur_date))
    feedback_df = session.table(f'{db}.{sc}.FEEDBACK_SENTIMENT').filter(F.col("CHAT_DATE") < F.lit(cur_date))
    
    # Sales aggregations by customer
    sales_agg_df = sales_df.group_by("CUSTOMER_ID").agg(
        F.max("TRANSACTION_DATE").alias("LAST_PURCHASE_DATE"),
        F.sum("TOTAL_AMOUNT").alias("TOTAL_CUSTOMER_VALUE")
    )
    
    # Custom column naming for time-series features
    def custom_column_naming(input_col, agg, window):
        return f"{agg}_{input_col}_{window.replace('-', 'PAST_')}"
    
    # Time-series aggregations
    sales_ts_df = sales_df.analytics.time_series_agg(
        time_col="TRANSACTION_DATE",
        aggs={"TOTAL_AMOUNT": ["SUM", "COUNT"]},
        windows=["-7D", "-1MM", "-2MM", "-3MM"],
        sliding_interval="1D",
        group_by=["CUSTOMER_ID"],
        col_formatter=custom_column_naming
    )
    
    # Join sales aggregations
    sales_features_df = sales_agg_df.join(
        sales_ts_df,
        (sales_agg_df.LAST_PURCHASE_DATE == sales_ts_df.TRANSACTION_DATE) &
        (sales_agg_df.CUSTOMER_ID == sales_ts_df.CUSTOMER_ID),
        "left"
    ).select(
        sales_agg_df["CUSTOMER_ID"].alias("CUSTOMER_ID"),
        sales_agg_df["TOTAL_CUSTOMER_VALUE"],
        sales_agg_df["LAST_PURCHASE_DATE"],
        sales_ts_df["SUM_TOTAL_AMOUNT_PAST_7D"],
        sales_ts_df["SUM_TOTAL_AMOUNT_PAST_1MM"],
        sales_ts_df["SUM_TOTAL_AMOUNT_PAST_2MM"],
        sales_ts_df["SUM_TOTAL_AMOUNT_PAST_3MM"],
        sales_ts_df["COUNT_TOTAL_AMOUNT_PAST_7D"].alias("COUNT_ORDERS_PAST_7D"),
        sales_ts_df["COUNT_TOTAL_AMOUNT_PAST_1MM"].alias("COUNT_ORDERS_PAST_1MM"),
        sales_ts_df["COUNT_TOTAL_AMOUNT_PAST_2MM"].alias("COUNT_ORDERS_PAST_2MM"),
        sales_ts_df["COUNT_TOTAL_AMOUNT_PAST_3MM"].alias("COUNT_ORDERS_PAST_3MM")
    )
    
    # Feedback features
    latest_feedback_df = feedback_df.group_by("CUSTOMER_ID").agg(F.max("CHAT_DATE").alias("CHAT_DATE"))
    
    feedback_agg_df = feedback_df.analytics.moving_agg(
        aggs={"SENTIMENT": ["MIN", "AVG"]},
        window_sizes=[2, 3, 4],
        order_by=["CHAT_DATE"],
        group_by=["CUSTOMER_ID"]
    )
    
    feedback_features_df = latest_feedback_df.join(feedback_agg_df, "CUSTOMER_ID", "left").select(
        latest_feedback_df["CUSTOMER_ID"],
        feedback_agg_df["SENTIMENT_MIN_2"],
        feedback_agg_df["SENTIMENT_MIN_3"],
        feedback_agg_df["SENTIMENT_MIN_4"],
        feedback_agg_df["SENTIMENT_AVG_2"],
        feedback_agg_df["SENTIMENT_AVG_3"],
        feedback_agg_df["SENTIMENT_AVG_4"]
    )
    
    # Combine all features
    features_df = customers_df.join(sales_features_df, "CUSTOMER_ID", "left") \
                             .join(feedback_features_df, "CUSTOMER_ID", "left") \
                             .select(
                                 customers_df["CUSTOMER_ID"],
                                 customers_df["AGE"],
                                 customers_df["GENDER"],
                                 customers_df["LOCATION"],
                                 customers_df["CUSTOMER_SEGMENT"],
                                 sales_features_df["LAST_PURCHASE_DATE"],
                                 feedback_features_df["SENTIMENT_MIN_2"],
                                 feedback_features_df["SENTIMENT_MIN_3"],
                                 feedback_features_df["SENTIMENT_MIN_4"],
                                 feedback_features_df["SENTIMENT_AVG_2"],
                                 feedback_features_df["SENTIMENT_AVG_3"],
                                 feedback_features_df["SENTIMENT_AVG_4"],
                                 sales_features_df["SUM_TOTAL_AMOUNT_PAST_7D"],
                                 sales_features_df["SUM_TOTAL_AMOUNT_PAST_1MM"],
                                 sales_features_df["SUM_TOTAL_AMOUNT_PAST_2MM"],
                                 sales_features_df["SUM_TOTAL_AMOUNT_PAST_3MM"],
                                 sales_features_df["COUNT_ORDERS_PAST_7D"],
                                 sales_features_df["COUNT_ORDERS_PAST_1MM"],
                                 sales_features_df["COUNT_ORDERS_PAST_2MM"],
                                 sales_features_df["COUNT_ORDERS_PAST_3MM"],
                                 F.datediff("day", sales_features_df["LAST_PURCHASE_DATE"], F.lit(cur_date)).alias("DAYS_SINCE_LAST_PURCHASE"),
                                 F.lit(cur_date).alias("TIMESTAMP")
                             ).filter(sales_features_df["LAST_PURCHASE_DATE"].isNotNull()) \
                              .dropDuplicates(["CUSTOMER_ID", "TIMESTAMP"])
    
    # Fill nulls with 0
    fill_columns = [
        "SENTIMENT_MIN_2", "SENTIMENT_MIN_3", "SENTIMENT_MIN_4", 
        "SENTIMENT_AVG_2", "SENTIMENT_AVG_3", "SENTIMENT_AVG_4",
        "SUM_TOTAL_AMOUNT_PAST_7D", "SUM_TOTAL_AMOUNT_PAST_1MM", 
        "SUM_TOTAL_AMOUNT_PAST_2MM", "SUM_TOTAL_AMOUNT_PAST_3MM",
        "COUNT_ORDERS_PAST_7D", "COUNT_ORDERS_PAST_1MM", 
        "COUNT_ORDERS_PAST_2MM", "COUNT_ORDERS_PAST_3MM"
    ]
    
    for column in fill_columns:
        features_df = features_df.fillna({column: 0})
    
    # Save features
    features_df.write.mode("append").save_as_table(table_name)
    print(f"✅ Features created for {cur_date}")

print("✅ Feature engineering function defined")

✅ Feature engineering function defined


In [36]:
# Labeling function for churn detection
def create_churn_labels(session, db: str, sc: str, features_table: str, output_table: str, churn_days: int):
    """Label customers as churned based on future purchase behavior"""
    
    features_df = session.table(f'{db}.{sc}.{features_table}')
    sales_df = session.table(f'{db}.{sc}.SALES')

    sales_filtered = sales_df.select(F.col("CUSTOMER_ID"), F.col("TRANSACTION_DATE"))

    # Find next transaction for each customer after their feature timestamp
    next_transaction_df = features_df.join(
        sales_filtered.select("CUSTOMER_ID", "TRANSACTION_DATE"),
        "CUSTOMER_ID",
        "left"
    ).filter(
        F.col("TRANSACTION_DATE") > F.col("LAST_PURCHASE_DATE")
    ).group_by("CUSTOMER_ID", "TIMESTAMP").agg(
        F.min("TRANSACTION_DATE").alias("NEXT_TRANSACTION_DATE")
    )
    
    # Create labeled dataset
    labeled_df = features_df.join(
        next_transaction_df, 
        ["CUSTOMER_ID", "TIMESTAMP"], 
        "left"
    ).select(
        features_df["*"],
        F.when(
            (F.col("NEXT_TRANSACTION_DATE").is_null()) |
            ((F.col("NEXT_TRANSACTION_DATE") - F.col("LAST_PURCHASE_DATE")) > churn_days),
            1
        ).otherwise(0).alias("CHURNED"),
        F.col("NEXT_TRANSACTION_DATE")
    )
    
    # Save labeled dataset
    labeled_df.write.mode("overwrite").save_as_table(output_table)
    print(f"✅ Labels created for churn window: {churn_days} days")

print("✅ Labeling function defined")

✅ Labeling function defined


In [37]:
# Define table names
features_table = 'CUSTOMER_FEATURES'
labeled_table = 'CUSTOMER_FEATURES_LABELED'

# Drop existing tables for fresh start
session.sql(f'DROP TABLE IF EXISTS {features_table}').collect()
session.sql(f'DROP TABLE IF EXISTS {labeled_table}').collect()

# Create features for each month loaded
sales_df = session.table("SALES")

# Load first 4 months of data for initial training
print("🔄 Loading initial training data (4 months)...")

for i in range(4):
    start_time = datetime.now()
    
    print(f"\n--- Loading month {i+1}/4 ---")
    if not copy_next_file(session, db, sc):
        break
        
    # Process sentiment for new feedback
    process_sentiment()
    
    #calculate features for the latest transaction timestamp
    latest_transaction = sales_df.select(F.max(F.col("transaction_date"))).collect()[0][0]

    create_customer_features(session, db, sc, latest_transaction, features_table)

    #create churn labels
    create_churn_labels(session, db, sc, features_table, labeled_table, CHURN_WINDOW)


    elapsed = datetime.now() - start_time
    print(f"⏱️  Month {i+1} completed in {elapsed.total_seconds():.1f}s")

print("\n✅ Initial data loading completed")

# Check data loaded
sales_count = session.table("SALES").count()
customers_count = session.table("CUSTOMERS").count()
feedback_count = session.table("FEEDBACK_SENTIMENT").count()

print(f"📊 Data Summary:")
print(f"   Sales records: {sales_count:,}")
print(f"   Customers: {customers_count:,}")
print(f"   Feedback records: {feedback_count:,}")

# Check label distribution
label_distribution = session.sql(f"""
    SELECT 
        TIMESTAMP,
        SUM(CASE WHEN CHURNED = 0 THEN 1 ELSE 0 END) AS NOT_CHURNED,
        SUM(CASE WHEN CHURNED = 1 THEN 1 ELSE 0 END) AS CHURNED
    FROM {labeled_table}
    GROUP BY TIMESTAMP
    ORDER BY TIMESTAMP
""").collect()

print("\n📊 Label Distribution by Timestamp:")
for row in label_distribution:
    timestamp = row["TIMESTAMP"]
    not_churned = row["NOT_CHURNED"]
    churned = row["CHURNED"]
    total = not_churned + churned
    churn_rate = churned / total * 100 if total > 0 else 0
    print(f"   {timestamp}: {not_churned} not churned, {churned} churned ({churn_rate:.1f}% churn rate)")

print("✅ Features and labels created")

🔄 Loading initial training data (4 months)...

--- Loading month 1/4 ---
📥 Loaded data for 2024-05
✅ Sentiment processed for new feedback
✅ Features created for 2024-05-31
✅ Labels created for churn window: 30 days
⏱️  Month 1 completed in 23.8s

--- Loading month 2/4 ---
📥 Loaded data for 2024-06
✅ Sentiment processed for new feedback
✅ Features created for 2024-06-30
✅ Labels created for churn window: 30 days
⏱️  Month 2 completed in 24.1s

--- Loading month 3/4 ---
📥 Loaded data for 2024-07
✅ Sentiment processed for new feedback
✅ Features created for 2024-07-31
✅ Labels created for churn window: 30 days
⏱️  Month 3 completed in 25.9s

--- Loading month 4/4 ---
📥 Loaded data for 2024-08
✅ Sentiment processed for new feedback
✅ Features created for 2024-08-31
✅ Labels created for churn window: 30 days
⏱️  Month 4 completed in 25.3s

✅ Initial data loading completed
📊 Data Summary:
   Sales records: 29,540
   Customers: 5,000
   Feedback records: 3,318

📊 Label Distribution by Timesta

# 3. First Model Training

Creating features, setting up the Feature Store, training the initial model, and registering it.


In [38]:
# Setup Feature Store
print("🔄 Setting up Feature Store...")

# Create entity for customers
if "CUSTOMER_ENT" not in json.loads(fs.list_entities().select(F.to_json(F.array_agg("NAME", True))).collect()[0][0]):
    customer_entity = Entity(
        name="CUSTOMER_ENT", 
        join_keys=["CUSTOMER_ID"],
        desc="Primary Key for CUSTOMER"
    )
    fs.register_entity(customer_entity)
else:
    customer_entity = fs.get_entity("CUSTOMER_ENT")

print("✅ Customer entity created")

# Create Feature View
labeled_df = session.table(f'{sc}.{labeled_table}')

fv_name = "FV_CUSTOMER_CHURN"
fv_version = "V_1"

try:
    feature_view = fs.get_feature_view(name=fv_name, version=fv_version)
    print(f"✅ Feature View {fv_name}_{fv_version} already exists")
except:
    feature_view_instance = FeatureView(
        name=fv_name, 
        entities=[customer_entity], 
        feature_df=labeled_df,
        timestamp_col="TIMESTAMP",
        refresh_freq=None,
        desc="Features for customer churn detection"
    )
    
    feature_view = fs.register_feature_view(
        feature_view=feature_view_instance, 
        version=fv_version, 
        block=True
    )
    print(f"✅ Feature View {fv_name}_{fv_version} created")

# Show feature view info
print("\n📋 Feature View registered in Feature Store:")
fs.list_feature_views().show(5)


🔄 Setting up Feature Store...
✅ Customer entity created
✅ Feature View FV_CUSTOMER_CHURN_V_1 created

📋 Feature View registered in Feature Store:
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"NAME"             |"VERSION"  |"DATABASE_NAME"  |"SCHEMA_NAME"           |"CREATED_ON"                |"OWNER"        |"DESC"                                 |"ENTITIES"        |"REFRESH_FREQ"  |"REFRESH_MODE"  |"SCHEDULING_STATE"  |"WAREHOUSE"  |"CLUSTER_BY"  |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|FV_CUSTOMER_CHURN  |V_1        |CC_ML_JOBS       |E2E_DEMO

In [39]:
# Create training and validation datasets
print("🔄 Creating datasets for training...")

# Get available timestamps
timestamps = session.table(labeled_table).select("TIMESTAMP").distinct().sort("TIMESTAMP").collect()

# Use second timestamp for training, third for validation  
training_timestamp = timestamps[1]["TIMESTAMP"]
validation_timestamp = timestamps[2]["TIMESTAMP"]

print(f"Training timestamp: {training_timestamp}")
print(f"Validation timestamp: {validation_timestamp}")

def create_dataset(fs, feature_view, name, timestamp):
    """Create dataset from Feature Store"""
    spine_df = feature_view.feature_df.filter(
        F.col("TIMESTAMP") == F.lit(timestamp)
    ).group_by('CUSTOMER_ID').agg(F.max('TIMESTAMP').as_('TIMESTAMP'))
    
    dataset = fs.generate_dataset(
        name=name, 
        version='v1',
        spine_df=spine_df, 
        features=[feature_view], 
        spine_timestamp_col='TIMESTAMP'
    )
    
    # Convert to Snowpark DataFrame and handle data types
    dataset_df = dataset.read.to_snowpark_dataframe()
    
    # Convert decimal columns to float
    decimal_columns = [field.name for field in dataset_df.schema.fields
                      if isinstance(field.datatype, numeric_types)]
    
    for column_name in decimal_columns:
        dataset_df = dataset_df.with_column(
            column_name,
            F.col(column_name).cast("float")
        )
    
    return dataset_df

# Create datasets
training_dataset = create_dataset(fs, feature_view, 'CHURN_TRAINING', training_timestamp)
validation_dataset = create_dataset(fs, feature_view, 'CHURN_VALIDATION', validation_timestamp)

print(f"✅ Training dataset: {training_dataset.count():,} records")
print(f"✅ Validation dataset: {validation_dataset.count():,} records")


🔄 Creating datasets for training...
Training timestamp: 2024-06-30
Validation timestamp: 2024-07-31




✅ Training dataset: 4,252 records
✅ Validation dataset: 4,535 records


In [40]:
# Train initial model
print("🔄 Training initial churn prediction model...")

# Define feature columns
categorical_cols = ['GENDER', 'LOCATION', 'CUSTOMER_SEGMENT']
numerical_cols = [
    "AGE", "SENTIMENT_MIN_2", "SENTIMENT_MIN_3", "SENTIMENT_MIN_4", 
    "SENTIMENT_AVG_2", "SENTIMENT_AVG_3", "SENTIMENT_AVG_4",
    "SUM_TOTAL_AMOUNT_PAST_7D", "SUM_TOTAL_AMOUNT_PAST_1MM", 
    "SUM_TOTAL_AMOUNT_PAST_2MM", "SUM_TOTAL_AMOUNT_PAST_3MM",
    "COUNT_ORDERS_PAST_7D", "COUNT_ORDERS_PAST_1MM", 
    "COUNT_ORDERS_PAST_2MM", "COUNT_ORDERS_PAST_3MM"
]
feature_cols = categorical_cols + numerical_cols
target_col = "CHURNED"

def train_churn_model(feature_df):
    """Train XGBoost model for churn prediction"""
    
    # Convert to pandas
    train_df = feature_df.to_pandas()
    
    # Split data
    train_data, test_data = train_test_split(train_df, test_size=0.2, random_state=111)
    
    # Create preprocessing pipeline
    preprocessor = ColumnTransformer(
        transformers=[
            ("ordinal", OrdinalEncoder(), categorical_cols),
            ("scaler", StandardScaler(), numerical_cols)
        ]
    )
    
    # Create model pipeline
    pipeline = Pipeline(
        steps=[ 
            ("preprocessor", preprocessor),
            ("model", XGBClassifier(random_state=42))
        ]
    )
    
    # Train model
    X_train = train_data[feature_cols]
    y_train = train_data[target_col]
    
    pipeline.fit(X_train, y_train)
    
    # Evaluate on training set
    train_predictions = pipeline.predict(X_train)
    train_f1 = f1_score(y_train, train_predictions)
    
    # Evaluate on test set
    X_test = test_data[feature_cols]
    y_test = test_data[target_col]
    
    test_predictions = pipeline.predict(X_test)
    test_f1 = f1_score(y_test, test_predictions)
    
    return {
        'model': pipeline,
        'train_f1_score': train_f1,
        'test_f1_score': test_f1
    }

# Train the model
model_result = train_churn_model(training_dataset)

print(f"✅ Model training completed:")
print(f"   Training F1 Score: {model_result['train_f1_score']:.4f}")
print(f"   Test F1 Score: {model_result['test_f1_score']:.4f}")

# Validate on validation dataset
validation_df = validation_dataset.to_pandas()
val_predictions = model_result['model'].predict(validation_df[feature_cols])
val_f1 = f1_score(validation_df[target_col], val_predictions)

print(f"   Validation F1 Score: {val_f1:.4f}")


🔄 Training initial churn prediction model...
✅ Model training completed:
   Training F1 Score: 0.9595
   Test F1 Score: 0.6667
   Validation F1 Score: 0.7670


In [41]:
# Enhanced Model Performance Functions - Using sklearn for direct F1 calculation

def get_model_performance_sklearn(prediction_column, table_name='CUSTOMER_CHURN_PREDICTED_PROD2', limit_latest_timestamp=True):
    """
    Get F1 score using sklearn directly from the prediction table
    
    Parameters:
    - prediction_column: str, one of 'CHURNED_PRED_PROD', 'CHURNED_PRED_BASE', 'CHURNED_PRED_RETRAIN'
    - table_name: str, table containing predictions and true labels (default: 'CUSTOMER_CHURN_PREDICTED_PROD2')
    - limit_latest_timestamp: bool, whether to filter for latest timestamp only (default: True)
    
    Returns:
    - float: F1 score calculated using sklearn
    """
    from sklearn.metrics import f1_score
    
    # Validate prediction column
    valid_columns = ['CHURNED_PRED_PROD', 'CHURNED_PRED_BASE', 'CHURNED_PRED_RETRAIN']
    if prediction_column not in valid_columns:
        raise ValueError(f"prediction_column must be one of {valid_columns}")
    
    try:
        # Get the data from the table
        table_df = session.table(f'{sc}.{table_name}')
        
        if limit_latest_timestamp:
            # Get latest timestamp with labels (same logic as original function)
            timestamps = session.table(f'{sc}.{labeled_table}').select("TIMESTAMP").distinct().sort("TIMESTAMP").collect()
            latest_timestamp = timestamps[-2]["TIMESTAMP"]
            
            # Filter for latest timestamp
            table_df = table_df.filter(F.col("TIMESTAMP") == latest_timestamp)
        
        # Get only rows where both true labels and predictions are not null
        filtered_df = table_df.filter(
            (F.col("CHURNED").is_not_null()) & 
            (F.col(prediction_column).is_not_null())
        ).select("CHURNED", prediction_column)
        
        # Convert to pandas for sklearn
        pandas_df = filtered_df.to_pandas()
        
        if len(pandas_df) == 0:
            print(f"📍📍 Warning: No data available for {prediction_column}")
            return 1.0  # Default high score if no data
        
        # Calculate F1 score using sklearn
        y_true = pandas_df['CHURNED'].astype(int)
        y_pred = pandas_df[prediction_column].astype(int)
        
        f1 = f1_score(y_true, y_pred)
        
        return f1
        
    except Exception as e:
        print(f"📍📍 Warning: Error calculating F1 score for {prediction_column}: {str(e)}")
        return 1.0  # Default if calculation fails



In [42]:
# Register model in Model Registry
print("🔄 Registering model in Model Registry...")

# Log the trained model
baseline_model = mr.log_model(
    model=model_result['model'],
    model_name="ChurnDetector",
    version_name="baseline",
    conda_dependencies=["snowflake-ml-python", "xgboost", "scikit-learn"],
    sample_input_data=training_dataset.select(feature_cols).limit(100),
    task=type_hints.Task.TABULAR_BINARY_CLASSIFICATION,
    target_platforms=["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"],                
    comment="Baseline model for customer churn detection"
)

# Set metrics for the model
baseline_model.set_metric(metric_name="train_f1_score", value=model_result['train_f1_score'])
baseline_model.set_metric(metric_name="test_f1_score", value=model_result['test_f1_score'])
baseline_model.set_metric(metric_name="validation_f1_score", value=val_f1)

# Set as default version
session.sql(f'USE SCHEMA {mr_schema}').collect()
session.sql('ALTER MODEL ChurnDetector SET DEFAULT_VERSION = baseline;').collect()
session.sql(f'USE SCHEMA {sc}').collect()

print("✅ Baseline model registered and set as default")
print(f"   Model name: ChurnDetector")
print(f"   Version: baseline")
print(f"   F1 Score: {val_f1:.4f}")


🔄 Registering model in Model Registry...


  self.manifest.save(


✅ Baseline model registered and set as default
   Model name: ChurnDetector
   Version: baseline
   F1 Score: 0.7670


# 4. Iterative Model Improvement & Observability

Setting up model monitoring and implementing continuous learning with new data.


In [43]:
# Setup Model Monitoring Infrastructure
print("🔄 Setting up model monitoring infrastructure...")

# Create prediction tables for monitoring - using single table format like CC_RETAIL_4_3_STREAMS_TRAINING
monitoring_tables = ["customer_churn_baseline_predicted", "CUSTOMER_CHURN_PREDICTED_PROD2"]

for table in monitoring_tables:
    session.sql(f"""       
        CREATE OR REPLACE TABLE {table} (
            CUSTOMER_ID VARCHAR(16777216),
            TIMESTAMP TIMESTAMP_NTZ(9),
            GENDER VARCHAR(16777216),
            LOCATION VARCHAR(16777216),
            CUSTOMER_SEGMENT VARCHAR(16777216),
            LAST_PURCHASE_DATE DATE,
            NEXT_TRANSACTION_DATE DATE,
            AGE FLOAT,
            SENTIMENT_MIN_2 FLOAT,
            SENTIMENT_MIN_3 FLOAT,
            SENTIMENT_MIN_4 FLOAT,
            SENTIMENT_AVG_2 FLOAT,
            SENTIMENT_AVG_3 FLOAT,
            SENTIMENT_AVG_4 FLOAT,
            SUM_TOTAL_AMOUNT_PAST_7D FLOAT,
            SUM_TOTAL_AMOUNT_PAST_1MM FLOAT,
            SUM_TOTAL_AMOUNT_PAST_2MM FLOAT,
            SUM_TOTAL_AMOUNT_PAST_3MM FLOAT,
            COUNT_ORDERS_PAST_7D FLOAT,
            COUNT_ORDERS_PAST_1MM FLOAT,
            COUNT_ORDERS_PAST_2MM FLOAT,
            COUNT_ORDERS_PAST_3MM FLOAT,
            DAYS_SINCE_LAST_PURCHASE FLOAT,
            CHURNED FLOAT,
            CHURNED_PRED_PROD FLOAT,
            CHURNED_PRED_BASE FLOAT,
            CHURNED_PRED_RETRAIN FLOAT,
            CHURNED_PRED_PROBABILITY FLOAT,
            VERSION_NAME VARCHAR(50)
        )
    """).collect()

print("✅ Monitoring tables created")

# Create inference function
def run_inference(model, dataset_df, output_table, col_name, is_prod):
    """Run inference and store results for monitoring"""
    
    # Get predictions
    predictions = model.run(dataset_df, function_name="predict")
    predictions = predictions.select([F.col(c).alias(c.replace('"', '')) for c in predictions.columns])
    predictions_df = predictions.rename("output_feature_0", col_name)
    predictions_df = predictions_df.with_column("VERSION_NAME", F.lit(model.version_name))
    predictions_df = predictions_df.with_column("CHURNED_PRED_PROBABILITY", F.col(col_name))
    
    # Store in temporary table first
    predictions_df.write.mode("overwrite").save_as_table('TEMP_PREDICTIONS')
    
    # Merge with output table
    output_columns = [field.name for field in session.table(output_table).schema]
    insert_columns = ", ".join(output_columns)
    insert_values = ", ".join([
        f"t.{col}" if col in predictions_df.columns else "NULL" for col in output_columns
    ])
    
    merge_statement = f"""
        MERGE INTO {output_table} o
        USING TEMP_PREDICTIONS t
        ON o.CUSTOMER_ID = t.CUSTOMER_ID AND o.TIMESTAMP = t.TIMESTAMP
        WHEN MATCHED THEN
            UPDATE SET o.{col_name} = t.{col_name},
                       o.VERSION_NAME = t.VERSION_NAME,
                       o.CHURNED_PRED_PROBABILITY = t.CHURNED_PRED_PROBABILITY
        WHEN NOT MATCHED THEN
            INSERT ({insert_columns})
            VALUES ({insert_values})
    """
    
    session.sql(merge_statement).collect()
    print(f"✅ Predictions stored in {output_table}")

print("✅ Inference function defined")


🔄 Setting up model monitoring infrastructure...
✅ Monitoring tables created
✅ Inference function defined


In [44]:
# Setup Model Monitors
print("🔄 Creating Model Monitors...")
session.sql(f'USE SCHEMA {sc}').collect()

# Populate baseline predictions
baseline_model = mr.get_model("ChurnDetector").version("baseline")
run_inference(baseline_model, training_dataset, 'customer_churn_baseline_predicted', 'CHURNED_PRED_BASE', True)
run_inference(baseline_model, validation_dataset, 'CUSTOMER_CHURN_PREDICTED_PROD2', 'CHURNED_PRED_PROD', True)
run_inference(baseline_model, validation_dataset, 'CUSTOMER_CHURN_PREDICTED_PROD2', 'CHURNED_PRED_BASE', True)
run_inference(baseline_model, validation_dataset, 'CUSTOMER_CHURN_PREDICTED_PROD2', 'CHURNED_PRED_RETRAIN', True)

# Need this as each model monitor requires a unique model version
fake_prod_model = mr.get_model("ChurnDetector").version("baseline")

fake_prod_logged = mr.log_model(model= fake_prod_model,
                        model_name= "ChurnDetector",
                        version_name= "PRODMONITOR",
                        )

fake_retrain_model = mr.get_model("ChurnDetector").version("baseline")

fake_retrain_logged = mr.log_model(model= fake_retrain_model,
                        model_name= "ChurnDetector",
                        version_name= "RETRAIN",
                        )


# Create Model Monitors
session.sql(f'USE SCHEMA {mr_schema}').collect()

# Monitor for baseline model
session.sql(f"""
    CREATE OR REPLACE MODEL MONITOR Monitor_ChurnDetector_Base
    WITH
        MODEL=ChurnDetector
        VERSION=baseline
        FUNCTION=predict
        SOURCE={sc}.CUSTOMER_CHURN_PREDICTED_PROD2
        BASELINE={sc}.customer_churn_baseline_predicted
        TIMESTAMP_COLUMN=TIMESTAMP
        PREDICTION_CLASS_COLUMNS=(CHURNED_PRED_BASE)  
        ACTUAL_CLASS_COLUMNS=(CHURNED)
        ID_COLUMNS=(CUSTOMER_ID)
        WAREHOUSE={WAREHOUSE}
        REFRESH_INTERVAL='1 min'
        AGGREGATION_WINDOW='1 day';
""").collect()

# Monitor for production model
session.sql(f"""
    CREATE OR REPLACE MODEL MONITOR Monitor_ChurnDetector_Prod
    WITH
        MODEL=ChurnDetector
        VERSION=PRODMONITOR
        FUNCTION=predict
        SOURCE={sc}.CUSTOMER_CHURN_PREDICTED_PROD2
        BASELINE={sc}.customer_churn_baseline_predicted
        TIMESTAMP_COLUMN=TIMESTAMP
        PREDICTION_CLASS_COLUMNS=(CHURNED_PRED_PROD)  
        ACTUAL_CLASS_COLUMNS=(CHURNED)
        ID_COLUMNS=(CUSTOMER_ID)
        WAREHOUSE={WAREHOUSE}
        REFRESH_INTERVAL='1 min'
        AGGREGATION_WINDOW='1 day';
""").collect()

# Monitor for retrained model
session.sql(f"""
    CREATE OR REPLACE MODEL MONITOR Monitor_ChurnDetector_Retrain
    WITH
        MODEL=ChurnDetector
        VERSION=RETRAIN
        FUNCTION=predict
        SOURCE={sc}.CUSTOMER_CHURN_PREDICTED_PROD2
        BASELINE={sc}.customer_churn_baseline_predicted
        TIMESTAMP_COLUMN=TIMESTAMP
        PREDICTION_CLASS_COLUMNS=(CHURNED_PRED_RETRAIN)  
        ACTUAL_CLASS_COLUMNS=(CHURNED)
        ID_COLUMNS=(CUSTOMER_ID)
        WAREHOUSE={WAREHOUSE}
        REFRESH_INTERVAL='1 min'
        AGGREGATION_WINDOW='1 day';
""").collect()

session.sql(f'USE SCHEMA {sc}').collect()

print("✅ Model Monitors created")
print("   - Monitor_ChurnDetector_Base")
print("   - Monitor_ChurnDetector_Prod")
print("   - Monitor_ChurnDetector_Retrain")


🔄 Creating Model Monitors...
✅ Predictions stored in customer_churn_baseline_predicted
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Model Monitors created
   - Monitor_ChurnDetector_Base
   - Monitor_ChurnDetector_Prod
   - Monitor_ChurnDetector_Retrain


In [45]:
# Monitoring and Retraining Functions
# get_model_performance() function removed - now using get_model_performance_sklearn() only

def update_labels_with_new_data(db: str, sc:str, baseline_table: str,  
                     num_days_churn: int):

    # Load baseline features dataset
    baseline_table = f'{db}.{sc}.{baseline_table}'
    
    baseline_df = session.table(baseline_table)

    # Load sales dataset
    sales_df = session.table(f'{db}.{sc}.SALES')

    # Filter sales to retain only customer ID and transaction date
    sales_filtered = sales_df.select(F.col("CUSTOMER_ID"), F.col("TRANSACTION_DATE"))

    # Find the next transaction date for each (CUSTOMER_ID, TIMESTAMP)
    next_transaction_df = (
        baseline_df
        .join(sales_filtered, "CUSTOMER_ID", "left")
        .filter(F.col("TRANSACTION_DATE") >F.col("LAST_PURCHASE_DATE"))
        .group_by(F.col("CUSTOMER_ID"), F.col("TIMESTAMP"))
        .agg(F.min("TRANSACTION_DATE").alias("NEXT_TX_DATE"))
    )

    # Join back with the baseline dataset to compute CHURNED
    final_df = (
        baseline_df
        .join(next_transaction_df, ["CUSTOMER_ID", "TIMESTAMP"], "left")
        .select(
            baseline_df["CUSTOMER_ID"],
            baseline_df["TIMESTAMP"],
            next_transaction_df["NEXT_TX_DATE"],
            F.when(
                next_transaction_df["NEXT_TX_DATE"].is_null() |
                ((next_transaction_df["NEXT_TX_DATE"] - baseline_df["LAST_PURCHASE_DATE"]) > num_days_churn),
                1
            ).otherwise(0).alias("CHURNED")
        )    
        .with_column_renamed("NEXT_TX_DATE", "NEXT_TRANSACTION_DATE")

    )

    final_df.write.mode("overwrite").save_as_table('temp_updates')

    update_statement = f"""
        update {baseline_table} c
        set CHURNED = t.CHURNED,
            NEXT_TRANSACTION_DATE = t.NEXT_TRANSACTION_DATE
        from temp_updates t
        where c.CUSTOMER_ID = t.CUSTOMER_ID AND
            c.TIMESTAMP = t.TIMESTAMP
          
        """

    session.sql(update_statement).collect()
    
    print("✅ Labels updated with new transaction data")

def set_default_model(version_name):
    """Set default model version"""
    session.sql(f'USE SCHEMA {mr_schema}').collect()
    session.sql(f'ALTER MODEL ChurnDetector SET DEFAULT_VERSION = {version_name};').collect()
    session.sql(f'USE SCHEMA {sc}').collect()
    print(f"✅ Default model set to: {version_name}")

print("✅ Monitoring and retraining functions defined")


✅ Monitoring and retraining functions defined


In [46]:
from snowflake.ml import dataset

def get_inference_dataset():

    ts_inference_tb = session.table(labeled_table).select("TIMESTAMP").distinct().sort("TIMESTAMP").collect()
    
    ts_inference = ts_inference_tb[-2]["TIMESTAMP"]

    date_name = "v_" + str(ts_inference).replace("-", "_")
    ds_name = f'{fs_schema}.CHURN_{date_name}'

    inference_dataset = dataset.load_dataset(session, ds_name, 'v1')

    inference_dataset_sdf = inference_dataset.read.to_snowpark_dataframe()

    return inference_dataset_sdf

In [50]:
# Main ML Pipeline Function
# UPDATED process_monthly_data() function using sklearn method only
# Removed get_model_performance() and 90-second waits

def process_monthly_data():
    """Process new monthly data and retrain if needed - NOW USING SKLEARN METHOD ONLY"""
    start_time = datetime.now()
    
    print("=" * 60)
    print("🔄 Processing new monthly data...")
    
    # Step 1: Load new data
    if not copy_next_file(session, db, sc):
        print("❌ No more data files to process")
        return False
    
    # Step 2: Process sentiment
    process_sentiment()
    
    # Step 3: Create features for new data
    sales_df = session.table("SALES")
    latest_timestamp = sales_df.select(F.max("TRANSACTION_DATE")).collect()[0][0]
    
    print(f"📊 Creating features for timestamp: {latest_timestamp}")
    create_customer_features(session, db, sc, latest_timestamp, features_table)
    
    # Step 4: Update labels with new data
    create_churn_labels(session, db, sc, features_table, labeled_table, CHURN_WINDOW)
    
    latest_labeled_timestamp = session.table(labeled_table).select(F.max("TIMESTAMP")).collect()[0][0]

    # Step 5: Create new dataset
    date_name = f"v_{latest_labeled_timestamp}".replace("-", "_")
    new_dataset = create_dataset(fs, feature_view, f'CHURN_{date_name}', latest_labeled_timestamp)
    
    # Step 6: Run inference with baseline model
    baseline_model = mr.get_model("ChurnDetector").version("baseline")
    run_inference(baseline_model, new_dataset, 'CUSTOMER_CHURN_PREDICTED_PROD2', 'CHURNED_PRED_BASE', False)
    update_labels_with_new_data(db, sc, "CUSTOMER_CHURN_PREDICTED_PROD2", CHURN_WINDOW)

    # No longer need to wait for monitors since we use sklearn directly
    print("📈 Model Performance Metrics:")
    
    # Get performance using sklearn method only (fast and reliable)
    baseline_f1 = get_model_performance_sklearn('CHURNED_PRED_BASE')
    production_f1 = get_model_performance_sklearn('CHURNED_PRED_PROD')
    
    print(f"   Baseline F1:   {baseline_f1:.4f}")
    print(f"   Production F1: {production_f1:.4f}")
    
    # Step 8: Retrain if performance drops
    retrain_threshold = 0.8
    retrained_f1 = 0.0
    
    if production_f1 < retrain_threshold:
        print(f"🚨 Performance dropped below {retrain_threshold}, retraining model...")
        
        # Train new model
        new_model_result = train_churn_model(new_dataset)
        
        # Register new model
        retrained_model = mr.log_model(
            model=new_model_result['model'],
            model_name="ChurnDetector",
            version_name=date_name,
            conda_dependencies=["snowflake-ml-python", "xgboost", "scikit-learn"],
            sample_input_data=new_dataset.select(feature_cols).limit(100),
            task=type_hints.Task.TABULAR_BINARY_CLASSIFICATION,
            comment=f"Retrained model for {latest_labeled_timestamp}"
        )
        
        retrained_model.set_metric("train_f1_score", new_model_result['train_f1_score'])
        retrained_model.set_metric("test_f1_score", new_model_result['test_f1_score'])
        
        print(f"✅ New model {date_name} trained:")
        print(f"   Train F1: {new_model_result['train_f1_score']:.4f}")
        print(f"   Test F1: {new_model_result['test_f1_score']:.4f}")
        
        # Test new model on validation data
        val_dataset = get_inference_dataset()
        run_inference(retrained_model, val_dataset, 'CUSTOMER_CHURN_PREDICTED_PROD2', 'CHURNED_PRED_RETRAIN', True)
            
        # Get retrained model performance using sklearn method
        retrained_f1 = get_model_performance_sklearn('CHURNED_PRED_RETRAIN')
        print(f"   Validation F1: {retrained_f1:.4f}")

    # Step 9: Choose best model

    best_model = mr.get_model("ChurnDetector").default.version_name
    
    if (baseline_f1 > production_f1) & (baseline_f1 > retrained_f1):
        set_default_model('baseline')
        best_model = "BASELINE"
    elif (retrained_f1 > baseline_f1) & (retrained_f1 > production_f1):
        set_default_model(date_name)
        best_model = date_name

    # Run production inference
    prod_model = mr.get_model("ChurnDetector").default
    run_inference(prod_model, new_dataset, 'CUSTOMER_CHURN_PREDICTED_PROD2', 'CHURNED_PRED_PROD', True)
    
    elapsed = datetime.now() - start_time
    print(f"✅ Monthly processing completed in {elapsed.total_seconds():.1f}s")
    print(f"   Best model: {best_model})")
    print("=" * 60)
    
    return True

print("✅ Finished!")


✅ Finished!


In [51]:
# Run the continuous learning pipeline
print("🚀 Starting continuous learning pipeline...")
print("This will process remaining data files and retrain models as needed\n")

# Track processing iterations
iteration = 1
continue_processing = True

while continue_processing:
    print(f"\n🔄 Processing Iteration {iteration}")
    
    #try:
    continue_processing = process_monthly_data()
    if continue_processing:
        iteration += 1
    else:
        print("\n🎯 All data processed!")
            
#    except Exception as e:
#        print(f"❌ Error in iteration {iteration}: {str(e)}")
        #break

print(f"\n✅ Pipeline completed after {iteration-1} iterations")

# Final summary
print("\n" + "="*60)
print("📊 FINAL SUMMARY")
print("="*60)

# Check final data counts
sales_final = session.table("SALES").count()
feedback_final = session.table("FEEDBACK_SENTIMENT").count()
features_final = session.table(labeled_table).count()

print(f"📈 Data Processed:")
print(f"   Sales records: {sales_final:,}")
print(f"   Feedback records: {feedback_final:,}")
print(f"   Feature records: {features_final:,}")

# Check model registry
models_query = session.sql(f"USE SCHEMA {mr_schema}; SHOW MODELS; USE SCHEMA {sc};")
print(f"\n🤖 Models in Registry:")
try:
    models = session.sql(f"""
        SELECT VERSION_NAME, CREATION_TIME 
        FROM {mr_schema}.INFORMATION_SCHEMA.ML_MODELS 
        WHERE MODEL_NAME = 'CHURNDETECTOR'
        ORDER BY CREATION_TIME
    """).collect()
    
    for model in models:
        print(f"   - {model['VERSION_NAME']} (created: {model['CREATION_TIME']})")
except:
    print("   Model information not available")

print(f"\n🎯 Workflow completed successfully!")
print("="*60)


🚀 Starting continuous learning pipeline...
This will process remaining data files and retrain models as needed


🔄 Processing Iteration 1
🔄 Processing new monthly data...
📥 Loaded data for 2024-10
✅ Sentiment processed for new feedback
📊 Creating features for timestamp: 2024-10-31
✅ Features created for 2024-10-31
✅ Labels created for churn window: 30 days




✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Labels updated with new transaction data
📈 Model Performance Metrics:
   Baseline F1:   0.7627
   Production F1: 1.0000
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Monthly processing completed in 90.8s
   Best model: BASELINE)

🔄 Processing Iteration 2
🔄 Processing new monthly data...
📥 Loaded data for 2024-11
✅ Sentiment processed for new feedback
📊 Creating features for timestamp: 2024-11-30
✅ Features created for 2024-11-30
✅ Labels created for churn window: 30 days




✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Labels updated with new transaction data
📈 Model Performance Metrics:
   Baseline F1:   0.7441
   Production F1: 0.7441
🚨 Performance dropped below 0.8, retraining model...


  self.manifest.save(


✅ New model v_2024_11_30 trained:
   Train F1: 1.0000
   Test F1: 0.9641
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
   Validation F1: 0.7101
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Monthly processing completed in 140.0s
   Best model: BASELINE)

🔄 Processing Iteration 3
🔄 Processing new monthly data...
📥 Loaded data for 2024-12
✅ Sentiment processed for new feedback
📊 Creating features for timestamp: 2024-12-20
✅ Features created for 2024-12-20
✅ Labels created for churn window: 30 days




✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Labels updated with new transaction data
📈 Model Performance Metrics:
   Baseline F1:   0.6876
   Production F1: 0.6876
🚨 Performance dropped below 0.8, retraining model...


  self.manifest.save(


✅ New model v_2024_12_20 trained:
   Train F1: 1.0000
   Test F1: 0.9978
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
   Validation F1: 0.7771
✅ Default model set to: v_2024_12_20
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Monthly processing completed in 147.0s
   Best model: v_2024_12_20)

🔄 Processing Iteration 4
🔄 Processing new monthly data...
No unprocessed sales files found.
❌ No more data files to process

🎯 All data processed!

✅ Pipeline completed after 3 iterations

📊 FINAL SUMMARY
📈 Data Processed:
   Sales records: 54,837
   Feedback records: 6,685
   Feature records: 34,040

🤖 Models in Registry:
   Model information not available

🎯 Workflow completed successfully!


In [52]:
my_model = mr.get_model("ChurnDetector").default.version_name

my_model

'V_2024_12_20'

In [53]:
# Introduce Data Drift - New Customer Segments and Behaviors
print("🔄 Introducing data drift with new customer segments...")

# Load new customers with different demographic patterns
print("📥 Loading new customers with different demographics...")
load_into_table(session, f'{db}.{sc}.CUSTOMERS', f'@{stage_name}/new_customers.csv')

print("✅ New customers loaded - this introduces demographic drift")

# Discover and register new drift data files
print("🔍 Discovering new data files for drift simulation...")

# Get new sales files (with different purchasing patterns)
new_sales_files = get_year_month_files(session, stage_name, 'new_sales')
insert_file_tracking('sales', db, sc, new_sales_files)

# Get new feedback files (with different sentiment patterns)  
new_feedback_files = get_year_month_files(session, stage_name, 'new_feedback_raw2')
insert_file_tracking('feedback_raw', db, sc, new_feedback_files)

print(f"📊 Registered {len(new_sales_files)} new sales files and {len(new_feedback_files)} new feedback files")
print("   These files contain:")
print("   - Different customer segments (demographic drift)")
print("   - New purchasing patterns (behavioral drift)")  
print("   - Different sentiment distributions (feature drift)")

# Check what files are available for processing
unprocessed_files = session.sql(f"""
    SELECT FILE_TYPE, COUNT(*) as COUNT
    FROM FILES_INGESTED 
    WHERE INGESTED = FALSE
    GROUP BY FILE_TYPE
    ORDER BY FILE_TYPE
""").collect()

print("\n📋 Unprocessed files available:")
for row in unprocessed_files:
    print(f"   - {row['FILE_TYPE']}: {row['COUNT']} files")

print("\n✅ Data drift setup completed - ready to process new data patterns")


🔄 Introducing data drift with new customer segments...
📥 Loading new customers with different demographics...
✅ New customers loaded - this introduces demographic drift
🔍 Discovering new data files for drift simulation...
📊 Registered 5 new sales files and 13 new feedback files
   These files contain:
   - Different customer segments (demographic drift)
   - New purchasing patterns (behavioral drift)
   - Different sentiment distributions (feature drift)

📋 Unprocessed files available:
   - feedback_raw: 18 files
   - sales: 5 files

✅ Data drift setup completed - ready to process new data patterns


In [54]:
# Process Data Drift and Monitor Model Performance
print("🚀 Processing data with drift patterns and monitoring model response...")
print("This will test how well our models adapt to new data distributions\n")

# Track drift processing
drift_iteration = 1
continue_drift_processing = True

# Store baseline performance for comparison
baseline_performance = {}

while continue_drift_processing:
    print(f"\n🔄 Processing New Drift Data - Iteration {drift_iteration}")
    
    # Process the next batch of drift data
    continue_drift_processing = process_monthly_data()
    
    if continue_drift_processing:
        print(f"✅ Drift iteration {drift_iteration} completed")
        
        drift_iteration += 1
    else:
        print("\n🎯 All New drift data processed!")

print(f"\n✅ Data drift processing completed after {drift_iteration-1} iterations")


🚀 Processing data with drift patterns and monitoring model response...
This will test how well our models adapt to new data distributions


🔄 Processing New Drift Data - Iteration 1
🔄 Processing new monthly data...
📥 Loaded data for 2025-01
✅ Sentiment processed for new feedback
📊 Creating features for timestamp: 2025-01-31
✅ Features created for 2025-01-31
✅ Labels created for churn window: 30 days




✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Labels updated with new transaction data
📈 Model Performance Metrics:
   Baseline F1:   0.5579
   Production F1: 0.9996
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Monthly processing completed in 88.9s
   Best model: V_2024_12_20)
✅ Drift iteration 1 completed

🔄 Processing New Drift Data - Iteration 2
🔄 Processing new monthly data...
📥 Loaded data for 2025-02
✅ Sentiment processed for new feedback
📊 Creating features for timestamp: 2025-02-28
✅ Features created for 2025-02-28
✅ Labels created for churn window: 30 days




✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Labels updated with new transaction data
📈 Model Performance Metrics:
   Baseline F1:   0.7216
   Production F1: 0.8948
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Monthly processing completed in 82.6s
   Best model: V_2024_12_20)
✅ Drift iteration 2 completed

🔄 Processing New Drift Data - Iteration 3
🔄 Processing new monthly data...
📥 Loaded data for 2025-03
✅ Sentiment processed for new feedback
📊 Creating features for timestamp: 2025-03-31
✅ Features created for 2025-03-31
✅ Labels created for churn window: 30 days




✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Labels updated with new transaction data
📈 Model Performance Metrics:
   Baseline F1:   0.7162
   Production F1: 0.9356
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Monthly processing completed in 84.5s
   Best model: V_2024_12_20)
✅ Drift iteration 3 completed

🔄 Processing New Drift Data - Iteration 4
🔄 Processing new monthly data...
📥 Loaded data for 2025-04
✅ Sentiment processed for new feedback
📊 Creating features for timestamp: 2025-04-30
✅ Features created for 2025-04-30
✅ Labels created for churn window: 30 days




✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Labels updated with new transaction data
📈 Model Performance Metrics:
   Baseline F1:   0.7228
   Production F1: 0.9411
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Monthly processing completed in 90.5s
   Best model: V_2024_12_20)
✅ Drift iteration 4 completed

🔄 Processing New Drift Data - Iteration 5
🔄 Processing new monthly data...
📥 Loaded data for 2025-05
📊 Creating features for timestamp: 2025-05-20
✅ Features created for 2025-05-20
✅ Labels created for churn window: 30 days




✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Labels updated with new transaction data
📈 Model Performance Metrics:
   Baseline F1:   0.7128
   Production F1: 0.9504
✅ Predictions stored in CUSTOMER_CHURN_PREDICTED_PROD2
✅ Monthly processing completed in 82.9s
   Best model: V_2024_12_20)
✅ Drift iteration 5 completed

🔄 Processing New Drift Data - Iteration 6
🔄 Processing new monthly data...
No unprocessed sales files found.
❌ No more data files to process

🎯 All New drift data processed!

✅ Data drift processing completed after 5 iterations


In [55]:
session.sql('''
select DISTINCT(TIMESTAMP) as TS, VERSION_NAME from CUSTOMER_CHURN_PREDICTED_PROD2
order by TS ASC;''').collect()

[Row(TS=datetime.datetime(2024, 7, 31, 0, 0), VERSION_NAME='BASELINE'),
 Row(TS=datetime.datetime(2024, 9, 30, 0, 0), VERSION_NAME='BASELINE'),
 Row(TS=datetime.datetime(2024, 10, 31, 0, 0), VERSION_NAME='V_2024_11_30'),
 Row(TS=datetime.datetime(2024, 11, 30, 0, 0), VERSION_NAME='V_2024_12_20'),
 Row(TS=datetime.datetime(2024, 12, 20, 0, 0), VERSION_NAME='V_2024_12_20'),
 Row(TS=datetime.datetime(2025, 1, 31, 0, 0), VERSION_NAME='V_2024_12_20'),
 Row(TS=datetime.datetime(2025, 2, 28, 0, 0), VERSION_NAME='V_2024_12_20'),
 Row(TS=datetime.datetime(2025, 3, 31, 0, 0), VERSION_NAME='V_2024_12_20'),
 Row(TS=datetime.datetime(2025, 4, 30, 0, 0), VERSION_NAME='V_2024_12_20'),
 Row(TS=datetime.datetime(2025, 5, 20, 0, 0), VERSION_NAME='V_2024_12_20')]

## Summary

This notebook demonstrates a complete end-to-end ML workflow with observability

### ✅ What We Accomplished

1. **Environment Setup**
   - Configured Snowflake ML environment with Feature Store and Model Registry
   - Created data schemas and staging areas
   - Set up monitoring infrastructure

2. **Data Ingestion & Processing**
   - Loaded customer, sales, and feedback data progressively
   - Used Cortex AI for sentiment analysis of customer feedback
   - Created rich behavioral features using Snowflake analytical functions

3. **Model Training & Registry**
   - Trained XGBoost model for churn prediction
   - Registered model in Model Registry with full lineage
   - Created baseline for future comparisons

4. **Continuous Learning & Monitoring**
   - Implemented automated data ingestion pipeline
   - Set up Model Monitors for performance tracking
   - Created automated retraining based on performance thresholds
   - Maintained model lineage and observability



### 🔍 Key Features Demonstrated

- **Feature Store**: Centralized feature management with versioning
- **Model Registry**: Version control and lifecycle management for models
- **Model Monitoring**: Automatic drift detection and performance tracking
- **Automated Retraining**: Performance-based model updates triggered by drift
- **Performance Evolution Tracking**: Monitor how models respond to changing data
- **Full Observability**: Complete lineage from data to predictions with drift visibility

