# Use Case: Predicting Future Customer Revenue Using Historical Transaction Data 📊

# 1 - Setup Demo 🛠️
* Import required libraries
* Create a Snowpark session

| Library    | Use |
| -------- | ------- |
| `snowflake.snowpark` | Main Python Developer Framework for Snowflake including the DataFrame-API     |
| `snowflake.ml`    | Snowflake ML specific functions including Feature Store & Model Registry APIs    |
| `snowflake.cortex`    | Snowflake APIs to access Cortex Services (e.g. LLMs)    |
| `helper_functions`  | Demo-specific functions that are nort part of any official module    |
| `notebook_copilot`  | Convenience Functions for Snowflake Notebooks. More details [here](google.com).    |

In [None]:
# Helper functions for this demo
from helper_functions.setup_environment import setup_demo
from notebook_copilot.cortex_helper import (
    cortex_helper_ui,
    cortex_helper_visualize_query,
    cortex_helper_explain_column_sql,
    cortex_helper_describe_columns
)
from helper_functions.plotting import plot_inshop_vs_online_revenue, visualize_lineage, compare_two_models
from helper_functions.mlops import train_new_model, simulate_model_performance
from helper_functions.misc import get_function_source_recursively, get_snowsight_url


# Import python packages
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from streamlit import dataframe as sdf
import pandas as pd
import json
import shap
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

# Import Snowflake packages
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark import functions as F
from snowflake.snowpark.functions import lit, col
from snowflake.ml.modeling.xgboost import XGBRegressor
from snowflake.ml.modeling.metrics import mean_absolute_percentage_error
from snowflake.ml.registry import Registry
from snowflake.ml.monitoring.entities.model_monitor_config import ModelMonitorSourceConfig, ModelMonitorConfig
from snowflake.ml.feature_store import (
    FeatureStore,
    FeatureView,
    Entity,
    CreationMode
)
from snowflake.cortex import complete

# Create a session
session = get_active_session()
setup_demo(session)

# 2 - Data Exploration & Visualization

* `session.table()` creates a reference to a table
* `count()`, `order_by()`, `describe()` are dataframe operations
* `describe()` gives us insights into the transaction amounts (e.g. min, average, max, count).

We can see that we have roughly 50K transactions across 350 customers.

In [None]:
transactions_df = session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')

print(f'Number of transactions: {transactions_df.count()}')
print(f'Number of customers: {transactions_df.select("CUSTOMER_ID").distinct().count()}')

print('Transactions Data:')
transactions_df.order_by(col('DATE').desc()).show()

print('Quick Variable Analysis:')
transactions_df.describe().order_by('SUMMARY').show()

### Plotting Data
* For this notebook I developed two convenience functions for you:  
    * `cortex_helper_ui()` which will open a simple user interface (based on Streamlit) to Snowflake Cortex
    * `cortex_helper_visualize_query()` receives a Snowpark or Pandas Dataframe and a prompt (in case you already know the dataframe and query)

Both functions utilize Snowflake's [complete()](https://docs.snowflake.com/en/developer-guide/snowpark-ml/reference/latest/api/cortex/snowflake.cortex.complete) function to access LLMs natively hosted in Snowflake.

Try asking the following questions:  
* ***What was the overall revenue per channel and month? Use a stacked bar plot and use YY-Monthname for the x-axis.***
* ***What was the total transaction amount per channel? Use a pie chart.***

In [None]:
with st.expander('**Source Code:** cortex_helper_visualize_query()'):
    st.code(get_function_source_recursively(cortex_helper_visualize_query, max_depth=0), language='python')

with st.expander('**Source Code:** get_cortex_helper()'):
    st.code(get_function_source_recursively(cortex_helper_ui, max_depth=1), language='python')

In [None]:
cortex_helper_ui()

When we plot the distribution of ONLINE vs. IN_SHOP revenue, we can see that 75% of our revenue comes from customer transactions that go into our shops.  
A model trained on this data should recognize that IN_SHOP transactions are the major driver of future customer revenue.

In [None]:
cortex_helper_visualize_query(transactions_df, 'What was the total transaction amount per channel? Use a pie chart.')

# 3 - Feature Store & Feature Engineering
The Snowflake Feature Store enables data scientists and ML engineers to create, manage, and utilize machine learning features within machine learning pipelines.  
A feature store consists of feature views, which encapsulate Python or SQL pipelines that transform raw data into one or more related features.  
All features within a feature view are refreshed simultaneously from the source data.

Feature store objects are implemented as Snowflake objects and all feature store objects are therefore subject to Snowflake access control rules.
| Feature Store Object    | Snowflake Object |
| -------- | ------- |
| `FeatureStore` | Schema     |
| `Entity`    | Tag    |
| `FeatureView`  | Dynamic Table or View    |
| `Feature`  | Column in a Dynamic Table or View    |

### Setup the Feature Store
We are creating (or referencing if it already exists) a Feature Store that is stored in the schema `FEATURE_STORE`.  
The `default_warehouse` will be used to refresh features automatically.

In [None]:
fs = FeatureStore(
    session=session, 
    database=session.get_current_database(), 
    name='FEATURE_STORE', 
    default_warehouse='FEATURE_STORE_WH',
    creation_mode=CreationMode.CREATE_IF_NOT_EXIST,
)

### Create a Feature Store Entity
Feature views are organized in the feature store according to the entities to which they apply. An entity is a higher-level abstraction that represents the subject matter of a feature.  
In our example, the main entity is the `CUSTOMER` and the features we will create will be linked to this entity.

In [None]:
# Create a new entity for the Feature Store
entity = Entity(name="CUSTOMER", join_keys=["CUSTOMER_ID"], desc='Unique identifier for customers.')
fs.register_entity(entity)
fs.list_entities().show()

### Develop Features for Customer Transactions

The Snowpark Python API provides analytics functions for easily defining many common feature types, such as windowed aggregations.  
We will use `analytics.time_series_agg()` to quickly generate revenue for the past 1, 2 and 3 months per customer per channel which we will use as features for our machine learning model.

The feature dataframe should have the following columns:
| Column    | Purpose |
| -------- | ------- |
| `CUSTOMER_ID` | Identify relevant rows for the calculated feature (Join-Criteria)     |
| `DATE`    | Allow correct Point-in-Time Joins   |
| `Feature columns`  | Actual features per entity    |  

