## Setup Demo / Data Generation

In [None]:
from snowflake.snowpark.context import get_active_session
from src.setup_environment import setup_demo
session = get_active_session()

setup_demo(session)

In [None]:
from snowflake.snowpark.context import get_active_session
session = get_active_session()

session.sql('CREATE OR REPLACE SCHEMA MLOPS_DEMO.RETAIL_DATA').collect()
session.sql('CREATE OR REPLACE SCHEMA MLOPS_DEMO._DATA_GENERATION').collect()
session.sql('CREATE OR REPLACE STAGE MLOPS_DEMO._DATA_GENERATION.FUNCTIONS').collect()
session.sql('CREATE OR REPLACE SCHEMA MLOPS_DEMO.FEATURE_STORE').collect()
session.sql('CREATE OR REPLACE SCHEMA MLOPS_DEMO.MODEL_REGISTRY').collect()
session.sql('CREATE OR REPLACE SCHEMA MLOPS_DEMO.FEATURE_STORE').collect()

In [None]:
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark import Session
session = get_active_session()

In [None]:
from snowflake.core import Root, CreateMode
from snowflake.core.schema import Schema
from snowflake.core.stage import Stage
from snowflake.core.warehouse import Warehouse
from snowflake.snowpark.types import (
    FloatType,
    IntegerType,
    StringType,
    ArrayType,
    StructField,
    StructType,
)
from snowflake.snowpark.functions import udtf, lit, col
import pandas as pd
import numpy as np
from scipy.stats import truncnorm

def generate_supermarket_revenue_data():
    # Define the date range from 01.01.2022 to 31.01.2025
    start_date = '2024-01-01'
    end_date = '2025-01-31'
    dates = pd.date_range(start=start_date, end=end_date, freq='D')
    
    # Define base revenue (this is a baseline you can adjust)
    base_revenue = 20000
    
    # Prepare a list to store computed records
    records = []
    
    for date in dates:
        # Determine the day-of-week: Monday=0, ... , Sunday=6
        weekday = date.weekday()
        
        # Weekday effect: Saturdays have the highest revenue,
        # Sundays are lower, and the rest are normal.
        if weekday == 5:       # Saturday
            weekday_factor = 1.5
        elif weekday == 6:     # Sunday
            weekday_factor = 0.9
        else:
            weekday_factor = 1.0
        
        # Month effect: June, July, August, and December get a boost.
        if date.month in [6, 7, 8, 12]:
            month_factor = 1.15
        else:
            month_factor = 1.0
        
        # Add random noise (mean 1, small standard deviation) to simulate natural fluctuations
        noise_factor = np.random.normal(loc=1, scale=0.05)
        
        # Compute the final revenue
        revenue = base_revenue * weekday_factor * month_factor * noise_factor
        revenue = max(0, revenue)  # Ensure no negative revenue
        
        # Append the result as a tuple (DATE, REVENUE)
        records.append((date, revenue))
    
    # Create a DataFrame from the records
    df = pd.DataFrame(records, columns=['DATE', 'REVENUE'])
    
    return df

class GenerateTransactions:
    # Draw a number from a normal distribution with defined mean, std_dev, lower and upper bounds
    def get_norm_value(self, min, max, mean, std_dev):
        # Calculate the a and b parameters for truncnorm
        min = (min - mean) / std_dev
        max = (max - mean) / std_dev
        
        # Generate the truncated normal distribution
        truncated_data = truncnorm.rvs(min, max, loc=mean, scale=std_dev, size=1)[0]
        return truncated_data
        
    def process(self, revenue, in_shop_online):
        customer_id = 0
        while revenue > 0:
            # customer id
            if (customer_id >= 0) and (customer_id < 100):
                transaction_amount = np.round(self.get_norm_value(5, 25, 20, 5),2)
            elif (customer_id >= 100) and (customer_id < 200):
                transaction_amount = np.round(self.get_norm_value(25, 50, 40, 5),2)
            elif (customer_id >= 200) and (customer_id < 300):
                transaction_amount = np.round(self.get_norm_value(50, 100, 80, 10),2)
            else:
                transaction_amount = np.round(self.get_norm_value(100, 150, 120, 10),2)
            transaction_channel = np.random.choice(['IN_SHOP','ONLINE'],p=in_shop_online)
            revenue = revenue - transaction_amount
            if customer_id == 350:
                customer_id = 0
            customer_id = customer_id +1
            yield (customer_id, transaction_amount, transaction_channel)

def setup_objects(session:Session):
    root = Root(session)
    root.databases["MLOPS_DEMO"].schemas.create(schema=Schema(name="RETAIL_DATA"), mode=CreateMode.or_replace)
    root.databases["MLOPS_DEMO"].schemas.create(schema=Schema(name="_DATA_GENERATION"), mode=CreateMode.or_replace)
    root.databases["MLOPS_DEMO"].schemas.create(schema=Schema(name="FEATURE_STORE"), mode=CreateMode.or_replace)
    root.databases["MLOPS_DEMO"].schemas.create(schema=Schema(name="MODEL_REGISTRY"), mode=CreateMode.or_replace)
    root.databases["MLOPS_DEMO"].schemas['_DATA_GENERATION'].stages.create(stage=Stage(name="FUNCTIONS"), mode=CreateMode.or_replace)
    root.warehouses.create(Warehouse(name='FEATURE_STORE_WH', warehouse_size='MEDIUM'), mode=CreateMode.or_replace)
    session.udtf.register(
        GenerateTransactions,
        name='MLOPS_DEMO._DATA_GENERATION.GENERATE_TRANSACTIONS',
        stage_location='MLOPS_DEMO._DATA_GENERATION.FUNCTIONS',
        is_permanent=True,
        replace=True,
        output_schema=StructType([
            StructField("CUSTOMER_ID", IntegerType()),
            StructField("TRANSACTION_AMOUNT", FloatType()),
            StructField("TRANSACTION_CHANNEL", StringType())]),
        input_types=[FloatType(), ArrayType()],
        packages=["numpy","scipy"]
    )

def generate_data(session: Session):
    setup_objects(session)
    revenue_df = generate_supermarket_revenue_data()
    revenue_df = session.create_dataframe(revenue_df)
    revenue_in_shop = revenue_df.filter(col('DATE') < lit('2024-06-01'))
    revenue_online = revenue_df.filter(col('DATE') >= lit('2024-06-01'))
    
    revenue_in_shop.join_table_function('MLOPS_DEMO._DATA_GENERATION.GENERATE_TRANSACTIONS', col('REVENUE'), lit([0.75,0.25])).drop('REVENUE').write.save_as_table(table_name='MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS', mode='overwrite')
    revenue_online.join_table_function('MLOPS_DEMO._DATA_GENERATION.GENERATE_TRANSACTIONS',col('REVENUE'), lit([0.25,0.75])).drop('REVENUE').write.save_as_table(table_name='MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS', mode='append')
    
    # Setting up data for demo
    session.table('MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS').filter(col('DATE') <= lit('2024-04-30')).write.save_as_table(table_name='MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS', mode='overwrite')
    session.table('MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS').select('CUSTOMER_ID').distinct().order_by('CUSTOMER_ID').write.save_as_table(table_name='MLOPS_DEMO.RETAIL_DATA.CUSTOMERS', mode='overwrite')

In [None]:
generate_data(session)

In [None]:
# Snowflake Snowpark imports
from snowflake.snowpark import Session 
from snowflake.snowpark import functions as F
from snowflake.snowpark.functions import udtf, lit, col
from snowflake.snowpark.types import (
    FloatType,
    IntegerType,
    StringType,
    ArrayType,
    StructField,
    StructType,
)
from snowflake.ml.registry import Registry
from snowflake.ml.modeling.metrics import mean_absolute_percentage_error
from snowflake.ml.feature_store import (
    FeatureStore,
    FeatureView,
    Entity,
    CreationMode
)
from snowflake.ml.modeling.xgboost import XGBRegressor

# Third-party imports
import streamlit as st
import random
import numpy as np
import pandas as pd
from scipy.stats import truncnorm
import time
import plotly.graph_objects as go
import plotly.express as px
import re
from datetime import datetime, timedelta
import calendar
from tabulate import tabulate
import json
import networkx as nx

# Ensure reproducibility
np.random.seed(42)

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()

session.sql('CREATE OR REPLACE SCHEMA MLOPS_DEMO.RETAIL_DATA').collect()
session.sql('CREATE OR REPLACE SCHEMA MLOPS_DEMO._DATA_GENERATION').collect()
session.sql('CREATE OR REPLACE STAGE MLOPS_DEMO._DATA_GENERATION.FUNCTIONS').collect()
session.sql('CREATE OR REPLACE SCHEMA MLOPS_DEMO.FEATURE_STORE').collect()
session.sql('CREATE OR REPLACE SCHEMA MLOPS_DEMO.MODEL_REGISTRY').collect()
session.sql('CREATE OR REPLACE SCHEMA MLOPS_DEMO.FEATURE_STORE').collect()

def generate_supermarket_revenue_data():
    # Define the date range from 01.01.2022 to 31.01.2025
    start_date = '2024-01-01'
    end_date = '2025-01-31'
    dates = pd.date_range(start=start_date, end=end_date, freq='D')
    
    # Define base revenue (this is a baseline you can adjust)
    base_revenue = 20000
    
    # Prepare a list to store computed records
    records = []
    
    for date in dates:
        # Determine the day-of-week: Monday=0, ... , Sunday=6
        weekday = date.weekday()
        
        # Weekday effect: Saturdays have the highest revenue,
        # Sundays are lower, and the rest are normal.
        if weekday == 5:       # Saturday
            weekday_factor = 1.5
        elif weekday == 6:     # Sunday
            weekday_factor = 0.9
        else:
            weekday_factor = 1.0
        
        # Month effect: June, July, August, and December get a boost.
        if date.month in [6, 7, 8, 12]:
            month_factor = 1.15
        else:
            month_factor = 1.0
        
        # Add random noise (mean 1, small standard deviation) to simulate natural fluctuations
        noise_factor = np.random.normal(loc=1, scale=0.05)
        
        # Compute the final revenue
        revenue = base_revenue * weekday_factor * month_factor * noise_factor
        revenue = max(0, revenue)  # Ensure no negative revenue
        
        # Append the result as a tuple (DATE, REVENUE)
        records.append((date, revenue))
    
    # Create a DataFrame from the records
    df = pd.DataFrame(records, columns=['DATE', 'REVENUE'])
    
    return df

revenue_df = generate_supermarket_revenue_data()
revenue_df = session.create_dataframe(revenue_df)

@udtf(
    name="MLOPS_DEMO._DATA_GENERATION.GENERATE_TRANSACTIONS",
    stage_location='MLOPS_DEMO._DATA_GENERATION.FUNCTIONS',
    is_permanent=True,
    replace=True,
    output_schema=StructType([
        StructField("CUSTOMER_ID", IntegerType()),
        StructField("TRANSACTION_AMOUNT", FloatType()),
        StructField("TRANSACTION_CHANNEL", StringType())]),
    input_types=[FloatType(), ArrayType()],
    packages=["numpy","scipy"])
class GenerateTransactions:
    # Draw a number from a normal distribution with defined mean, std_dev, lower and upper bounds
    def get_norm_value(self, min, max, mean, std_dev):
        # Calculate the a and b parameters for truncnorm
        min = (min - mean) / std_dev
        max = (max - mean) / std_dev
        
        # Generate the truncated normal distribution
        truncated_data = truncnorm.rvs(min, max, loc=mean, scale=std_dev, size=1)[0]
        return truncated_data
        
    def process(self, revenue, in_shop_online):
        customer_id = 0
        while revenue > 0:
            # customer id
            if (customer_id >= 0) and (customer_id < 100):
                transaction_amount = np.round(self.get_norm_value(5, 25, 20, 5),2)
            elif (customer_id >= 100) and (customer_id < 200):
                transaction_amount = np.round(self.get_norm_value(25, 50, 40, 5),2)
            elif (customer_id >= 200) and (customer_id < 300):
                transaction_amount = np.round(self.get_norm_value(50, 100, 80, 10),2)
            else:
                transaction_amount = np.round(self.get_norm_value(100, 150, 120, 10),2)
            transaction_channel = np.random.choice(['IN_SHOP','ONLINE'],p=in_shop_online)
            revenue = revenue - transaction_amount
            if customer_id == 350:
                customer_id = 0
            customer_id = customer_id +1
            yield (customer_id, transaction_amount, transaction_channel)


revenue_in_shop = revenue_df.filter(col('DATE') < lit('2024-06-01'))
revenue_online = revenue_df.filter(col('DATE') >= lit('2024-06-01'))

revenue_in_shop.join_table_function(GenerateTransactions(col('REVENUE'), lit([0.75,0.25]))).drop('REVENUE').write.save_as_table(table_name='MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS', mode='overwrite')
revenue_online.join_table_function(GenerateTransactions(col('REVENUE'), lit([0.25,0.75]))).drop('REVENUE').write.save_as_table(table_name='MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS', mode='append')

# Setting up data for demo
session.table('MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS').filter(col('DATE') <= lit('2024-04-30')).write.save_as_table(table_name='MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS', mode='overwrite')
session.table('MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS').select('CUSTOMER_ID').distinct().order_by('CUSTOMER_ID').write.save_as_table(table_name='MLOPS_DEMO.RETAIL_DATA.CUSTOMERS', mode='overwrite')

def is_feature_store_updated(timestamp_df):
    feature_store_refreshes = (
        session.table_function('INFORMATION_SCHEMA.DYNAMIC_TABLE_REFRESH_HISTORY')
        .filter(col('QUALIFIED_NAME').startswith('MLOPS_DEMO.FEATURE_STORE.'))
        .select('NAME','STATE','REFRESH_END_TIME')
        .order_by(col('REFRESH_END_TIME').desc())
        .group_by('NAME','STATE')
        .agg(F.max('REFRESH_END_TIME').as_('REFRESH_END_TIME'))
        .with_column('SECONDS_SINCE_LAST_REFRESH', F.datediff('second', col('REFRESH_END_TIME'),F.current_timestamp()))
        .cross_join(timestamp_df)
        .with_column('UPDATED', col('REFRESH_END_TIME') > col('TIMESTAMP'))
    )
    if feature_store_refreshes.filter(col('UPDATED') == False).count() > 0:
        return False
    else:
        return True

def wait_until_feature_store_updated(interval=3):
    ts = session.range(1).with_column('TIMESTAMP', F.current_timestamp()).drop('ID').cache_result()
    start_time = time.time()
    while not is_feature_store_updated(ts):
        print(f"\rWaiting for updated Feature Store ... ({int(time.time()-start_time)} seconds.)", end="", flush=True)
        time.sleep(interval)