You can find more functions for quickly generating featueres here:  
[Common feature and query patterns](https://docs.snowflake.com/en/developer-guide/snowflake-ml/feature-store/examples)

In [None]:
def col_formatter(input_col, agg, window):
    feature_name = f"{agg.replace('SUM','TOTAL')}_{input_col}_{window.replace('-', 'past_').replace('MM','_MONTHS')}"
    return feature_name

in_shop_transaction_features = (
    transactions_df.filter(col('TRANSACTION_CHANNEL') == 'IN_SHOP')
    .group_by(['CUSTOMER_ID','DATE']).agg(F.sum('TRANSACTION_AMOUNT').as_('REVENUE'))
    .rename({'REVENUE':'REVENUE_IN_SHOP'})
    .analytics.time_series_agg(
        aggs={'REVENUE_IN_SHOP':['SUM']},
        windows=['-1MM','-2MM','-3MM'],
        sliding_interval="1D",
        group_by=['CUSTOMER_ID'],
        time_col='DATE',
        col_formatter=col_formatter
    ).drop(['SLIDING_POINT','REVENUE_IN_SHOP'])
)

online_transaction_features = (
    transactions_df.filter(col('TRANSACTION_CHANNEL') == 'ONLINE')
    .group_by(['CUSTOMER_ID','DATE']).agg(F.sum('TRANSACTION_AMOUNT').as_('REVENUE'))
    .rename({'REVENUE':'REVENUE_ONLINE'})
    .analytics.time_series_agg(
        aggs={'REVENUE_ONLINE':['SUM']},
        windows=['-1MM','-2MM','-3MM'],
        sliding_interval="1D",
        group_by=['CUSTOMER_ID'],
        time_col='DATE',
        col_formatter=col_formatter
    ).drop(['SLIDING_POINT','REVENUE_ONLINE'])
)

In [None]:
online_transaction_features.show()

**Feature Descriptions**  
To avoid manually writing descriptions, we can use `complete()` to have an LLM generate JSON files containing business descriptions.  
These descriptions are stored in the Feature Store alongside our features.

In [None]:
feature_descriptions_in_shop_transactions = cortex_helper_describe_columns(in_shop_transaction_features, exclude_columns=['CUSTOMER_ID','DATE'])
feature_descriptions_online_transactions = cortex_helper_describe_columns(online_transaction_features, exclude_columns=['CUSTOMER_ID','DATE'])

st.json(feature_descriptions_in_shop_transactions)
st.json(feature_descriptions_online_transactions)

### Registering Feature Views
The `FeatureView` class accepts a Snowpark DataFrame object that contains the feature transformation logic. This allows you to define your features using any method supported by the Snowpark DataFrame API or Snowflake SQL. You can pass the DataFrame directly to the `FeatureView` constructor.  

Each `FeatureView` is associated with the corresponding `Entity`.  
The `refresh_freq` parameter determines how often the Feature Store checks for new data and updates the features automatically. For demonstration purposes, this value is set to 1 minute, but it should be adjusted based on the specific use case.

In [None]:
# Create Feature View
in_shop_transaction_fv = FeatureView(
    name="IN_SHOP_REVENUE_FEATURES", 
    entities=[entity],
    timestamp_col='DATE',
    feature_df=in_shop_transaction_features, 
    refresh_freq="1 minute",
    refresh_mode='AUTO',
    desc="Features for in-shop transactions",
    overwrite=True
)

# Add descriptions for some features
in_shop_transaction_fv = in_shop_transaction_fv.attach_feature_desc(feature_descriptions_in_shop_transactions)

in_shop_transaction_fv = fs.register_feature_view(
    feature_view=in_shop_transaction_fv, 
    version="V1", 
    block=True,
    overwrite=True
)

# Create Feature View
online_transaction_fv = FeatureView(
    name="ONLINE_REVENUE_FEATURES", 
    entities=[entity],
    timestamp_col='DATE',
    feature_df=online_transaction_features, 
    refresh_freq="1 minute",
    refresh_mode='AUTO',
    desc="Features for online transactions",
    overwrite=True
)

# Add descriptions for some features
online_transaction_fv = online_transaction_fv.attach_feature_desc(feature_descriptions_online_transactions)

online_transaction_fv = fs.register_feature_view(
    feature_view=online_transaction_fv, 
    version="V1", 
    block=True,
    overwrite=True
)

### Discovering Features via Feature Store UI
After creating entities and feature views, you can utilize the [Feature Store User Interface](https://docs.snowflake.com/en/developer-guide/snowflake-ml/feature-store/feature-store-ui) in Snowsight to locate the objects you need.  

Example of the Feature Store UI:  
![text](https://github.com/michaelgorkow/snowflake_simple_mlops/blob/main/resources/feature_store.png?raw=true)

In [None]:
get_snowsight_url(session, 'Link to Feature Store', '#/features/database/SIMPLE_MLOPS_DEMO/store/FEATURE_STORE/entities')

### Discovering Features via Feature Store API

In [None]:
st.markdown('### List of all Feature Views:')
sdf(fs.list_feature_views())

# Retrieve a Feature View
retrieved_feature_view = fs.get_feature_view(name='IN_SHOP_REVENUE_FEATURES',version='V1')

st.markdown('### Feature View Columns:')
sdf(retrieved_feature_view.list_columns())#.show(max_width=200)

# Manually refresh a Feature View
fs.refresh_feature_view(retrieved_feature_view)

st.markdown('### Feature View Refresh History:')
sdf(fs.get_refresh_history(retrieved_feature_view).limit(3))

# Explore lineage information
st.markdown('### Feature View Lineage:')
st.json(retrieved_feature_view.lineage(direction='both'))

# Use an LLM and the underlying SQL query to explain how the feature is calculated
sql_explanation = cortex_helper_explain_column_sql(sql=retrieved_feature_view.query, column='TOTAL_REVENUE_IN_SHOP_PAST_1_MONTHS')
st.markdown(sql_explanation)

# 4 - Model Training

### Generate the Training Dataset with Features from Feature Store
Our goal is to predict each customer's revenue for the next month based on their transactions from the past three months.  

We have data from January to April 2024. To define our target variable, `NEXT_MONTH_REVENUE`, we sum all transactions from April for each customer. To ensure proper point-in-time feature retrieval and avoid using future data, we only include transaction features up to **March 31, 2024**, and mark this cutoff with the `FEATURE_CUTOFF_DATE` column.  

The DataFrame you just created is a **spine DataFrame**, which acts as a reference table linking customers (`CUSTOMER_ID`) with a timestamp (`FEATURE_CUTOFF_DATE`). It ensures consistent and reproducible feature retrieval in a **feature store**.  

Using this spine, you can generate a training dataset with [`generate_dataset()`](https://docs.snowflake.com/en/developer-guide/snowflake-ml/feature-store/modeling#generating-snowflake-datasets-for-training). The Feature Store will automatically retrieve features as they were valid on that date and add them to the dataset.  

A [Snowflake Dataset](https://docs.snowflake.com/en/developer-guide/snowflake-ml/dataset) is a schema-level object designed for machine learning. It stores data in versions, ensuring immutability, efficient access, and compatibility with ML frameworks.

In [None]:
target_df = session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')
target_df = (
    target_df.filter(col('DATE').between('2024-04-01','2024-04-30'))    # Generate Target Variable for April 2024
    .group_by('CUSTOMER_ID')
    .agg(F.sum('TRANSACTION_AMOUNT').as_('NEXT_MONTH_REVENUE'))
    .with_column('FEATURE_CUTOFF_DATE', F.to_date(lit('2024-03-31')))   # Features until End of March 2024
)

# Get list of all customers
customers_df = session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.CUSTOMERS').select('CUSTOMER_ID').distinct()

# Create spine dataframe
spine_df = target_df.join(customers_df, on=['CUSTOMER_ID'], how='outer')
spine_df = spine_df.fillna(0, subset='NEXT_MONTH_REVENUE')
spine_df.order_by('CUSTOMER_ID').show()

In [None]:
train_dataset = fs.generate_dataset(
    name="SIMPLE_MLOPS_DEMO.FEATURE_STORE.NEXT_MONTH_REVENUE_DATASET",
    spine_df=spine_df,
    features=[in_shop_transaction_fv, online_transaction_fv],
    version="V1",
    spine_timestamp_col="FEATURE_CUTOFF_DATE",
    spine_label_cols=["NEXT_MONTH_REVENUE"],
    include_feature_view_timestamp_col=False,
    desc="Initial Training Dataset"
)

df = train_dataset.read.to_snowpark_dataframe()
df.show()

### Train an XGBoost Model
We randomly split the data, allocating **90% for training** and **10% for validation**.  
The training data is then used to train an **XGBoost regression model** with the `XGBRegressor` from the **Snowflake ML library**.

In [None]:
# Split the data into train and test sets
train_df, test_df = df.random_split(weights=[0.9, 0.1], seed=0)

print(f'Number of samples in train: {train_df.count()}')
print(f'Number of samples in test: {test_df.count()}')

feature_columns = train_df.drop(['CUSTOMER_ID','FEATURE_CUTOFF_DATE','NEXT_MONTH_REVENUE']).columns

xgb_model = XGBRegressor(
    input_cols=feature_columns,
    label_cols=['NEXT_MONTH_REVENUE'],
    output_cols=['NEXT_MONTH_REVENUE_PREDICTION'],
    n_estimators=100,
    learning_rate=0.05,
    random_state=0
)

xgb_model = xgb_model.fit(train_df)

### Evaluate the XGBoost Model
You can immediately use the model’s `predict()` function to generate predictions on the test data.  
Snowflake ML also provides built-in metric functions, such as **Mean Absolute Percentage Error (MAPE)**, for evaluating model performance.  

Additionally, you can convert the model back to its native open-source format using `xgb_model.to_xgboost()`.  
This allows you to access feature importance values, which we visualize to better understand what influences the model’s predictions.  

As shown in the plot, the model correctly identified that **IN_SHOP transactions** are the primary driver of the target variable, `NEXT_MONTH_REVENUE`.

In [None]:
predictions = xgb_model.predict(test_df)
# Analyze results
mape = mean_absolute_percentage_error(
    df=predictions, 
    y_true_col_names="NEXT_MONTH_REVENUE", 
    y_pred_col_names="NEXT_MONTH_REVENUE_PREDICTION"
)

print(f"Mean absolute percentage error: {mape}")

col1, col2 = st.columns(2)
with col1:
    # Plot Feature Importance
    plot_data = pd.DataFrame(
        list(zip(feature_columns, xgb_model.to_xgboost().feature_importances_)), 
        columns=['FEATURE','IMPORTANCE']
    )
    
    fig = px.bar(
        plot_data.sort_values('IMPORTANCE', ascending=False).head(10),
        x="IMPORTANCE",
        y="FEATURE",
        title="Feature Importance",
        labels={"FEATURE": "Feature", "IMPORTANCE": "Importance"},
        orientation="h"
    )
    st.plotly_chart(fig, use_container_width=True)
with col2:
    # Plot Predictions
    fig = px.scatter(
        predictions["NEXT_MONTH_REVENUE", "NEXT_MONTH_REVENUE_PREDICTION"].to_pandas().astype("float64"),
        x="NEXT_MONTH_REVENUE",
        y="NEXT_MONTH_REVENUE_PREDICTION",
        title="Actual vs Predicted Revenue",
        labels={
            "NEXT_MONTH_REVENUE": "Actual Revenue",
            "NEXT_MONTH_REVENUE_PREDICTION": "Predicted Revenue"
        },
        trendline="ols",
        trendline_color_override="red"
    )
    st.plotly_chart(fig, use_container_width=True)

# 5 - Snowflake Model Registry
### Setup Model Registry
After training a model, the first step in operationalizing it and running inference in Snowflake is to **log the model in the Snowflake Model Registry**.  

The **Model Registry** allows you to securely manage models and their metadata in Snowflake, regardless of their origin or type, while also simplifying inference.  
It stores machine learning models as **first-class schema-level objects** within Snowflake.  

By setting `enable_monitoring` to True, the **Model Registry** can also be used for model monitoring, which we will implement in the next step.


In [None]:
# Create reference to model registry
reg = Registry(
    session=session, 
    database_name=session.get_current_database(), 
    schema_name='MODEL_REGISTRY', 
    options={'enable_monitoring':True},
)

### Register Model in Model Registry
The Model Registry's `log_model()` function takes the model object and logs it to the registry.  
The **name** and **version** help ensure the correct model is retrieved for inference.  

Additionally, we log relevant metrics/information, including:  
- **MAPE (Mean Absolute Percentage Error)** calculated on the test dataset  
- **Feature importance values**  
- **FEATURE_CUTOFF_DATE**   

We also specify the following parameters:  

| Variable               | Description  |
|------------------------|-------------|
| `sample_input_data`    | Sample input data used to infer model signatures, serve as background data for explanations, and capture data lineage. |
| `conda_dependencies`   | Specifies model dependencies, such as the XGBoost library. |
| `relax_version`        | Enforces specific dependency versions for compatibility and reproducibility. |
| `enable_explainability` | Adds an explainability function to the model, allowing us to better understand its predictions using SHAP values. |

In [None]:
registered_model = reg.log_model(
    xgb_model,
    model_name="CUSTOMER_REVENUE_MODEL",
    version_name='V1',
    metrics={
        'MAPE':mape, 
        'FEATURE_IMPORTANCE':dict(zip(feature_columns, xgb_model.to_xgboost().feature_importances_.astype('float'))),
        "TRAINING_DATA":{'FEATURE_CUTOFF_DATE':'2024-03-31'}
    },
    comment="Model trained using XGBoost to predict revenue per customer for next month.",
    conda_dependencies=['xgboost'],
    sample_input_data=train_df.select(feature_columns).limit(100),
    options={"relax_version": False, "enable_explainability": True}
)

### Operationalize Models
There are multiple ways to operationalize models using Snowflake's Model Registry.  
One simple approach is to use **aliases** for the model. By assigning the alias **`PRODUCTION`**, any inference pipeline referencing this alias will automatically use the correct production-ready model.  

When a new model version is trained and ready for deployment, you can seamlessly update production by **removing the alias from the current model** and **assigning it to the new model**.  
This method ensures that existing ML pipelines remain unchanged, reducing the need for manual updates while maintaining a smooth model deployment process.

In [None]:
registered_model.set_alias('PRODUCTION')

In [None]:
# Retrieve the production model in your pipelines
production_model = reg.get_model('CUSTOMER_REVENUE_MODEL').version('PRODUCTION')
production_model.show_metrics()

### Explore Models in the Model Registry UI
The [Model Registry UI]((https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/snowsight-ui)) in Snowsight enables you to discover and explore machine learning models available for use in Snowflake.  

To view a model's details, click on its corresponding row in the Models list.  
The details page provides essential information, including the model's description, tags, and versions.

Example of the Model Registry UI:  
![text](https://github.com/michaelgorkow/snowflake_simple_mlops/blob/main/resources/model_registry_ui.png?raw=true)

In [None]:
get_snowsight_url(session, 'Link to Model Registry', '#/models')

### Model explainability
Since we enabled model_explainibility when registering the model, we can now call the explain function of the model that was auto-generated.  
The standard SHAP library is then used to visualise the SHAP values.

**What are SHAP (SHapley Additive exPlanations) values?**  
* SHAP (SHapley Additive exPlanations) values measure how much each feature contributes to the prediction.
* The x-axis represents the mean absolute SHAP value, indicating the magnitude of a feature's impact on the model's predictions.
* The y-axis lists the feature names.
* Longer bars mean the feature has a greater impact on predictions.

**Interpretation**  
What is the meaning of the values?  
For the General Feature Importance on the left, think about it like this:  
On average, the feature (e.g. TOTAL_REVENUE_IN_SHOP_PAST_1_MONTH) affects the model’s output by approximately X units of revenue.  

On a more general note, in our case IN_SHOP revenue consistently ranks higher than online revenue, implying that the model sees in-shop purchases as a stronger signal for future revenue prediction.

On the right side we are plotting the local Feature Importance for single customers.
That means it can happen that certain customers have a strong record of ONLINE transactions and therefore are much more influenced by in-shop features than online features.

In [None]:
# Calculate Shap values
explanations = production_model.run(test_df, function_name="explain")
explanations = explanations.rename({col:col.replace('"""', '').upper() for col in explanations.columns})

shap_columns = [col for col in explanations.columns if '_EXPLANATION' in col]

explanations = explanations.select('CUSTOMER_ID', *shap_columns)
explanations = explanations.to_pandas()

In [None]:
# Plot Shap values
selected_customer = st.selectbox('Select Customer:', options=explanations[['CUSTOMER_ID']].sort_values(by='CUSTOMER_ID'))
selected_explanation = explanations[explanations['CUSTOMER_ID'] == selected_customer]

col1, col2 = st.columns(2)
with col1:
    st.markdown(f'### Global Feature Importance:')
    shap_exp = shap._explanation.Explanation(selected_explanation[shap_columns].values, feature_names = [col.replace('_EXPLANATION','') for col in shap_columns]) # wrapping them into a SHAP recognized object
    shap.plots.bar(shap_exp)
with col2:
    st.markdown(f'### Local Feature Importance for CUSTOMER_ID {int(selected_customer)}:')
    shap.plots.bar(shap_exp[0])

### Continious Model Monitoring
Model behavior can change over time due to factors such as **input drift, stale training assumptions, data pipeline issues, hardware and software updates**.

**ML Observability** enables you to monitor the quality of models registered in the **Snowflake Model Registry** across multiple dimensions, including **performance, drift, and volume**.  

To measure drift for model monitoring, we use two tables:  

| Table      | Description  |
|------------|-------------|
| `BASELINE` | Contains a snapshot of data similar to `SOURCE`. It is used as a reference for comparing future feature values and predictions. |
| `SOURCE`   | Stores future predictions and feature values for monitoring. |

In [None]:
# Save baseline predictions
predictions = predictions.with_column('FEATURE_CUTOFF_DATE', F.col('FEATURE_CUTOFF_DATE').cast('timestamp'))
predictions = predictions.with_column('NEXT_MONTH_REVENUE_PREDICTION', F.col('NEXT_MONTH_REVENUE_PREDICTION').cast('number(38,2)'))
predictions = predictions.with_column('NEXT_MONTH_REVENUE', F.col('NEXT_MONTH_REVENUE').cast('number(38,2)'))
predictions.write.save_as_table('SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_V1', mode='overwrite')

### Creating predictions for the next month
We use the trained model on our **April data** to predict each customer's **revenue for May**.  
The `get_feature_df()` function is a helper utility that constructs the **spine DataFrame** and retrieves the correct **point-in-time features** based on the `FEATURE_CUTOFF_DATE`.  
The predictions are then stored in the `SOURCE` table, which we will link to the **model monitor** for tracking and evaluation.

In [None]:
def build_feature_df(session, feature_cutoff_date, feature_views):
    # Initialize the Feature Store.
    fs = FeatureStore(
        session=session, 
        database=session.get_current_database(), 
        name='FEATURE_STORE', 
        default_warehouse=session.get_current_warehouse(),
        creation_mode=CreationMode.CREATE_IF_NOT_EXIST,
    )
    
    # Retrieve all feature views (version 'V1') from the Feature Store.
    fvs = [fs.get_feature_view(name=feature_view_name,version=feature_view_version) for feature_view_name, feature_view_version in feature_views.items()]
    
    # Create a base (spine) DataFrame containing distinct CUSTOMER_IDs and the feature cutoff date.
    feature_df = session.table(f'{session.get_current_database()}.RETAIL_DATA.CUSTOMERS') \
                        .select('CUSTOMER_ID') \
                        .distinct() \
                        .with_column('FEATURE_CUTOFF_DATE', F.to_date(lit(feature_cutoff_date)))
    
    # Retrieve feature values from the Feature Store for the specified cutoff date.
    feature_df = fs.retrieve_feature_values(
        spine_df=feature_df,
        features=fvs,
        spine_timestamp_col="FEATURE_CUTOFF_DATE"
    )
    
    # Add a placeholder column for NEXT_MONTH_REVENUE
    feature_df = feature_df.with_column('NEXT_MONTH_REVENUE', lit(None).cast('number(38,2)'))
    
    return feature_df

In [None]:
feature_df = build_feature_df(
    session, 
    feature_cutoff_date='2024-04-30', 
    feature_views={'IN_SHOP_REVENUE_FEATURES':'V1', 'ONLINE_REVENUE_FEATURES':'V1'}
)

print('Feature DataFrame:')
feature_df.show()

# Predict May values
predictions = production_model.run(feature_df, function_name='PREDICT')
predictions = predictions.with_column('FEATURE_CUTOFF_DATE', F.col('FEATURE_CUTOFF_DATE').cast('timestamp'))
predictions = predictions.with_column('NEXT_MONTH_REVENUE_PREDICTION', F.col('NEXT_MONTH_REVENUE_PREDICTION').cast('number(38,2)'))
predictions.write.save_as_table(table_name='SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_V1', mode='overwrite')

# Predictions
print('Predictions [column=NEXT_MONTH_REVENUE_PREDICTION]:')
session.table('SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_V1').show()

### Creating a Model Monitor  

We are setting up a **model monitor** to continuously calculate and track model performance and drift over time.  

These calculations are based on the **`BASELINE`** and **`SOURCE`** tables created earlier.  
Each model requires its own dedicated **model monitor** to ensure accurate tracking and evaluation.

In [None]:
# Enable once 1.7.3 with bugfix is available
source_config = ModelMonitorSourceConfig(
    source='SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_V1',
    timestamp_column='FEATURE_CUTOFF_DATE',
    id_columns=['CUSTOMER_ID'],
    prediction_score_columns=['NEXT_MONTH_REVENUE_PREDICTION'],
    actual_score_columns=['NEXT_MONTH_REVENUE'],
    baseline='SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_V1'
)

monitor_config = ModelMonitorConfig(
    model_version=reg.get_model('CUSTOMER_REVENUE_MODEL').version('PRODUCTION'),
    model_function_name='predict',
    background_compute_warehouse_name='COMPUTE_WH',
    refresh_interval='1 minute',
    aggregation_window='1 day'
)

model_monitor = reg.add_monitor(
    name='SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_V1',
    source_config=source_config,
    model_monitor_config=monitor_config
)

### Simulating the next month of Customer Transactions
Our model has predicted each customer's **revenue for May 2024** and stored the results in the **`SOURCE`** table.  
Next, we simulate the actual transactions for May and update the **true revenue values** for each customer in the **`SOURCE`** table.  
When the **model monitor** refreshes, it will use these updated values to calculate various **model performance metrics**, including the MAPE.

In [None]:
# Add new transactions (created as part of the initial demo setup)
new_transactions = session.table('SIMPLE_MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS').filter(col('DATE').between('2024-05-01','2024-05-31'))
new_transactions.write.save_as_table(table_name='SIMPLE_MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS', mode='append')

# Calculate actual values
actual_values_df = (
    session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')
    .filter(col('DATE').between('2024-05-01','2024-05-31'))
    .group_by(['CUSTOMER_ID'])
    .agg(F.sum('TRANSACTION_AMOUNT').as_('TOTAL_REVENUE'))
    .with_column('DATE', F.to_date(lit('2024-04-30')))
)

# Get list of all customers
customers_df = session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.CUSTOMERS').select('CUSTOMER_ID').distinct()

# Assume 0 revenue for customers without transactions
actual_values_df = actual_values_df.join(customers_df, on=['CUSTOMER_ID'], how='outer')
actual_values_df = actual_values_df.fillna(0,subset='TOTAL_REVENUE')

# Update source table from model monitor
source_table = session.table('SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_V1')
source_table.update(
    condition=(
        (source_table['FEATURE_CUTOFF_DATE'] == actual_values_df['DATE']) &
        (source_table['CUSTOMER_ID'] == actual_values_df['CUSTOMER_ID'])
    ),
    assignments={
        "NEXT_MONTH_REVENUE": actual_values_df['TOTAL_REVENUE'],
    },
    source=actual_values_df
)

## Simulate Customer Transactions until 2025-01-31
For convenience, I encapsulated all the logic for simulating future months into the helper function `simulate_model_performance()`.  
We use this function to simulate the model's behavior until January 2025.

In [None]:
start_date = '2024-06-01'
end_date = '2025-01-31'
model_version = 'PRODUCTION'

simulate_model_performance(session, start_date, end_date, model_version, generate_data=True)

## Explore the Model Monitor
Navigate to the Model Monitor and observe the `MAPE` and `Wasserstein`  for the last months.  

You will notice the following:
* Declining Model Performance
    * :arrow_up_small: MAPE (Mean Average Percentage Error)
* Feature Drift
    * :arrow_down_small: Difference of means for TOTAL_REVENUE_IN_SHOP_PAST_1_MONTHS (less in shop transaction volume)
    * :arrow_up_small: Difference of means for TOTAL_REVENUE_ONLINE_PAST_1_MONTHS (more online transaction volume)

Why is that?  
Well, if we visualize the monthly revenue distribution, we can see that online revenue grew while in-shop transaction declined.

Instead of using the builtin UI, you can also query model monitor metrics using the following table functions and build your own visuals:
* [MODEL_MONITOR_PERFORMANCE_METRIC](https://docs.snowflake.com/en/sql-reference/functions/model-monitor-performance-metric)
* [MODEL_MONITOR_DRIFT_METRIC](https://docs.snowflake.com/en/sql-reference/functions/model-monitor-drift-metric)
* [MODEL_MONITOR_STAT_METRIC](https://docs.snowflake.com/en/sql-reference/functions/model-monitor-stat-metric)

In [None]:
get_snowsight_url(session, 'Link to Model Monitor', '#/data/databases/SIMPLE_MLOPS_DEMO/schemas/MODEL_REGISTRY/model/CUSTOMER_REVENUE_MODEL/version/V1/monitors/MM_V1/dashboard')

In [None]:
with st.expander('**Need help deciding for the right metric?**', expanded=False):
    text = """### **Overview of Feature Drift Metrics:**

|                      | **Jensen-Shannon Distance** | **Wasserstein Distance** | **Difference of Means** |
|-----------------------------|----------------|--------------------------|-------------------------|
| **What It Measures**        | Difference in probability distributions | Amount of movement needed to align two distributions | Simple difference between means of two distributions |
| **Intuition**               | Measures how **different** two distributions are (based on KL divergence, but smoothed and symmetric). | Measures the **work needed** to "move" one distribution to match the other. | Measures the shift in the **central tendency** of the feature values. |
| **Range**                   | 0 to 1 (bounded) | 0 to ∞ (can grow indefinitely) | -∞ to ∞ (unbounded) |
| **Interpretability**        | 0 = identical, 1 = completely different | Larger values mean greater distribution shift | Positive = mean has increased, Negative = mean has decreased |
| **Computational Complexity** | Faster, works well with discrete values | Slower, requires solving an optimization problem | Very fast (simple arithmetic) |
| **Small shifts in values**  | May not detect it well if probability distributions overlap a lot. | Captures even small shifts because it looks at the actual distance between values. | Only detects shifts in the mean, not overall distribution changes. |
| **Major changes in shape**  | Captures well if distributions change significantly. | Captures well if mass shifts significantly. | ❌ No, only captures mean changes. |
| **Outliers or extreme shifts** | May be less sensitive if distributions overlap in many places. | More sensitive because it considers the actual movement of values. | Very sensitive to outliers (mean can shift significantly). |
| **Best for categorical distributions** (e.g., customer segments) | ✅ Yes | ❌ No | ❌ No |
| **Best for continuous features** (e.g., age, income) | ❌ No | ✅ Yes | ✅ Yes |
| **Best for detecting gradual numerical shifts** | ❌ No | ✅ Yes | ✅ Yes, but only if the mean is shifting. |
| **Best for interpretable (bounded 0-1) metric** | ✅ Yes | ❌ No | ❌ No |"""
    st.markdown(text)

In [None]:
# Allowed performance metrics
ALLOWED_PERFORMANCE_METRICS = {
    'ROC_AUC', 'CLASSIFICATION_ACCURACY', 'F1_SCORE', 'MAPE', 'MSE', 
    'RMSE', 'MAE', 'PRECISION', 'RECALL'
}

# Function to retrieve performance metrics for models from the model registry
def get_model_performance_metrics(session, models, metrics, start_date, end_date, aggregation):
    """
    Fetches model performance metrics from the model registry for given models and metrics.
    
    Args:
        session: Active database session.
        models (list): List of models to fetch metrics for.
        metrics (list): List of performance metrics to retrieve.
        start_date (str): Start date for metrics retrieval.
        end_date (str): End date for metrics retrieval.
        aggregation (str): Aggregation window (e.g., '1 day').
    
    Returns:
        DataFrame: Aggregated model performance metrics.
    """
    # Initialize the model registry
    registry = Registry(
        session=session, 
        database_name=session.get_current_database(), 
        schema_name='MODEL_REGISTRY', 
        options={'enable_monitoring': True}
    )
    
    # Validate requested metrics
    invalid_metrics = set(metrics) - ALLOWED_PERFORMANCE_METRICS
    if invalid_metrics:
        raise ValueError(f"Invalid metric(s) found: {invalid_metrics}")
    
    all_models_metrics = []
    
    # Iterate through each model
    for model in models:
        model_name = model.model_name
        model_version_name = model.version_name
        monitor_name = str(registry.get_monitor(model_version=model).name)
        
        model_metrics_dfs = []
        
        # Fetch each metric for the model
        for metric in metrics:
            df_metric = (
                session.table_function(
                    "MODEL_MONITOR_PERFORMANCE_METRIC",
                    lit(monitor_name), lit(metric), lit(aggregation), lit(start_date), lit(end_date)
                )
                .with_column('MODEL_NAME', lit(model_name))
                .with_column('MODEL_VERSION_NAME', lit(model_version_name))
                .with_column('MODEL_MONITOR_NAME', lit(monitor_name))
                .rename({'METRIC_VALUE': metric})
                .select(['MODEL_NAME', 'MODEL_VERSION_NAME', 'MODEL_MONITOR_NAME', 'EVENT_TIMESTAMP', metric])
            )
            model_metrics_dfs.append(df_metric)
        
        # Combine metrics for the model
        model_metrics_df = model_metrics_dfs[0]
        for df in model_metrics_dfs[1:]:
            model_metrics_df = model_metrics_df.join(df, on=['MODEL_NAME', 'MODEL_VERSION_NAME', 'MODEL_MONITOR_NAME', 'EVENT_TIMESTAMP'], how='inner')
        
        all_models_metrics.append(model_metrics_df)
    
    # Combine all models' metrics
    final_df = all_models_metrics[0]
    for df in all_models_metrics[1:]:
        final_df = final_df.union_all(df)
    
    return final_df.order_by(['MODEL_NAME', 'MODEL_VERSION_NAME', 'MODEL_MONITOR_NAME', 'EVENT_TIMESTAMP'])

# Calling the function
model_performance_metrics = get_model_performance_metrics(
    session=session, 
    models=[
        reg.get_model('CUSTOMER_REVENUE_MODEL').version('V1'),
        reg.get_model('CUSTOMER_REVENUE_MODEL').version('PRODUCTION')
    ], 
    metrics=['MAPE', 'RMSE'], 
    start_date='2024-01-01', 
    end_date='2024-12-31', 
    aggregation='1 day'
)

model_performance_metrics

In [None]:
# Allowed performance metrics
ALLOWED_DRIFT_METRICS = {
    'DIFFERENCE_OF_MEANS', 'JENSEN_SHANNON', 'WASSERSTEIN'
}

# Function to retrieve performance metrics for models from the model registry
def get_model_drift_metrics(session, models, metrics, start_date, end_date, aggregation, columns):
    """
    Fetches model performance metrics from the model registry for given models and metrics.
    
    Args:
        session: Active database session.
        models (list): List of models to fetch metrics for.
        metrics (list): List of performance metrics to retrieve.
        start_date (str): Start date for metrics retrieval.
        end_date (str): End date for metrics retrieval.
        aggregation (str): Aggregation window (e.g., '1 day').
    
    Returns:
        DataFrame: Aggregated model performance metrics.
    """
    # Initialize the model registry
    registry = Registry(
        session=session, 
        database_name=session.get_current_database(), 
        schema_name='MODEL_REGISTRY', 
        options={'enable_monitoring': True}
    )
    
    # Validate requested metrics
    invalid_metrics = set(metrics) - ALLOWED_DRIFT_METRICS
    if invalid_metrics:
        raise ValueError(f"Invalid metric(s) found: {invalid_metrics}")
    
    all_models_metrics = []
    
    # Iterate through each model
    for model in models:
        model_name = model.model_name
        monitor_name = str(registry.get_monitor(model_version=model).name)
        
        model_metrics_dfs = []
        
        # Fetch each metric for the model
        for column in columns:
            column_metrics_dfs = []
            for metric in metrics:
                df_metric = (
                    session.table_function(
                        "MODEL_MONITOR_DRIFT_METRIC",
                        lit(monitor_name), lit(metric), lit(column), lit(aggregation), lit(start_date), lit(end_date)
                    )
                    .with_column('MODEL_NAME', lit(model_name))
                    .with_column('MODEL_VERSION_NAME', lit(model.version_name))
                    .with_column('MODEL_MONITOR_NAME', lit(monitor_name))
                    .rename({'METRIC_VALUE': metric})
                    .select(['MODEL_NAME', 'MODEL_VERSION_NAME', 'MODEL_MONITOR_NAME', 'EVENT_TIMESTAMP', 'COLUMN_NAME', metric])
                )
                column_metrics_dfs.append(df_metric)
            column_metrics_df = column_metrics_dfs[0]
            for df in column_metrics_dfs[1:]:
                column_metrics_df = column_metrics_df.join(df, on=['MODEL_NAME', 'MODEL_VERSION_NAME', 'MODEL_MONITOR_NAME', 'EVENT_TIMESTAMP', 'COLUMN_NAME'], how='inner')
            model_metrics_dfs.append(column_metrics_df)

        # Combine metrics for the model
        model_metrics_df = model_metrics_dfs[0]
        for df in model_metrics_dfs[1:]:
            model_metrics_df = model_metrics_df.union_all(df)
    
        all_models_metrics.append(model_metrics_df)
    
    # Combine all models' metrics
    final_df = all_models_metrics[0]
    for df in all_models_metrics[1:]:
        final_df = final_df.union_all(df)
    
    return final_df.order_by(['MODEL_NAME', 'MODEL_VERSION_NAME', 'MODEL_MONITOR_NAME', 'COLUMN_NAME', 'EVENT_TIMESTAMP'])

# Calling the function
model_drift_metrics = get_model_drift_metrics(
    session=session, 
    models=[
        reg.get_model('CUSTOMER_REVENUE_MODEL').version('PRODUCTION')
    ], 
    metrics=['JENSEN_SHANNON', 'WASSERSTEIN','DIFFERENCE_OF_MEANS'], 
    start_date='2024-01-01', 
    end_date='2024-12-31', 
    aggregation='1 day',
    columns=['TOTAL_REVENUE_IN_SHOP_PAST_1_MONTHS','TOTAL_REVENUE_IN_SHOP_PAST_2_MONTHS','TOTAL_REVENUE_IN_SHOP_PAST_3_MONTHS']
)

model_drift_metrics

# PLOT MM

In [None]:
#@st.cache_resource
def get_all_models():
    # Fetch and process model data
    all_models = registry.show_models()
    all_models['model_task'] = all_models['name'].apply(lambda x: str(registry.get_model(x).version('default').get_model_task()))
    all_models['versions'] = all_models['versions'].apply(lambda x: ast.literal_eval(x))
    all_models['aliases'] = all_models['aliases'].apply(lambda x: ast.literal_eval(x))
    all_models = all_models.explode('versions')
    all_models = all_models.rename(columns={'versions': 'model_version', 'name': 'model_name'})
    all_models = all_models.sort_values(['model_name', 'created_on', 'model_version'])
    all_models = all_models[['model_name', 'model_version', 'aliases', 'model_task']]
    return all_models

@st.cache_resource
def get_model(model_name, model_version):
    return registry.get_model(model_name).version(model_version)

import ast
import plotly.graph_objects as go

ALLOWED_MODEL_TYPES_METRICS = {
    'TASK.TABULAR_BINARY_CLASSIFICATION': ['PRECISION', 'F1_SCORE', 'CLASSIFICATION_ACCURACY', 'ROC_AUC', 'RECALL'],
    'Task.TABULAR_REGRESSION': ['MSE', 'RMSE', 'MAPE', 'MAE']
}

registry = Registry(
    session=session, 
    database_name=session.get_current_database(), 
    schema_name='MODEL_REGISTRY', 
    options={'enable_monitoring': True}
)

# Fetch and process model data
all_models = get_all_models()

with st.expander('Select Models:', expanded=True):
    selection = st.dataframe(all_models, selection_mode='multi-row', on_select="rerun", hide_index=True, use_container_width=True)

if len(selection['selection']['rows']) == 0:
    st.info('Select models.')
else:
    selected_models = all_models.iloc[selection['selection']['rows']]
    if selected_models['model_task'].nunique() > 1:
        st.error('All selected models must have the same task.')
    else:
        with st.form("my_form"):
            col1, col2 = st.columns(2)
            if selected_models.iloc[0]['model_task'] == 'Task.TABULAR_REGRESSION':
                selected_performance_metric = col1.selectbox('Select Model Performance Metric:', ALLOWED_MODEL_TYPES_METRICS['Task.TABULAR_REGRESSION'])
            else:
                selected_performance_metric = col1.selectbox('Select Model Performance Metric:', ALLOWED_MODEL_TYPES_METRICS['Task.TABULAR_BINARY_CLASSIFICATION'])
            selected_drift_metric = col2.selectbox('Select Model Drift Metric:', ALLOWED_DRIFT_METRICS)

            #models = selected_models.apply(lambda row: registry.get_model(row['model_name']).version(row['model_version']), axis=1).tolist()
            models = selected_models.apply(lambda row: get_model(row['model_name'],row['model_version']), axis=1).tolist()
            all_monitors = pd.DataFrame(registry.show_model_monitors())
            all_monitors['model'] = all_monitors['model'].apply(lambda x: ast.literal_eval(x))
            all_monitors['source'] = all_monitors['source'].apply(lambda x: ast.literal_eval(x))
            all_monitors['model_name'] = all_monitors['model'].apply(lambda x: x['model_name'])
            all_monitors['model_version'] = all_monitors['model'].apply(lambda x: x['version_name'])
            all_monitors = selected_models.merge(all_monitors, on=['model_name', 'model_version'], how='inner')
            all_monitors['monitor_columns'] = all_monitors['source'].apply(lambda x: session.table(f"{x['database_name']}.{x['schema_name']}.{x['name']}").columns)
            all_monitors = all_monitors[['model_name', 'model_version', 'monitor_columns']]
            selected_columns = st.multiselect('Select Drift columns:', all_monitors['monitor_columns'].explode().unique())
            submitted = st.form_submit_button("Submit")
            
            if submitted:
                df_model = get_model_performance_metrics(
                    session=session,
                    models=models,
                    metrics=[selected_performance_metric],
                    start_date='2024-01-01',
                    end_date='2024-12-31',
                    aggregation='1 day'
                ).to_pandas()

                df_drift = get_model_drift_metrics(
                    session=session,
                    models=models,
                    metrics=[selected_drift_metric],
                    start_date='2024-01-01',
                    end_date='2024-12-31',
                    aggregation='1 day',
                    columns=selected_columns
                ).to_pandas()

                df_drift["EVENT_TIMESTAMP"] = pd.to_datetime(df_drift["EVENT_TIMESTAMP"])
                df_model["EVENT_TIMESTAMP"] = pd.to_datetime(df_model["EVENT_TIMESTAMP"])

                fig = go.Figure()
                
                for model_version in df_model["MODEL_VERSION_NAME"].unique():
                    df_subset = df_model[df_model["MODEL_VERSION_NAME"] == model_version]
                    fig.add_trace(go.Scatter(
                        x=df_subset["EVENT_TIMESTAMP"],
                        y=df_subset[selected_performance_metric],
                        mode='lines+markers',
                        line=dict(dash='solid', width=2),
                        marker=dict(symbol='diamond', size=12),
                        name=f"{df_subset.iloc[0]['MODEL_NAME']} - {model_version}",
                        yaxis='y1',
                        legendgroup='model_metrics',
                        legendgrouptitle_text='Model Metrics:'
                    ))
                
                for model_version in df_drift["MODEL_VERSION_NAME"].unique():
                    df_subset = df_drift[df_drift["MODEL_VERSION_NAME"] == model_version]
                    for column_name in df_subset["COLUMN_NAME"].unique():
                        df_subsubset = df_subset[df_subset["COLUMN_NAME"] == column_name]
                        fig.add_trace(go.Scatter(
                            x=df_subsubset["EVENT_TIMESTAMP"],
                            y=df_subsubset[selected_drift_metric],
                            mode='lines+markers',
                            line=dict(dash='dot', width=2),
                            marker=dict(symbol='square', size=8),
                            name=f'{column_name}',
                            yaxis='y2',
                            legendgroup=f"{df_subsubset.iloc[0]['MODEL_NAME']} - {df_subsubset.iloc[0]['MODEL_VERSION_NAME']}",
                            legendgrouptitle_text=f"Drift: {df_subsubset.iloc[0]['MODEL_NAME']} - {df_subsubset.iloc[0]['MODEL_VERSION_NAME']}"
                        ))

                fig.update_layout(
                    title="Model Performance & Feature Drift Over Time",
                    xaxis_title="Event Timestamp",
                    xaxis=dict(type='date'),
                    yaxis=dict(title=selected_performance_metric, side="left", showgrid=False),
                    yaxis2=dict(title=selected_drift_metric, overlaying="y", side="right", showgrid=False),
                    legend=dict(orientation="h", yanchor="top", y=-0.2, xanchor="left", x=0.25, traceorder="grouped", itemwidth=30),
                    margin=dict(t=50),
                    legend_tracegroupgap=10,
                    template="plotly_white"
                )
                
                st.plotly_chart(fig)

In [None]:
df_model

In [None]:
cortex_helper_visualize_query(model_drift_metrics, 'Plot the jensen shannon distance. Create one line per unique column name in column_name and make sure event timestamp is sorted.')

In [None]:
df_drift_performance_metric_column = (
    session.table_function(
        "MODEL_MONITOR_DRIFT_METRIC",
        lit('MM_V1'),
        lit('WASSERSTEIN'),
        lit('TOTAL_REVENUE_IN_SHOP_PAST_1_MONTHS'),
        lit('1 day'),
        lit('2024-01-01'),
        lit('2024-12-31')
    )
    .rename({'METRIC_VALUE': 'JENSEN_SHANNON'})
    .with_column('MODEL_NAME', lit('ECOMM_MODEL'))
    .with_column('MODEL_MONITOR_NAME', lit('MM_MONITOR_NAME'))
    .select('MODEL_NAME','MODEL_MONITOR_NAME','EVENT_TIMESTAMP','COLUMN_NAME','JENSEN_SHANNON')
)

df_drift_performance_metric_column

In [None]:
plot_inshop_vs_online_revenue(transactions_df)

## Train a new Model Version  

Since **user behavior has changed**, we will train a **new version of our model** using fresh data.  

To streamline this process, I have encapsulated the entire training workflow into the helper function `train_new_model()`, which automates the following steps:  

- **Creates the spine DataFrame**, including the target variable.  
- **Retrieves features** from the Feature Store.  
- **Creates a Snowflake Dataset** from the training data (ensuring reproducibility with a snapshot).  
- **Trains a new XGBoost model**.  
- **Registers the model** in the Snowflake Model Registry.  
- **Sets up a new model monitor** to track performance and drift.  
- **Compares model performance** against the existing production model.  
- **Deploys the new model** if it outperforms the current one by assigning it the **"PRODUCTION"** alias.  

Since the training data includes **June, July, and August 2024** (covering training data up to **August 31, 2024**, and looking back three months), the model should recognize that **ONLINE transactions** have become a major driver of customer revenue.

In [None]:
feature_cutoff_date = '2024-08-31'
target_start_date = '2024-09-01'
target_end_date = '2024-09-30'
model_version = 'V2'

train_new_model(session, feature_cutoff_date, target_start_date, target_end_date, model_version)

### Simulate Model performance for Model Version V2 until 2025-01-31
Once again, we are simulating **model performance** based on customer transactions up to **January 2025**.  
Make sure to check the **model monitor** to evaluate whether the new model version trained on more recent data performs better.  
Additionally, analyze the **feature drift**, where you’ll notice that the trend for the **V2 model** is much more stable.

In [None]:
start_date = '2024-10-01'
end_date = '2025-01-31'
model_version = 'V2'

simulate_model_performance(session, start_date, end_date, model_version, generate_data=False)

### Comparing the two Model Versions
We have already observed that the new model provides **significantly better predictions** for future customer revenue. However, we want to gain deeper insights into **why** this improvement occurred.  

To analyze this, I am plotting the **feature importance** for both models. This reveals that the new model recognizes a **much stronger influence** of past **ONLINE transactions** on future customer revenue.  

Additionally, we can leverage the model's **explainability features**, using **SHAP values**, to further visualize and understand these relationships.

In [None]:
compare_two_models(session,'V1','V2')

In [None]:
explanations = registered_model.run(test_df, function_name="explain")
explanations = explanations.rename({col:col.replace('"""', '').upper() for col in explanations.columns})

shap_columns = [col for col in explanations.columns if '_EXPLANATION' in col]

explanations = explanations.select('CUSTOMER_ID', *shap_columns)
explanations = explanations.to_pandas()

import shap
col1, col2 = st.columns(2)
with col1:
    st.markdown(f'### Global Feature Importance')
    shap_exp = shap._explanation.Explanation(explanations[shap_columns].values, feature_names = [col.replace('_EXPLANATION','') for col in shap_columns]) # wrapping them into a SHAP recognized object
    shap.plots.bar(shap_exp)
with col2:
    st.markdown(f'### Local Feature Importance for CUSTOMER_ID {int(explanations.iloc[0]["CUSTOMER_ID"])}:')
    shap.plots.bar(shap_exp[0])

In [None]:
import shap
col1, col2 = st.columns(2)

with col1:
    st.markdown('### Global Feature Importance: Model V1')
    explaination_df = session.table('SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_V2').filter(col('FEATURE_CUTOFF_DATE') == '2025-01-31')
    mv = reg.get_model('CUSTOMER_REVENUE_MODEL').version('V1')
    explanations = mv.run(explaination_df, function_name="explain")
    explanations = explanations.rename({col:col.replace('"""', '').upper() for col in explanations.columns})
    shap_columns = [col for col in explanations.columns if '_EXPLANATION' in col]
    explanations = explanations.select('CUSTOMER_ID', *shap_columns)
    explanations = explanations.to_pandas()
    
    shap_exp = shap._explanation.Explanation(explanations[shap_columns].values, feature_names = [col.replace('_EXPLANATION','') for col in shap_columns]) 
    shap.plots.bar(shap_exp)
with col2:
    st.markdown('### Global Feature Importance: Model V2')
    explaination_df = session.table('SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_TRANS_SOURCE_V2').filter(col('FEATURE_CUTOFF_DATE') == '2025-01-31')
    mv = reg.get_model('CUSTOMER_REVENUE_MODEL').version('V2')
    explanations = mv.run(explaination_df, function_name="explain")
    explanations = explanations.rename({col:col.replace('"""', '').upper() for col in explanations.columns})
    shap_columns = [col for col in explanations.columns if '_EXPLANATION' in col]
    explanations = explanations.select('CUSTOMER_ID', *shap_columns)
    explanations = explanations.to_pandas()
    
    shap_exp = shap._explanation.Explanation(explanations[shap_columns].values, feature_names = [col.replace('_EXPLANATION','') for col in shap_columns]) 
    shap.plots.bar(shap_exp)

## ML Lineage
Even though you may not have noticed, you’ve been capturing **lineage information** throughout the development of your machine learning pipeline.  

You can retrieve this information using the built-in function `lineage.trace()` for further analysis.  
For example, you can use this data to **visualize the lineage directly in the notebook**.  

Additionally, Snowflake provides a **more user-friendly and interactive UI** that allows you to explore and monitor your machine learning pipeline:  
![text](https://github.com/michaelgorkow/snowflake_simple_mlops/blob/main/resources/ml_lineage3.png?raw=true)

As shown, the lineage captures a **comprehensive view** of your pipeline, tracking data transformations and dependencies from the **source tables**, through the **feature view**, the **training dataset**, and ultimately the **registered model** in the Model Registry.

In [None]:
st.markdown('https://app.snowflake.com/tuoxlbu/etb67195/#/data/databases')

In [None]:
session.sql("SELECT CURRENT_ACCOUNT_NAME()").show()

In [None]:
trace = session.lineage.trace(
    object_name='SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.CUSTOMER_REVENUE_MODEL',
    object_version='V1',
    object_domain='model',
    direction='both',
    distance=2
)
trace.show()

In [None]:
visualize_lineage(trace.to_pandas(), short_names=True)