def last_day_of_month(year: int, month: int) -> datetime:
    """Returns the last day of a given month and year."""
    last_day = calendar.monthrange(year, month)[1]
    return datetime(year, month, last_day)

def generate_last_days(start_date: str, end_date: str):
    """Generates the last day of each month within the given date range."""
    start = datetime.strptime(start_date, "%Y-%m-%d")
    end = datetime.strptime(end_date, "%Y-%m-%d")
    
    results = []
    current = last_day_of_month(start.year, start.month)
    
    while current <= end:
        results.append(current.strftime("%Y-%m-%d"))
        next_month = current.month + 1
        next_year = current.year + (1 if next_month > 12 else 0)
        next_month = next_month if next_month <= 12 else 1
        current = last_day_of_month(next_year, next_month)
    
    return results

def get_feature_df(session, feature_cutoff_date):
    fs = FeatureStore(
        session=session, 
        database=session.get_current_database(), 
        name='FEATURE_STORE', 
        default_warehouse=session.get_current_warehouse(),
        creation_mode=CreationMode.CREATE_IF_NOT_EXIST,
    )
    fvs = [fs.get_feature_view(n[0], 'V1') for n in fs.list_feature_views().select('NAME').to_pandas().values]
    feature_df = session.table('MLOPS_DEMO.RETAIL_DATA.CUSTOMERS').select('CUSTOMER_ID').distinct().with_column('FEATURE_CUTOFF_DATE', F.to_date(lit(feature_cutoff_date)))
    feature_df = fs.retrieve_feature_values(
        spine_df=feature_df,
        features=fvs,
        spine_timestamp_col="FEATURE_CUTOFF_DATE"
    )
    feature_df = feature_df.with_column('NEXT_MONTH_REVENUE', lit(None).cast('number(38,2)'))
    return feature_df

def simulate_model_performance(session, start_date, end_date, model_version, generate_data=False):
    if generate_data == True:
        # add transactions
        new_transactions = session.table('MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS').filter(col('DATE').between(start_date,end_date))
        new_transactions.write.save_as_table(table_name='MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS', mode='append')
        
        # wait for feature store
        wait_until_feature_store_updated()
        print('')

    # Retrieve model from model registry
    reg = Registry(
        session=session, 
        database_name=session.get_current_database(), 
        schema_name='MODEL_REGISTRY', 
        options={'enable_monitoring':True},
    )
    registered_model = reg.get_model('CUSTOMER_REVENUE_MODEL').version(model_version)
    
    # Generate predictions 
    for date in generate_last_days(start_date, end_date):
        print(F'Generated predictions with features until: {date}.')
        feature_df = get_feature_df(session, feature_cutoff_date=date)
    
        # Predict values
        predictions = registered_model.run(feature_df, function_name='PREDICT')
        predictions = predictions.with_column('FEATURE_CUTOFF_DATE', F.col('FEATURE_CUTOFF_DATE').cast('timestamp'))
        predictions = predictions.with_column('NEXT_MONTH_REVENUE_PREDICTION', F.col('NEXT_MONTH_REVENUE_PREDICTION').cast('number(38,2)'))
        predictions.write.save_as_table(table_name=f'MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_{model_version}', mode='append')

    # Add actual values
    actual_values_df = (
        session.table('MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')
        .filter(col('DATE').between(start_date, end_date))
        .with_column('YEAR', F.year(col('DATE')))
        .with_column('MONTH', F.month(col('DATE')))
        .group_by(['CUSTOMER_ID','YEAR','MONTH'])
        .agg(F.sum('TRANSACTION_AMOUNT').cast('decimal(38,2)').as_('REVENUE'))
        .with_column('DATE', F.date_add(F.date_from_parts(col('YEAR'),col('MONTH'),lit(1)), lit(-1)))
        .drop(['YEAR','MONTH'])
    )
    
    # Get list of all customers
    customers_df = session.table('MLOPS_DEMO.RETAIL_DATA.CUSTOMERS').select('CUSTOMER_ID').distinct()
    
    # Assume 0 revenue for customers without transactions
    actual_values_df = actual_values_df.join(customers_df, on=['CUSTOMER_ID'], how='outer')
    actual_values_df = actual_values_df.fillna(0,subset='REVENUE')

    # Update source table from model monitor
    source_table = session.table(f'MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_{model_version}')
    update_result = source_table.update(
        condition=(
            (source_table['FEATURE_CUTOFF_DATE'] == actual_values_df['DATE']) &
            (source_table['CUSTOMER_ID'] == actual_values_df['CUSTOMER_ID'])
        ),
        assignments={
            "NEXT_MONTH_REVENUE": actual_values_df['REVENUE'],
        },
        source=actual_values_df
    )
    print(f'Updated {update_result.rows_updated} rows in source table of model monitor.')
    return

def evaluate_against_production_model(session, new_model_version, test_df):
    reg = Registry(
        session=session, 
        database_name=session.get_current_database(), 
        schema_name='MODEL_REGISTRY', 
        options={'enable_monitoring':True},
    )
    production_model = reg.get_model('CUSTOMER_REVENUE_MODEL').version('PRODUCTION')
    production_model_predictions = production_model.run(test_df, function_name='PREDICT')
    production_model_mape = mean_absolute_percentage_error(
        df=production_model_predictions, 
        y_true_col_names="NEXT_MONTH_REVENUE", 
        y_pred_col_names="NEXT_MONTH_REVENUE_PREDICTION"
    )
    
    development_model = reg.get_model('CUSTOMER_REVENUE_MODEL').version(new_model_version)
    development_model_predictions = development_model.run(test_df, function_name='PREDICT')
    development_model_mape = mean_absolute_percentage_error(
        df=development_model_predictions, 
        y_true_col_names="NEXT_MONTH_REVENUE", 
        y_pred_col_names="NEXT_MONTH_REVENUE_PREDICTION"
    )
    print(development_model_mape, production_model_mape)
    if development_model_mape < production_model_mape:
        print(f"New model with version {new_model_version} has a lower MAPE compared to current production model.")
        mape_values_df = pd.DataFrame([['PRODUCTION', production_model_mape],[new_model_version,development_model_mape]], columns=['VERSION','MAPE'])
        print(tabulate(mape_values_df, headers='keys', tablefmt='grid'))
        production_model.unset_alias('PRODUCTION')
        production_model.set_alias('DEPRECATED')
        development_model.set_alias('PRODUCTION')
    else:
        print(f"Existing production model has a lower MAPE compared to the developed model.")

def train_new_model(session, feature_cutoff_date, target_start_date, target_end_date, model_version):
    fs = FeatureStore(
        session=session, 
        database=session.get_current_database(), 
        name='FEATURE_STORE', 
        default_warehouse=session.get_current_warehouse(),
        creation_mode=CreationMode.CREATE_IF_NOT_EXIST,
    )
    fvs = [fs.get_feature_view(n[0], 'V1') for n in fs.list_feature_views().select('NAME').to_pandas().values]
    # Create dataset for training
    print('Creating training dataset...')
    target_df = session.table('MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')
    target_df = (
        target_df.filter(col('DATE').between(target_start_date,target_end_date))
        .group_by('CUSTOMER_ID')
        .agg(F.sum('TRANSACTION_AMOUNT').as_('NEXT_MONTH_REVENUE'))
        .with_column('FEATURE_CUTOFF_DATE', F.to_date(lit(feature_cutoff_date)))
    )
    customers_df = session.table('MLOPS_DEMO.RETAIL_DATA.CUSTOMERS').select('CUSTOMER_ID').distinct()
    spine_df = target_df.join(customers_df, on=['CUSTOMER_ID'], how='outer')
    spine_df = spine_df.fillna(0, subset='NEXT_MONTH_REVENUE')
    train_dataset = fs.generate_dataset(
        name="MLOPS_DEMO.FEATURE_STORE.NEXT_MONTH_REVENUE_DATASET",
        spine_df=spine_df,
        features=fvs,
        version=model_version,
        spine_timestamp_col="FEATURE_CUTOFF_DATE",
        spine_label_cols=["NEXT_MONTH_REVENUE"],
        include_feature_view_timestamp_col=False,
        desc="Training Dataset from September 2024"
    )
    df = train_dataset.read.to_snowpark_dataframe()
    print('Training dataset created.')
    # Train new model
    print(f'Training new model with version {model_version}...')
    train_df, test_df = df.random_split(weights=[0.9, 0.1], seed=0)
    feature_columns = train_df.drop(['CUSTOMER_ID','FEATURE_CUTOFF_DATE','NEXT_MONTH_REVENUE']).columns
    xgb_model = XGBRegressor(
        input_cols=feature_columns,
        label_cols=['NEXT_MONTH_REVENUE'],
        output_cols=['NEXT_MONTH_REVENUE_PREDICTION'],
        n_estimators=100,
        learning_rate=0.05,
        random_state=0
    )
    xgb_model = xgb_model.fit(train_df)

    # Evaluate model
    print(f'Evaluating model...')
    predictions = xgb_model.predict(test_df)
    mape = mean_absolute_percentage_error(
        df=predictions, 
        y_true_col_names="NEXT_MONTH_REVENUE", 
        y_pred_col_names="NEXT_MONTH_REVENUE_PREDICTION"
    )
    
    # Register new model version
    reg = Registry(
        session=session, 
        database_name=session.get_current_database(), 
        schema_name='MODEL_REGISTRY', 
        options={'enable_monitoring':True},
    )
    
    registered_model = reg.log_model(
        xgb_model,
        model_name="CUSTOMER_REVENUE_MODEL",
        version_name=model_version,
        metrics={
            'MAPE':mape, 
            'FEATURE_IMPORTANCE':dict(zip(feature_columns, xgb_model.to_xgboost().feature_importances_.astype('float'))),
            'TRAINING_DATA':{'FEATURE_CUTOFF_DATE':feature_cutoff_date}
        },
        comment="Model trained using XGBoost to predict revenue per customer for next month.",
        conda_dependencies=['xgboost'],
        sample_input_data=train_df.select(feature_columns).limit(10),
        options={"relax_version": False, "enable_explainability": True}
    )
    print(f'Registered new model with version {model_version} in model registry.')

    # Evaluate model and set new model into production if better than old model
    evaluate_against_production_model(session, model_version, test_df)

    # Save baseline predictions
    predictions = predictions.with_column('FEATURE_CUTOFF_DATE', F.col('FEATURE_CUTOFF_DATE').cast('timestamp'))
    predictions = predictions.with_column('NEXT_MONTH_REVENUE_PREDICTION', F.col('NEXT_MONTH_REVENUE_PREDICTION').cast('number(38,2)'))
    predictions = predictions.with_column('NEXT_MONTH_REVENUE', F.col('NEXT_MONTH_REVENUE').cast('number(38,2)'))
    predictions.write.save_as_table(f'MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_{model_version}', mode='overwrite')
    print(f'Baseline table for model monitor created with predictions until {feature_cutoff_date}.')
    
    # Create predictions for next month to create model monitor source table
    # We can use former target_end_date
    feature_df = get_feature_df(session, feature_cutoff_date=target_end_date)
    predictions = registered_model.run(feature_df, function_name='PREDICT')
    predictions = predictions.with_column('FEATURE_CUTOFF_DATE', F.col('FEATURE_CUTOFF_DATE').cast('timestamp'))
    predictions = predictions.with_column('NEXT_MONTH_REVENUE_PREDICTION', F.col('NEXT_MONTH_REVENUE_PREDICTION').cast('number(38,2)'))
    predictions.write.save_as_table(table_name=f'MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_{model_version}', mode='overwrite')
    print(f'Source table for model monitor created with predictions between {target_start_date} and {target_end_date}.')

    session.sql(f"""
    CREATE OR REPLACE MODEL MONITOR MLOPS_DEMO.MODEL_REGISTRY.MM_{model_version} WITH
        MODEL=MLOPS_DEMO.MODEL_REGISTRY.CUSTOMER_REVENUE_MODEL VERSION={model_version} FUNCTION=PREDICT
        SOURCE=MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_{model_version}
        BASELINE=MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_{model_version},
        TIMESTAMP_COLUMN='FEATURE_CUTOFF_DATE'
        ID_COLUMNS=('CUSTOMER_ID')
        PREDICTION_SCORE_COLUMNS=('NEXT_MONTH_REVENUE_PREDICTION')
        ACTUAL_SCORE_COLUMNS=('NEXT_MONTH_REVENUE')
        WAREHOUSE=COMPUTE_WH
        REFRESH_INTERVAL='1 minute'
        AGGREGATION_WINDOW='1 day'""").collect()
    print(f'Model monitor created.')

    # Enable once 1.7.3 with bugfix is available
    # source_config = ModelMonitorSourceConfig(
    #     source=f'MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_{model_version}',
    #     timestamp_column='FEATURE_CUTOFF_DATE',
    #     id_columns=['CUSTOMER_ID'],
    #     prediction_score_columns=['NEXT_MONTH_REVENUE_PREDICTION'],
    #     actual_score_columns=['NEXT_MONTH_REVENUE'],
    #     baseline=f'MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_{model_version}'
    # )
    
    # monitor_config = ModelMonitorConfig(
    #     model_version=reg.get_model('CUSTOMER_REVENUE_MODEL').version('PRODUCTION'),
    #     model_function_name='predict',
    #     background_compute_warehouse_name='COMPUTE_WH',
    #     refresh_interval='1 minute',
    #     aggregation_window='1 day'
    # )
    
    # reg.add_monitor(
    #     name=f'MLOPS_DEMO.MODEL_REGISTRY.MM_{model_version}',
    #     source_config=source_config,
    #     model_monitor_config=monitor_config
    # )
    return

def compare_two_models(session, version_name_1, version_name_2):
    reg = Registry(
        session=session, 
        database_name=session.get_current_database(), 
        schema_name='MODEL_REGISTRY', 
        options={'enable_monitoring':True},
    )
    local_model_object1 = reg.get_model('CUSTOMER_REVENUE_MODEL').version(version_name_1).load()
    feature_cols1 = local_model_object1.input_cols
    plot_data1 = pd.DataFrame(
        list(zip(feature_cols1, local_model_object1.to_xgboost().feature_importances_)), 
        columns=['FEATURE','IMPORTANCE']
    )
    local_model_object2 = reg.get_model('CUSTOMER_REVENUE_MODEL').version(version_name_2).load()
    feature_cols2 = local_model_object2.input_cols
    plot_data2 = pd.DataFrame(
        list(zip(feature_cols2, local_model_object2.to_xgboost().feature_importances_)), 
        columns=['FEATURE','IMPORTANCE']
    )
    col1, col2 = st.columns(2)
    with col1:
        fig = px.bar(
            plot_data1.sort_values('IMPORTANCE', ascending=False).head(10),
            x="IMPORTANCE",
            y="FEATURE",
            title=f"Feature Importance Model {version_name_1}",
            labels={"FEATURE": "Feature", "IMPORTANCE": "Importance"},
            orientation="h"
        )
        st.plotly_chart(fig, use_container_width=True)
    with col2:
        fig = px.bar(
            plot_data2.sort_values('IMPORTANCE', ascending=False).head(10),
            x="IMPORTANCE",
            y="FEATURE",
            title=f"Feature Importance Model {version_name_2}",
            labels={"FEATURE": "Feature", "IMPORTANCE": "Importance"},
            orientation="h"
        )
        st.plotly_chart(fig, use_container_width=True)

# Function that extracts the actual Python code returned by mistral
def extract_python_code(text):
    # Regular expression pattern to extract content between triple backticks with 'python' as language identifier
    pattern = r"```python(.*?)```"

    # re.DOTALL allows the dot (.) to match newlines as well
    match = re.search(pattern, text, re.DOTALL)
    
    if match:
        # Return the matched group, stripping any leading or trailing whitespace
        return match.group(1).strip()
    else:
        return "No Python code found in the input string."

def plot_inshop_vs_online_revenue(transactions_df):
    aggregated_df = (
        transactions_df.with_column("MONTH", F.date_trunc("month", F.col("DATE")))
        .group_by("MONTH", "TRANSACTION_CHANNEL")
        .agg(F.sum("TRANSACTION_AMOUNT").alias("TOTAL_REVENUE"))
    )
    
    # 2. Bring the aggregated results into a Pandas DataFrame for further processing
    pdf = aggregated_df.to_pandas()
    
    # Optional: Convert the MONTH column to a more readable string format (e.g., 'YYYY-MM')
    pdf["MONTH"] = pdf["MONTH"].dt.strftime('%Y-%m')
    
    # 3. Compute the monthly total revenue and calculate percentage for each transaction channel
    pdf["monthly_total"] = pdf.groupby("MONTH")["TOTAL_REVENUE"].transform("sum")
    pdf["TOTAL_REVENUE"] = pdf["TOTAL_REVENUE"] / pdf["monthly_total"] * 100
    
    # 4. Create a stacked bar chart using the computed percentage values
    fig = px.bar(
        pdf,
        x="MONTH",
        y="TOTAL_REVENUE",
        color="TRANSACTION_CHANNEL",
        barmode="stack",
        labels={
            "TOTAL_REVENUE": "Percentage of Revenue",
            "MONTH": "Month",
            "TRANSACTION_CHANNEL": "Transaction Channel"
        },
        text=pdf['TOTAL_REVENUE'].apply(lambda x: f"{x/100:.0%}"),
        title="Monthly Revenue Distribution by Transaction Channel (Normalized to 100%)"
    )
    
    # Update y-axis to display percentage signs
    fig.update_layout(yaxis=dict(ticksuffix="%"))
    
    fig.update_xaxes(
        dtick="M1",
        tickformat="%b %Y"  # Format tick labels as "Jan 2023", adjust as needed
    )
    
    # 5. Display the figure
    st.plotly_chart(fig, use_container_width=True)

import json
import networkx as nx
import plotly.graph_objects as go
import pandas as pd
import numpy as np

def visualize_lineage(df: pd.DataFrame, short_names: bool = False, initial_zoom: float = 1.0):
    """
    Visualize a lineage graph given a DataFrame with columns:
      - SOURCE_OBJECT (JSON string)
      - TARGET_OBJECT (JSON string)
      - DIRECTION (e.g. "Upstream")
      - DISTANCE (an integer: the number of steps from the ultimate target)
    
    The ultimate target is taken from row 0's TARGET_OBJECT and is assigned distance 0.
    
    Parameters:
      df: pandas DataFrame containing the lineage information.
      short_names: If True, node labels will be shortened (e.g. by taking the last dot‐separated part).
      initial_zoom: A scale factor for the initial zoom level (default 1.0). Values > 1 zoom in;
                    values < 1 zoom out.
    
    Nodes are arranged in vertical columns by distance (with nodes farthest from the target on the left).
    Each node is colored based on its domain, and a legend is added for the node colors.
    """
    # Create an empty directed graph.
    G = nx.DiGraph()

    # Parse the ultimate target from row 0's TARGET_OBJECT and add it with distance 0.
    ultimate_target = json.loads(df.iloc[0]["TARGET_OBJECT"])
    ultimate_target_id = ultimate_target["name"]
    G.add_node(ultimate_target_id, domain=ultimate_target.get("domain", "Unknown"), distance=0)

    # Loop through each row to add nodes and edges.
    # We assume that the "DISTANCE" column applies to the SOURCE_OBJECT.
    for idx, row in df.iterrows():
        # Parse the source object.
        try:
            src_obj = json.loads(row["SOURCE_OBJECT"])
        except Exception as e:
            print(f"Error parsing SOURCE_OBJECT at row {idx}: {e}")
            continue
        src_id = src_obj.get("name")
        src_domain = src_obj.get("domain", "Unknown")
        src_distance = row["DISTANCE"]  # distance from ultimate target

        # Add or update source node with its distance (keeping the smaller distance if node exists).
        if src_id in G.nodes:
            G.nodes[src_id]["distance"] = min(G.nodes[src_id]["distance"], src_distance)
        else:
            G.add_node(src_id, domain=src_domain, distance=src_distance)

        # Parse the target object.
        try:
            tgt_obj = json.loads(row["TARGET_OBJECT"])
        except Exception as e:
            print(f"Error parsing TARGET_OBJECT at row {idx}: {e}")
            continue
        tgt_id = tgt_obj.get("name")
        tgt_domain = tgt_obj.get("domain", "Unknown")
        # For non-ultimate targets we assign distance = (source distance - 1).
        # (This works as long as the lineage chain is consistent.)
        if tgt_id == ultimate_target_id:
            tgt_distance = 0
        else:
            tgt_distance = row["DISTANCE"] - 1

        if tgt_id in G.nodes:
            G.nodes[tgt_id]["distance"] = min(G.nodes[tgt_id]["distance"], tgt_distance)
        else:
            G.add_node(tgt_id, domain=tgt_domain, distance=tgt_distance)

        # Add an edge from source to target (i.e. upstream relationship).
        G.add_edge(src_id, tgt_id)

    # --- Compute layout positions ------------------------------------------------
    # Arrange nodes in vertical columns by distance.
    # Get the maximum distance (farthest from the ultimate target).
    distances = [data["distance"] for _, data in G.nodes(data=True)]
    max_distance = max(distances)

    # Group nodes by their distance value.
    distance_groups = {}  # distance -> list of node ids.
    for node, data in G.nodes(data=True):
        d = data["distance"]
        distance_groups.setdefault(d, []).append(node)

    # Assign positions:
    #   x-coordinate: use max_distance - d so that nodes with highest d appear on the left.
    #   y-coordinate: for nodes with the same d, spread them evenly vertically.
    pos = {}
    for d, nodes in distance_groups.items():
        nodes_sorted = sorted(nodes)  # sort alphabetically for stability.
        n = len(nodes_sorted)
        # Create y positions so that they are centered around 0.
        y_positions = np.linspace((n - 1) / 2, -(n - 1) / 2, n)
        x = max_distance - d  # ultimate target (d=0) gets the rightmost x value.
        for i, node in enumerate(nodes_sorted):
            pos[node] = (x, y_positions[i])

    # --- Determine axis ranges based on initial_zoom -----------------------------
    # Compute the min and max for x and y positions.
    xs = [p[0] for p in pos.values()]
    ys = [p[1] for p in pos.values()]
    if xs:
        x_min, x_max = min(xs), max(xs)
    else:
        x_min, x_max = -1, 1
    if ys:
        y_min, y_max = min(ys), max(ys)
    else:
        y_min, y_max = -1, 1

    # Compute center and half-width/half-height.
    x_center = (x_min + x_max) / 2
    y_center = (y_min + y_max) / 2
    # Add a margin factor (here 1.2) so nodes are not at the very edge.
    margin_factor = 1.2
    x_half = ((x_max - x_min) / 2) * margin_factor / initial_zoom
    y_half = ((y_max - y_min) / 2) * margin_factor / initial_zoom

    x_range = [x_center - x_half, x_center + x_half]
    y_range = [y_center - y_half, y_center + y_half]

    # --- Create Plotly traces for edges ------------------------------------------
    # Prepare a single trace for edges (drawn as line segments).
    edge_x = []
    edge_y = []
    for u, v in G.edges():
        x0, y0 = pos[u]
        x1, y1 = pos[v]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])

    edge_trace = go.Scatter(
        x=edge_x, 
        y=edge_y,
        line=dict(width=1, color='#888'),
        hoverinfo='none',
        mode='lines'
    )

    # --- Create Plotly traces for nodes ------------------------------------------
    # Define a color mapping for domains (customize as needed)
    color_map = {
        "MODEL": "#FF5733",         # reddish
        "DATASET": "#33C3FF",        # blueish
        "TABLE": "#33FF57",          # greenish
        "FEATURE_VIEW": "#FF33F6",   # magenta-ish
    }

    # Group nodes by domain so that a separate trace (and legend entry) is created per domain.
    domain_nodes = {}
    for node, data in G.nodes(data=True):
        domain = data.get("domain", "Unknown")
        # Shorten the label if required.
        label = node.split('.')[-1] if short_names else node
        domain_nodes.setdefault(domain, {"x": [], "y": [], "text": []})
        x, y = pos[node]
        domain_nodes[domain]["x"].append(x)
        domain_nodes[domain]["y"].append(y)
        domain_nodes[domain]["text"].append(label)

    node_traces = []
    for domain, values in domain_nodes.items():
        trace = go.Scatter(
            x=values["x"],
            y=values["y"],
            mode='markers+text',
            name=domain,  # legend entry will show the domain name.
            text=values["text"],
            textposition="bottom center",
            hoverinfo='text',
            marker=dict(
                showscale=False,
                color=color_map.get(domain, "#CCCCCC"),
                size=30,
                line_width=2
            )
        )
        node_traces.append(trace)

    # --- Create and show the figure ----------------------------------------------
    fig = go.Figure(
        data=[edge_trace] + node_traces,
        layout=go.Layout(
            title="Lineage Visualization",
            titlefont_size=16,
            showlegend=True,
            hovermode='closest',
            margin=dict(b=20, l=20, r=20, t=40),
            xaxis=dict(range=x_range, showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(range=y_range, showgrid=False, zeroline=False, showticklabels=False)
        )
    )
    st.plotly_chart(fig, use_container_width=True)
    

print('Demo Setup finished.')

## 1 - Setup Environment

In [None]:
# Import python packages
import plotly.graph_objects as go
import plotly.express as px
import streamlit as st
import pandas as pd
import json
from tabulate import tabulate

# Import Snowflake packages
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark import functions as F
from snowflake.snowpark.functions import lit, col
from snowflake.ml.modeling.xgboost import XGBRegressor
from snowflake.ml.modeling.metrics import mean_absolute_percentage_error
from snowflake.ml.registry import Registry
from snowflake.ml.monitoring.entities.model_monitor_config import ModelMonitorSourceConfig, ModelMonitorConfig
from snowflake.ml.feature_store import (
    FeatureStore,
    FeatureView,
    Entity,
    CreationMode
)
from snowflake.cortex import Complete

# Create a session
session = get_active_session()

## 2 - Data Exploration & Visualization

In [None]:
transactions_df = session.table('MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')

print(f'Number of transactions: {transactions_df.count()}')
print('Transactions Data:')
transactions_df.order_by(col('DATE').desc()).show()

print('Quick Variable Analysis:')
transactions_df.describe().order_by('SUMMARY').show()

### Plotting Data

In [None]:
model = 'mistral-large2'
prompt = f"""
I have a Snowpark Dataframe called transactions_df with the following columns: {transactions_df.columns}
Write code using Snowpark Python to aggregate the data showing the total monthly revenue (TOTAL_REVENUE) from all channels and month (MONTH).
Afterwards use the data to create a plotly bar chart to show total revenue per month. For the x-axis use dtick="M1".
Make sure to use the container-width for the plotly chart.
Only return the code to transform the dataframe and plot the data using Plotly in Streamlit.
"""
try:
    result = Complete(model, prompt)
    result = extract_python_code(result)
    exec(result)
except Exception as e:
    st.error(e)

In [None]:
# BACKUP
# Aggregate the data to show total monthly revenue
monthly_revenue_df = (
    transactions_df
    .with_column("MONTH", F.date_trunc("month", F.col("DATE")))
    .group_by("MONTH")
    .agg(F.sum("TRANSACTION_AMOUNT").as_("TOTAL_REVENUE"))
).to_pandas()

# Create a Plotly bar chart
fig = px.bar(
    monthly_revenue_df, 
    x="MONTH", 
    y="TOTAL_REVENUE", 
    title="Total Revenue per Month", 
    labels={"MONTH": "Month", "TOTAL_REVENUE": "Total Revenue"},
)

fig.update_xaxes(
    dtick="M1",
    tickformat="%b %Y"  # Format tick labels as "Jan 2023", adjust as needed
)

st.plotly_chart(fig, use_container_width=True)

In [None]:
plot_inshop_vs_online_revenue(transactions_df)

## 3 - Feature Store & Feature Engineering

### Setup the Feature Store

In [None]:
fs = FeatureStore(
    session=session, 
    database=session.get_current_database(), 
    name='FEATURE_STORE', 
    default_warehouse=session.get_current_warehouse(),
    creation_mode=CreationMode.CREATE_IF_NOT_EXIST,
)

### Create a Feature Store Entity "CUSTOMER"

In [None]:
# Create a new entity for the Feature Store
entity = Entity(name="CUSTOMER", join_keys=["CUSTOMER_ID"], desc='Unique identifier for customers.')
fs.register_entity(entity)
fs.list_entities().show()

### Add Transaction Features about Customers

In [None]:
def col_formatter(input_col, agg, window):
    feature_name = f"{agg.replace('SUM','TOTAL')}_{input_col}_{window.replace('-', 'past_').replace('MM','_MONTHS')}"
    return feature_name

in_shop_transaction_features = (
    transactions_df.filter(col('TRANSACTION_CHANNEL') == 'IN_SHOP')
    .group_by(['CUSTOMER_ID','DATE']).agg(F.sum('TRANSACTION_AMOUNT').as_('REVENUE'))
    .rename({'REVENUE':'REVENUE_IN_SHOP'})
    .analytics.time_series_agg(
        aggs={'REVENUE_IN_SHOP':['SUM']},
        windows=['-1MM','-2MM','-3MM'],
        sliding_interval="1D",
        group_by=['CUSTOMER_ID'],
        time_col='DATE',
        col_formatter=col_formatter
    ).drop(['SLIDING_POINT','REVENUE_IN_SHOP'])
)

online_transaction_features = (
    transactions_df.filter(col('TRANSACTION_CHANNEL') == 'ONLINE')
    .group_by(['CUSTOMER_ID','DATE']).agg(F.sum('TRANSACTION_AMOUNT').as_('REVENUE'))
    .rename({'REVENUE':'REVENUE_ONLINE'})
    .analytics.time_series_agg(
        aggs={'REVENUE_ONLINE':['SUM']},
        windows=['-1MM','-2MM','-3MM'],
        sliding_interval="1D",
        group_by=['CUSTOMER_ID'],
        time_col='DATE',
        col_formatter=col_formatter
    ).drop(['SLIDING_POINT','REVENUE_ONLINE'])
)

In [None]:
in_shop_transaction_features.filter(col('CUSTOMER_ID') == 1).order_by(col('DATE').desc()).show()

In [None]:
# Use LLM to generate feature descriptions
model = 'mistral-large2'

feature_columns = in_shop_transaction_features.drop('CUSTOMER_ID','DATE').columns
prompt = f'Return a JSON string with column names as keys and a short business description as values. The columns are: {feature_columns}. Do not wrap the json codes in JSON markers.'
llm_response = Complete(model, prompt, stream=False)
feature_descriptions_in_shop_transactions = json.loads(llm_response)

feature_columns = online_transaction_features.drop('CUSTOMER_ID','DATE').columns
prompt = f'Return a JSON string with column names as keys and a short business description as values. The columns are: {feature_columns}. Do not wrap the json codes in JSON markers.'
llm_response = Complete(model, prompt, stream=False)
feature_descriptions_online_transactions = json.loads(llm_response)

st.json(feature_descriptions_in_shop_transactions)
st.json(feature_descriptions_online_transactions)

In [None]:
# Create Feature View
in_shop_transaction_fv = FeatureView(
    name="IN_SHOP_REVENUE_FEATURES", 
    entities=[entity],
    timestamp_col='DATE',
    feature_df=in_shop_transaction_features, 
    refresh_freq="1 minute",
    refresh_mode='AUTO',
    desc="Features for in-shop transactions",
    overwrite=True
)

# Add descriptions for some features
in_shop_transaction_fv = in_shop_transaction_fv.attach_feature_desc(feature_descriptions_in_shop_transactions)

in_shop_transaction_fv = fs.register_feature_view(
    feature_view=in_shop_transaction_fv, 
    version="V1", 
    block=True,
    overwrite=True
)

# Create Feature View
online_transaction_fv = FeatureView(
    name="ONLINE_REVENUE_FEATURES", 
    entities=[entity],
    timestamp_col='DATE',
    feature_df=online_transaction_features, 
    refresh_freq="1 minute",
    refresh_mode='AUTO',
    desc="Features for online transactions",
    overwrite=True
)

# Add descriptions for some features
online_transaction_fv = online_transaction_fv.attach_feature_desc(feature_descriptions_online_transactions)

online_transaction_fv = fs.register_feature_view(
    feature_view=online_transaction_fv, 
    version="V1", 
    block=True,
    overwrite=True
)

## 4 - Model Training

### Generate the Training Dataset with Features from Feature Store

In [None]:
# Target: Predict total revenue per customer for October 2023
target_df = session.table('MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')
target_df = (
    target_df.filter(col('DATE').between('2024-04-01','2024-04-30'))    # Generate Target Variable for April 2024
    .group_by('CUSTOMER_ID')
    .agg(F.sum('TRANSACTION_AMOUNT').as_('NEXT_MONTH_REVENUE'))
    .with_column('FEATURE_CUTOFF_DATE', F.to_date(lit('2024-03-31')))   # Features until End of March 2024
)

# Get list of all customers
customers_df = session.table('MLOPS_DEMO.RETAIL_DATA.CUSTOMERS').select('CUSTOMER_ID').distinct()

# Create spine dataframe
spine_df = target_df.join(customers_df, on=['CUSTOMER_ID'], how='outer')
spine_df = spine_df.fillna(0, subset='NEXT_MONTH_REVENUE')
spine_df.order_by('CUSTOMER_ID').show()

In [None]:
train_dataset = fs.generate_dataset(
    name="MLOPS_DEMO.FEATURE_STORE.NEXT_MONTH_REVENUE_DATASET",
    spine_df=spine_df,
    features=[in_shop_transaction_fv, online_transaction_fv],
    version="V1",
    spine_timestamp_col="FEATURE_CUTOFF_DATE",
    spine_label_cols=["NEXT_MONTH_REVENUE"],
    include_feature_view_timestamp_col=False,
    desc="Initial Training Dataset"
)

df = train_dataset.read.to_snowpark_dataframe()
df.show()

### Train an XGBoost Model

In [None]:
# Split the data into train and test sets
train_df, test_df = df.random_split(weights=[0.9, 0.1], seed=0)

print(f'Number of samples in train: {train_df.count()}')
print(f'Number of samples in test: {test_df.count()}')

feature_columns = train_df.drop(['CUSTOMER_ID','FEATURE_CUTOFF_DATE','NEXT_MONTH_REVENUE']).columns

xgb_model = XGBRegressor(
    input_cols=feature_columns,
    label_cols=['NEXT_MONTH_REVENUE'],
    output_cols=['NEXT_MONTH_REVENUE_PREDICTION'],
    n_estimators=100,
    learning_rate=0.05,
    random_state=0
)

xgb_model = xgb_model.fit(train_df)

### Evaluate the XGBoost Model

In [None]:
predictions = xgb_model.predict(test_df)
# Analyze results
mape = mean_absolute_percentage_error(
    df=predictions, 
    y_true_col_names="NEXT_MONTH_REVENUE", 
    y_pred_col_names="NEXT_MONTH_REVENUE_PREDICTION"
)

print(f"Mean absolute percentage error: {mape}")

col1, col2 = st.columns(2)
with col1:
    # Plot Feature Importance
    plot_data = pd.DataFrame(
        list(zip(feature_columns, xgb_model.to_xgboost().feature_importances_)), 
        columns=['FEATURE','IMPORTANCE']
    )
    
    fig = px.bar(
        plot_data.sort_values('IMPORTANCE', ascending=False).head(10),
        x="IMPORTANCE",
        y="FEATURE",
        title="Feature Importance",
        labels={"FEATURE": "Feature", "IMPORTANCE": "Importance"},
        orientation="h"
    )
    st.plotly_chart(fig, use_container_width=True)
with col2:
    # Plot Predictions
    fig = px.scatter(
        predictions["NEXT_MONTH_REVENUE", "NEXT_MONTH_REVENUE_PREDICTION"].to_pandas().astype("float64"),
        x="NEXT_MONTH_REVENUE",
        y="NEXT_MONTH_REVENUE_PREDICTION",
        title="Actual vs Predicted Revenue",
        labels={
            "NEXT_MONTH_REVENUE": "Actual Revenue",
            "NEXT_MONTH_REVENUE_PREDICTION": "Predicted Revenue"
        },
        trendline="ols",
        trendline_color_override="red"
    )
    st.plotly_chart(fig, use_container_width=True)

In [None]:
# Save baseline predictions
predictions = predictions.with_column('FEATURE_CUTOFF_DATE', F.col('FEATURE_CUTOFF_DATE').cast('timestamp'))
predictions = predictions.with_column('NEXT_MONTH_REVENUE_PREDICTION', F.col('NEXT_MONTH_REVENUE_PREDICTION').cast('number(38,2)'))
predictions = predictions.with_column('NEXT_MONTH_REVENUE', F.col('NEXT_MONTH_REVENUE').cast('number(38,2)'))
predictions.write.save_as_table('MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_V1', mode='overwrite')

## 5 - Snowflake Model Registry
### Setup Model Registry

In [None]:
# Create reference to model registry
reg = Registry(
    session=session, 
    database_name=session.get_current_database(), 
    schema_name='MODEL_REGISTRY', 
    options={'enable_monitoring':True},
)

### Register Model in Model Registry

In [None]:
registered_model = reg.log_model(
    xgb_model,
    model_name="CUSTOMER_REVENUE_MODEL",
    version_name='V1',
    metrics={
        'MAPE':mape, 
        'FEATURE_IMPORTANCE':dict(zip(feature_columns, xgb_model.to_xgboost().feature_importances_.astype('float'))),
        "TRAINING_DATA":{'FEATURE_CUTOFF_DATE':'2024-03-31'}
    },
    comment="Model trained using XGBoost to predict revenue per customer for next month.",
    conda_dependencies=['xgboost'],
    sample_input_data=train_df.select(feature_columns).limit(10),
    options={"relax_version": False, "enable_explainability": True}
)

In [None]:
# Set this model version as PRODUCTION
registered_model.set_alias('PRODUCTION')

In [None]:
explanations = registered_model.run(test_df, function_name="explain")
explanations = explanations.rename({col:col.replace('"""', '').upper() for col in explanations.columns})
explanations = explanations.select([col for col in explanations.columns if '_EXPLANATION' in col])
explanations = explanations.to_pandas()

import shap
shap_exp = shap._explanation.Explanation(explanations.values, feature_names = explanations.columns) # wrapping them into a SHAP recognized object
shap.plots.bar(shap_exp)

In [None]:
trace = session.lineage.trace(
    object_name='MLOPS_DEMO.MODEL_REGISTRY.CUSTOMER_REVENUE_MODEL',
    object_version='V1',
    object_domain='model',
    direction='both',
    distance=2
)
trace.show()

In [None]:
visualize_lineage(trace.to_pandas(), short_names=True)

### Continious Model Monitoring

In [None]:
feature_df = get_feature_df(session, feature_cutoff_date='2024-04-30')
feature_df.show()

# Predict May values
predictions = registered_model.run(feature_df, function_name='PREDICT')
predictions = predictions.with_column('FEATURE_CUTOFF_DATE', F.col('FEATURE_CUTOFF_DATE').cast('timestamp'))
predictions = predictions.with_column('NEXT_MONTH_REVENUE_PREDICTION', F.col('NEXT_MONTH_REVENUE_PREDICTION').cast('number(38,2)'))
predictions.write.save_as_table(table_name='MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_V1', mode='overwrite')

### Create a Model Monitor

In [None]:
# Enable once 1.7.3 with bugfix is available
# source_config = ModelMonitorSourceConfig(
#     source='MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE',
#     timestamp_column='FEATURE_CUTOFF_DATE',
#     id_columns=['CUSTOMER_ID'],
#     prediction_score_columns=['NEXT_MONTH_REVENUE_PREDICTION'],
#     actual_score_columns=['NEXT_MONTH_REVENUE'],
#     baseline='MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_V1'
# )

# monitor_config = ModelMonitorConfig(
#     model_version=reg.get_model('CUSTOMER_REVENUE_MODEL').version('PRODUCTION'),
#     model_function_name='predict',
#     background_compute_warehouse_name='COMPUTE_WH',
#     refresh_interval='1 minute',
#     aggregation_window='1 day'
# )

# reg.add_monitor(
#     name='MLOPS_DEMO.MODEL_REGISTRY.MM_V1',
#     source_config=source_config,
#     model_monitor_config=monitor_config
# )

In [None]:
CREATE OR REPLACE MODEL MONITOR MLOPS_DEMO.MODEL_REGISTRY.MM_V1 WITH
    MODEL=MLOPS_DEMO.MODEL_REGISTRY.CUSTOMER_REVENUE_MODEL VERSION=V1 FUNCTION=PREDICT
    SOURCE=MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_V1
    BASELINE=MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_V1,
    TIMESTAMP_COLUMN='FEATURE_CUTOFF_DATE'
    ID_COLUMNS=('CUSTOMER_ID')
    PREDICTION_SCORE_COLUMNS=('NEXT_MONTH_REVENUE_PREDICTION')
    ACTUAL_SCORE_COLUMNS=('NEXT_MONTH_REVENUE')
    WAREHOUSE=COMPUTE_WH
    REFRESH_INTERVAL='1 minute'
    AGGREGATION_WINDOW='1 day'

In [None]:
# Add new transactions
new_transactions = session.table('MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS').filter(col('DATE').between('2024-05-01','2024-05-31'))
new_transactions.write.save_as_table(table_name='MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS', mode='append')

# Calculate actual values
actual_values_df = (
    session.table('MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')
    .filter(col('DATE').between('2024-05-01','2024-05-31'))
    .group_by(['CUSTOMER_ID'])
    .agg(F.sum('TRANSACTION_AMOUNT').as_('TOTAL_REVENUE'))
    .with_column('DATE', F.to_date(lit('2024-04-30')))
)

# Get list of all customers
customers_df = session.table('MLOPS_DEMO.RETAIL_DATA.CUSTOMERS').select('CUSTOMER_ID').distinct()

# Assume 0 revenue for customers without transactions
actual_values_df = actual_values_df.join(customers_df, on=['CUSTOMER_ID'], how='outer')
actual_values_df = actual_values_df.fillna(0,subset='TOTAL_REVENUE')

# Update source table from model monitor
source_table = session.table('MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_V1')
source_table.update(
    condition=(
        (source_table['FEATURE_CUTOFF_DATE'] == actual_values_df['DATE']) &
        (source_table['CUSTOMER_ID'] == actual_values_df['CUSTOMER_ID'])
    ),
    assignments={
        "NEXT_MONTH_REVENUE": actual_values_df['TOTAL_REVENUE'],
    },
    source=actual_values_df
)

## Simulate the rest of the year

In [None]:
start_date = '2024-06-01'
end_date = '2025-01-31'
model_version = 'V1'

simulate_model_performance(session, start_date, end_date, model_version, generate_data=True)

## Explore the Model Monitor
Navigate to the Model Monitor and observe the `MAPE` and `Difference of means`  for the last months.  

You will notice the following:
* Declining Model Performance
    * :arrow_up_small: MAPE (Mean Average Percentage Error)
* Feature Drift
    * :arrow_down_small: Difference of means for TOTAL_REVENUE_IN_SHOP_PAST_1_MONTHS (less in shop transaction volume)
    * :arrow_up_small: Difference of means for TOTAL_REVENUE_ONLINE_PAST_1_MONTHS (more online transaction volume)

If we visualize the monthly revenue distribution, we can see that online revenue grew while in-shop transaction declined.

In [None]:
plot_inshop_vs_online_revenue(transactions_df)

## Train a new version
Given that the user behavior changed, we'll train a new version of our model with fresh data.

In [None]:
feature_cutoff_date = '2024-08-31'
target_start_date = '2024-09-01'
target_end_date = '2024-09-30'
model_version = 'V2'

train_new_model(session, feature_cutoff_date, target_start_date, target_end_date, model_version)

In [None]:
compare_two_models(session,'V1','V2')

In [None]:
start_date = '2024-10-01'
end_date = '2025-01-31'
model_version = 'V2'

simulate_model_performance(session, start_date, end_date, model_version, generate_data=False)