# Use Case: Predicting Future Customer Revenue Using Historical Transaction Data 📊

# 1 - Setup Demo 🛠️
* Import required libraries
* Create a Snowpark session

| Library    | Use |
| -------- | ------- |
| `snowflake.snowpark` | Main Python Developer Framework for Snowflake including the DataFrame-API     |
| `snowflake.ml`    | Snowflake ML specific functions including Feature Store & Model Registry APIs    |
| `snowflake.cortex`    | Snowflake APIs to access Cortex Services (e.g. LLMs)    |
| `notebook_extras`  | Convenience Functions for Snowflake Notebooks. More details [here](google.com).    |
| `demo_extras`  | Demo-specific functions (Data Generation, Use Case flow, etc.)     |

In [None]:
# Helper functions for this demo
from demo_extras.flow import Demoflow
from notebook_extras.cortex import CortexPilot
from notebook_extras.model_registry import ModelRegistryHelper
from notebook_extras.lineage import LineageHelper
from notebook_extras.misc import get_snowsight_url


# Import python packages
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from streamlit import dataframe as sdf
import pandas as pd
import numpy as np
import shap
import warnings
import logging
from opentelemetry import trace
from snowflake import telemetry
warnings.filterwarnings("ignore", category=RuntimeWarning)

# Import Snowflake packages
from snowflake.snowpark import Session
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark import functions as F
from snowflake.snowpark.functions import lit, col, sproc
from snowflake.ml.modeling.xgboost import XGBRegressor
from snowflake.ml.modeling.metrics import mean_absolute_percentage_error
from snowflake.ml.registry import Registry
from snowflake.ml.monitoring.entities.model_monitor_config import ModelMonitorSourceConfig, ModelMonitorConfig
from snowflake.ml.feature_store import (
    FeatureStore,
    FeatureView,
    Entity,
    CreationMode
)

demo_flow = Demoflow()
demo_flow.setup()

session = get_active_session()

# 2 - Data Exploration & Visualization

* `session.table()` creates a reference to a table
* `count()`, `order_by()`, `describe()` are dataframe operations
* `describe()` gives us insights into the transaction amounts (e.g. min, average, max, count).

We can see that we have roughly 50K transactions across 350 customers.

In [None]:
transactions_df = session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')

print(f'Number of transactions: {transactions_df.count()}')
print(f'Number of customers: {transactions_df.select("CUSTOMER_ID").distinct().count()}')

print('Transactions Data:')
transactions_df.order_by(col('DATE').desc()).show()

print('Quick Variable Analysis:')
transactions_df.describe().order_by('SUMMARY').show()

### Plotting Data
You can use libraries such as plotly or matplotlib to visualize your data. However, instead of coding the plots manually, we'll leverage GenAI models hosted natively in Snowflake to automatically generate the visualizations.

* This notebook comes with your own personal 🤖 **CortexPilot** powered by Cortex LLMs   
    * `ui_plotting()` -> UI-driven plotting with GenAI
    * `f_cortex_helper_visualize_query()` -> function that receives a Snowpark or Pandas Dataframe and a prompt (in case you already know the dataframe and query)

Both functions utilize Snowflake's [complete()](https://docs.snowflake.com/en/developer-guide/snowpark-ml/reference/latest/api/cortex/snowflake.cortex.complete) function to access LLMs natively hosted in Snowflake.

Try asking the following questions:  
* ***What was the overall revenue per channel and month? Use a stacked bar plot and use YY-Monthname for the x-axis. Make sure the x-axis is ordered.***
* ***What was the total transaction amount per channel? Use a pie chart.***

In [None]:
# Get an instance of CortexPilot
my_pilot = CortexPilot(llm='mistral-large2')

In [None]:
# Open the UI
my_pilot.ui_plotting()

When we plot the distribution of ONLINE vs. IN_SHOP revenue, we can see that 75% of our revenue comes from customer transactions that go into our shops.  
A model trained on this data should recognize that IN_SHOP transactions are the major driver of future customer revenue.

In [None]:
my_pilot.f_cortex_helper_visualize_query(transactions_df, 'What was the total transaction amount per channel? Use a pie chart.')

# 3 - Feature Store & Feature Engineering
The Snowflake Feature Store enables data scientists and ML engineers to create, manage, and utilize machine learning features within machine learning pipelines.  
A feature store consists of feature views, which encapsulate Python or SQL pipelines that transform raw data into one or more related features.  
All features within a feature view are refreshed simultaneously from the source data.

Feature store objects are implemented as Snowflake objects and all feature store objects are therefore subject to Snowflake access control rules.
| Feature Store Object    | Snowflake Object |
| -------- | ------- |
| `FeatureStore` | Schema     |
| `Entity`    | Tag    |
| `FeatureView`  | Dynamic Table or View    |
| `Feature`  | Column in a Dynamic Table or View    |

### Setup the Feature Store
We are creating (or referencing if it already exists) a Feature Store that is stored in the schema `FEATURE_STORE`.  
The `default_warehouse` will be used to refresh features automatically.

In [None]:
fs = FeatureStore(
    session=session, 
    database='SIMPLE_MLOPS_DEMO', 
    name='FEATURE_STORE', 
    default_warehouse='FEATURE_STORE_WH',
    creation_mode=CreationMode.CREATE_IF_NOT_EXIST,
)

### Create a Feature Store Entity
Feature views are organized in the feature store according to the entities to which they apply. An entity is a higher-level abstraction that represents the subject matter of a feature.  
In our example, the main entity is the `CUSTOMER` and the features we will create will be linked to this entity.

In [None]:
# Create a new entity for the Feature Store
entity = Entity(name="CUSTOMER", join_keys=["CUSTOMER_ID"], desc='Unique identifier for customers.')
fs.register_entity(entity)
fs.list_entities().show()

### Develop Features for Customer Transactions

The Snowpark Python API provides analytics functions for easily defining many common feature types, such as windowed aggregations.  
We will use `analytics.time_series_agg()` to quickly generate revenue for the past 1, 2 and 3 months per customer per channel which we will use as features for our machine learning model.

The feature dataframe should have the following columns:
| Column    | Purpose |
| -------- | ------- |
| `CUSTOMER_ID` | Identify relevant rows for the calculated feature (Join-Criteria)     |
| `DATE`    | Allow correct Point-in-Time Joins   |
| `Feature columns`  | Actual features per entity    |  

You can find more functions for quickly generating featueres here:  
[Common feature and query patterns](https://docs.snowflake.com/en/developer-guide/snowflake-ml/feature-store/examples)

In [None]:
def col_formatter(input_col, agg, window):
    feature_name = f"{agg.replace('SUM','TOTAL')}_{input_col}_{window.replace('-', 'past_').replace('MM','_MONTHS')}"
    return feature_name

in_shop_transaction_features = (
    transactions_df.filter(col('TRANSACTION_CHANNEL') == 'IN_SHOP')
    .group_by(['CUSTOMER_ID','DATE']).agg(F.sum('TRANSACTION_AMOUNT').as_('REVENUE'))
    .rename({'REVENUE':'REVENUE_IN_SHOP'})
    .analytics.time_series_agg(
        aggs={'REVENUE_IN_SHOP':['SUM']},
        windows=['-1MM','-2MM','-3MM'],
        sliding_interval="1D",
        group_by=['CUSTOMER_ID'],
        time_col='DATE',
        col_formatter=col_formatter
    ).drop(['SLIDING_POINT','REVENUE_IN_SHOP'])
)

online_transaction_features = (
    transactions_df.filter(col('TRANSACTION_CHANNEL') == 'ONLINE')
    .group_by(['CUSTOMER_ID','DATE']).agg(F.sum('TRANSACTION_AMOUNT').as_('REVENUE'))
    .rename({'REVENUE':'REVENUE_ONLINE'})
    .analytics.time_series_agg(
        aggs={'REVENUE_ONLINE':['SUM']},
        windows=['-1MM','-2MM','-3MM'],
        sliding_interval="1D",
        group_by=['CUSTOMER_ID'],
        time_col='DATE',
        col_formatter=col_formatter
    ).drop(['SLIDING_POINT','REVENUE_ONLINE'])
)

In [None]:
online_transaction_features.show()

**Feature Descriptions**  
To avoid manually writing descriptions, we can use `complete()` to have an LLM generate JSON files containing business descriptions.  
The CortexPilot also offers a convenient function `f_describe_columns()` based on the complete() function.  
These descriptions are stored in the Feature Store alongside our features.

In [None]:
feature_descriptions_in_shop_transactions = my_pilot.f_describe_columns(in_shop_transaction_features, exclude_columns=['CUSTOMER_ID','DATE'])
feature_descriptions_online_transactions = my_pilot.f_describe_columns(online_transaction_features, exclude_columns=['CUSTOMER_ID','DATE'])

st.json(feature_descriptions_in_shop_transactions)
st.json(feature_descriptions_online_transactions)

### Registering Feature Views
The `FeatureView` class accepts a Snowpark DataFrame object that contains the feature transformation logic. This allows you to define your features using any method supported by the Snowpark DataFrame API or Snowflake SQL. You can pass the DataFrame directly to the `FeatureView` constructor.  

Each `FeatureView` is associated with the corresponding `Entity`.  
The `refresh_freq` parameter determines how often the Feature Store checks for new data and updates the features automatically. For demonstration purposes, this value is set to 1 minute, but it should be adjusted based on the specific use case.

In [None]:
# Create Feature View
in_shop_transaction_fv = FeatureView(
    name="IN_SHOP_REVENUE_FEATURES", 
    entities=[entity],
    timestamp_col='DATE',
    feature_df=in_shop_transaction_features, 
    refresh_freq="1 minute",
    refresh_mode='AUTO',
    desc="Features for in-shop transactions",
    overwrite=True
)

# Add descriptions for features
in_shop_transaction_fv = in_shop_transaction_fv.attach_feature_desc(feature_descriptions_in_shop_transactions)

in_shop_transaction_fv = fs.register_feature_view(
    feature_view=in_shop_transaction_fv, 
    version="V1", 
    block=True,
    overwrite=True
)

# Create Feature View
online_transaction_fv = FeatureView(
    name="ONLINE_REVENUE_FEATURES", 
    entities=[entity],
    timestamp_col='DATE',
    feature_df=online_transaction_features, 
    refresh_freq="1 minute",
    refresh_mode='AUTO',
    desc="Features for online transactions",
    overwrite=True
)

# Add descriptions for features
online_transaction_fv = online_transaction_fv.attach_feature_desc(feature_descriptions_online_transactions)

online_transaction_fv = fs.register_feature_view(
    feature_view=online_transaction_fv, 
    version="V1", 
    block=True,
    overwrite=True
)

### Discovering Features via Feature Store UI
After creating entities and feature views, you can utilize the [Feature Store User Interface](https://docs.snowflake.com/en/developer-guide/snowflake-ml/feature-store/feature-store-ui) in Snowsight to locate the objects you need.  

Example of the Feature Store UI:  
![text](https://github.com/michaelgorkow/snowflake_simple_mlops/blob/main/resources/feature_store.png?raw=true)

In [None]:
get_snowsight_url(session, 'Link to Feature Store', '#/features/database/SIMPLE_MLOPS_DEMO/store/FEATURE_STORE/entities')

### Discovering Features via Feature Store API
If you don't want to use the UI or want to develop workflows based on data in the feature store you can use the [Feature Store APIs](https://docs.snowflake.com/en/developer-guide/snowpark-ml/reference/latest/feature_store).  

CortexPilot can also help you in understanding how certain columns in a SQL query are calculated. In this example we are asking how a certain feature in the feature store is calculated but it works for any SQL query.  
Simply call `f_explain_column_sql` with the column name and the SQL query.

In [None]:
st.markdown('### List of all Feature Views:')
sdf(fs.list_feature_views())

# Retrieve a Feature View
retrieved_feature_view = fs.get_feature_view(name='IN_SHOP_REVENUE_FEATURES',version='V1')

st.markdown('### Feature View Columns:')
sdf(retrieved_feature_view.list_columns())#.show(max_width=200)

# Manually refresh a Feature View
fs.refresh_feature_view(retrieved_feature_view)

st.markdown('### Feature View Refresh History:')
sdf(fs.get_refresh_history(retrieved_feature_view).order_by(col('REFRESH_END_TIME').desc()).limit(3))

# Explore lineage information
st.markdown('### Feature View Lineage:')
st.json(retrieved_feature_view.lineage(direction='both'))

# Use an LLM and the underlying SQL query to explain how the feature is calculated
st.markdown('### LLM Explanation for a Feature in the Feature Store:')
sql_explanation = my_pilot.f_explain_column_sql(column='TOTAL_REVENUE_IN_SHOP_PAST_1_MONTHS', sql_query=retrieved_feature_view.query)
st.markdown(sql_explanation)

# 4 - Model Training

### Generate the Training Dataset with Features from Feature Store
Our goal is to predict each customer's revenue for the next month based on their transactions from the past three months.  

We have data from January to April 2024. To define our target variable, `NEXT_MONTH_REVENUE`, we sum all transactions from April for each customer. To ensure proper point-in-time feature retrieval and avoid using future data, we only include transaction features up to **April 1st, 2024**, and mark this cutoff with the `FEATURE_CUTOFF_DATE` column.  

The DataFrame you just created is a **spine DataFrame**, which acts as a reference table linking customers (`CUSTOMER_ID`) with a timestamp (`FEATURE_CUTOFF_DATE`). It ensures consistent and reproducible feature retrieval in a **feature store**.  

Using this spine, you can generate a training dataset with [`generate_dataset()`](https://docs.snowflake.com/en/developer-guide/snowflake-ml/feature-store/modeling#generating-snowflake-datasets-for-training). The Feature Store will automatically retrieve features as they were valid on that date and add them to the dataset.  

A [Snowflake Dataset](https://docs.snowflake.com/en/developer-guide/snowflake-ml/dataset) is a schema-level object designed for machine learning. It stores data in versions, ensuring immutability, efficient access, and compatibility with ML frameworks.

In [None]:
target_df = session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')
target_df = (
    target_df.filter(col('DATE').between('2024-04-02','2024-05-01'))    # Generate Target Variable for April 2024
    .group_by('CUSTOMER_ID')
    .agg(F.sum('TRANSACTION_AMOUNT').as_('NEXT_MONTH_REVENUE'))
    .with_column('FEATURE_CUTOFF_DATE', F.to_date(lit('2024-04-01')))   # Features until End of March 2024
)

# Get list of all customers
customers_df = session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.CUSTOMERS').select('CUSTOMER_ID').distinct()

# Create spine dataframe
spine_df = target_df.join(customers_df, on=['CUSTOMER_ID'], how='outer')
spine_df = spine_df.fillna(0, subset='NEXT_MONTH_REVENUE')
spine_df.order_by('CUSTOMER_ID').show()

In [None]:
train_dataset = fs.generate_dataset(
    name="SIMPLE_MLOPS_DEMO.FEATURE_STORE.NEXT_MONTH_REVENUE_DATASET",
    spine_df=spine_df,
    features=[in_shop_transaction_fv, online_transaction_fv],
    version="V1",
    spine_timestamp_col="FEATURE_CUTOFF_DATE",
    spine_label_cols=["NEXT_MONTH_REVENUE"],
    include_feature_view_timestamp_col=False,
    desc="Initial Training Dataset"
)

df = train_dataset.read.to_snowpark_dataframe()
df.show()

### Train an XGBoost Model
We randomly split the data, allocating **90% for training** and **10% for validation**.  
The training data is then used to train an **XGBoost regression model** with the `XGBRegressor` from the **Snowflake ML library**.

In [None]:
# Split the data into train and test sets
train_df, test_df = df.random_split(weights=[0.9, 0.1], seed=0)

print(f'Number of samples in train: {train_df.count()}')
print(f'Number of samples in test: {test_df.count()}')

feature_columns = train_df.drop(['CUSTOMER_ID','FEATURE_CUTOFF_DATE','NEXT_MONTH_REVENUE']).columns

xgb_model = XGBRegressor(
    input_cols=feature_columns,
    label_cols=['NEXT_MONTH_REVENUE'],
    output_cols=['NEXT_MONTH_REVENUE_PREDICTION'],
    n_estimators=100,
    learning_rate=0.05,
    random_state=0
)

xgb_model = xgb_model.fit(train_df)

### Evaluate the XGBoost Model
You can immediately use the model’s `predict()` function to generate predictions on the test data.  
Snowflake ML also provides built-in metric functions, such as **Mean Absolute Percentage Error (MAPE)**, for evaluating model performance.  

Additionally, you can convert the model back to its native open-source format using `xgb_model.to_xgboost()`.  
This allows you to access feature importance values, which we visualize to better understand what influences the model’s predictions.  

As shown in the plot, the model correctly identified that **IN_SHOP transactions** are the primary driver of the target variable, `NEXT_MONTH_REVENUE`.

In [None]:
predictions = xgb_model.predict(test_df)
# Analyze results
mape = mean_absolute_percentage_error(
    df=predictions, 
    y_true_col_names="NEXT_MONTH_REVENUE", 
    y_pred_col_names="NEXT_MONTH_REVENUE_PREDICTION"
)

st.info(f"**Mean absolute percentage error:** {mape:.5f}")

col1, col2 = st.columns(2)
with col1:
    # Plot Feature Importance
    plot_data = pd.DataFrame(
        list(zip(feature_columns, xgb_model.to_xgboost().feature_importances_)), 
        columns=['FEATURE','IMPORTANCE']
    )
    
    fig = px.bar(
        plot_data.sort_values('IMPORTANCE', ascending=True).head(10),
        x="IMPORTANCE",
        y="FEATURE",
        title="Feature Importance",
        labels={"FEATURE": "Feature", "IMPORTANCE": "Importance"},
        orientation="h"
    )
    fig.update_layout(title_font=dict(size=20, family="Arial", color="black"))
    st.plotly_chart(fig, use_container_width=True)
with col2:
    # Plot Predictions
    fig = px.scatter(
        predictions["NEXT_MONTH_REVENUE", "NEXT_MONTH_REVENUE_PREDICTION"].to_pandas().astype("float64"),
        x="NEXT_MONTH_REVENUE",
        y="NEXT_MONTH_REVENUE_PREDICTION",
        title="Actual vs. Predicted Revenue",
        labels={
            "NEXT_MONTH_REVENUE": "Actual Revenue",
            "NEXT_MONTH_REVENUE_PREDICTION": "Predicted Revenue"
        },
        trendline="ols",
        trendline_color_override="red"
    )
    fig.update_layout(title_font=dict(size=20, family="Arial", color="black"))
    st.plotly_chart(fig, use_container_width=True)

# 5 - Snowflake Model Registry
### Setup Model Registry
After training a model, the first step in operationalizing it and running inference in Snowflake is to **log the model in the Snowflake Model Registry**.  

The **Model Registry** allows you to securely manage models and their metadata in Snowflake, regardless of their origin or type, while also simplifying inference.  
It stores machine learning models as **first-class schema-level objects** within Snowflake.  

By setting `enable_monitoring` to True, the **Model Registry** can also be used for model monitoring, which we will implement in the next step.


In [None]:
# Create reference to model registry
reg = Registry(
    session=session, 
    database_name='SIMPLE_MLOPS_DEMO', 
    schema_name='MODEL_REGISTRY', 
    options={'enable_monitoring':True},
)

### Register Model in Model Registry
The Model Registry's `log_model()` function takes the model object and logs it to the registry.  
The **name** and **version** help ensure the correct model is retrieved for inference.  

Additionally, we log relevant metrics/information, including:  
- **MAPE (Mean Absolute Percentage Error)** calculated on the test dataset  
- **FEATURE_CUTOFF_DATE**   

We also specify the following parameters:  

| Variable               | Description  |
|------------------------|-------------|
| `sample_input_data`    | Sample input data used to infer model signatures, serve as background data for explanations, and capture data lineage. |
| `conda_dependencies`   | Specifies model dependencies, such as the XGBoost library. |
| `relax_version`        | Enforces specific dependency versions for compatibility and reproducibility. |
| `enable_explainability` | Adds an explainability function to the model, allowing us to better understand its predictions using SHAP values. |

In [None]:
registered_model = reg.log_model(
    xgb_model,
    model_name="CUSTOMER_REVENUE_MODEL",
    version_name='V1',
    metrics={
        'MAPE':mape, 
        "TRAINING_DATA":{'FEATURE_CUTOFF_DATE':'2024-04-01'}
    },
    comment="Model trained using XGBoost to predict revenue per customer for next month.",
    conda_dependencies=['xgboost'],
    sample_input_data=train_df.select(feature_columns).limit(100),
    options={"relax_version": False, "enable_explainability": True}
)

### Operationalize Models
There are multiple ways to operationalize models using Snowflake's Model Registry.  
One simple approach is to use **aliases** for the model. By assigning the alias **`PRODUCTION`**, any inference pipeline referencing this alias will automatically use the correct production-ready model.  

When a new model version is trained and ready for deployment, you can seamlessly update production by **removing the alias from the current model** and **assigning it to the new model**.  
This method ensures that existing ML pipelines remain unchanged, reducing the need for manual updates while maintaining a smooth model deployment process.

In [None]:
registered_model.set_alias('PRODUCTION')

In [None]:
# Retrieve the production model in your pipelines like this
production_model = reg.get_model('CUSTOMER_REVENUE_MODEL').version('PRODUCTION')
production_model.show_metrics()

### Explore Models in the Model Registry UI
The [Model Registry UI]((https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/snowsight-ui)) in Snowsight enables you to discover and explore machine learning models available for use in Snowflake.  

To view a model's details, click on its corresponding row in the Models list.  
The details page provides essential information, including the model's description, tags, and versions.

Example of the Model Registry UI:  
![text](https://github.com/michaelgorkow/snowflake_simple_mlops/blob/main/resources/model_registry_ui.png?raw=true)

In [None]:
get_snowsight_url(session, 'Link to Model Registry', '#/models')

### **Model Explainability**
Since we enabled `model_explainability` during model registration, we can now use the auto-generated `explain` function to compute SHAP values for each feature.

#### **What are SHAP (SHapley Additive exPlanations) values?**  
SHAP values provide a game-theoretic approach to interpreting machine learning predictions by fairly attributing contributions to each feature. They help explain both **global feature importance** and **individual predictions**, showing how each feature increases or decreases the model’s output. **Positive SHAP values** indicate features that push the prediction higher, while **negative values** lower it relative to the model’s baseline prediction.

Once computed, we can convert the SHAP values into a native [`shap.Explanation`](https://shap.readthedocs.io/en/latest/generated/shap.Explanation.html) object, which includes:

| Variable        | Description |
|----------------|-------------|
| `values`       | Contribution of each feature to the prediction (output from Snowflake’s `explain` function). |
| `base_values`  | Expected model output (typically the mean prediction). |
| `data`         | Feature values (used for visualization). |
| `feature_names` | Names of the features (optional but recommended). |

You can then leverage SHAP's built-in functions for visualization and deeper analysis.

In [None]:
# Calculate Shap values predictions
explanations = production_model.run(predictions, function_name="explain")
explanations = explanations.rename({col:col.replace('"""', '').upper() for col in explanations.columns})
shap_columns = [col for col in explanations.columns if '_EXPLANATION' in col]
explanations = explanations.to_pandas()

# Create the native shap Explanation object
shap_exp = shap.Explanation(
    values = explanations[shap_columns].values,
    base_values = np.full((len(explanations),), explanations['NEXT_MONTH_REVENUE_PREDICTION'].mean()),
    data = explanations[feature_columns].values,
    feature_names=feature_columns
)

### Global Explainibility using SHAP

The **left plot** is a SHAP **summary plot**, which displays the impact of each feature on the model’s output. Each dot represents a single instance, with the color indicating the feature value (blue = low, pink = high). Features at the top are the most influential, and the x-axis shows whether they push predictions higher (positive SHAP values) or lower (negative SHAP values).  

The **right plot** is a SHAP **violin plot**, which shows the distribution of SHAP values for each feature. The width of the violin indicates the density of SHAP values, helping visualize how much a feature’s impact varies across different predictions. Both plots highlight that recent revenue metrics (e.g., "TOTAL_REVENUE_IN_SHOP_PAST_1_MONTH") strongly influence the model, with high revenue values generally increasing predictions.

In [None]:
col1, col2 = st.columns(2)
with col1:
    shap.summary_plot(shap_exp)
with col2:
    shap.plots.violin(shap_exp)

### Local Explainibility using SHAP

On the left you see a **SHAP waterfall plot** which explains how a model arrived at a specific prediction by showing feature contributions. It starts with the **expected value** (baseline prediction) and adjusts it based on **SHAP values** of individual features. **Blue bars** represent features that **lowered** the prediction, while **red bars** indicate those that **increased** it. The final predicted value is obtained by sequentially adding these contributions to the baseline. This plot helps identify which features had the most impact and whether they pushed the prediction up or down.

On the right we are plotting the distribution of **online vs. in-shop transaction** for this specific customer.

In [None]:
import matplotlib.pyplot as plt

selected_customer = st.selectbox('Select Customer:', explanations.sort_values(by='CUSTOMER_ID')['CUSTOMER_ID'].values)
index = explanations.index[explanations['CUSTOMER_ID'] == selected_customer][0]
col1, col2 = st.columns([0.6,0.4])
with col1:
    st.subheader('Shap Summary Plot')
    fig, ax = plt.subplots()
    ax.set_title("", fontsize=16)
    plt.sca(ax)
    shap.plots.waterfall(shap_exp[index], show=False)
    plt.close()
    st.pyplot(fig)
with col2:
    st.subheader('Customer Transaction ')
    demo_flow.get_customer_revenue_plot(transactions_df, int(index))

### Continious Model Monitoring
Model behavior can change over time due to factors such as **input drift, stale training assumptions, data pipeline issues, hardware and software updates**.

**ML Observability** enables you to monitor the quality of models registered in the **Snowflake Model Registry** across multiple dimensions, including **performance, drift, and volume**.  

To measure drift for model monitoring, we use two tables:  

| Table      | Description  |
|------------|-------------|
| `BASELINE` | Contains a snapshot of data similar to `SOURCE`. It is used as a reference for comparing future feature values and predictions. |
| `SOURCE`   | Stores future predictions and feature values for monitoring. |

In [None]:
# Save baseline predictions
predictions = predictions.with_column('FEATURE_CUTOFF_DATE', F.col('FEATURE_CUTOFF_DATE').cast('timestamp'))
predictions.write.save_as_table('SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_V1', mode='overwrite')
predictions.write.save_as_table('SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_SOURCE_V1', mode='overwrite')

### Creating predictions for the next month
We now use the trained model on our **April data** to predict each customer's **revenue for May**.  
The workflow looks like this:  
1. Build a spine DataFrame
2. Retrieve Features from Feature Store 
3. Generate predictions

Now, let’s take this a step further with a **real-world challenge**:  
Imagine you didn’t train the model yourself. How would you determine which input features are required and where to source them?

The answer is simple:  
Query the **automatically captured lineage** information in Snowflake!  
The model is directly linked to the dataset used for training, which in turn is connected to the relevant Feature Views.

In [None]:
feature_views = production_model.lineage(direction='upstream')[0].lineage(domain_filter=['feature_view'], direction='upstream')
[fv.name for fv in feature_views]

In [None]:
def get_feature_df(model, feature_cutoff_date):
    # Use lineage information to retrieve the feature views of this model
    feature_views = model.lineage(direction='upstream')[0].lineage(domain_filter=['feature_view'], direction='upstream')
    
    # Create the spine dataframe containing all customers
    spine_df = (
        session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.CUSTOMERS')
        .select('CUSTOMER_ID')
        .with_column('FEATURE_CUTOFF_DATE', F.to_date(lit(feature_cutoff_date)))
    )

    # Retrieve feature values from the Feature Store for the specified cutoff date.
    feature_df = fs.retrieve_feature_values(
        spine_df=spine_df,
        features=feature_views,
        spine_timestamp_col="FEATURE_CUTOFF_DATE"
    )
    return feature_df

In [None]:
feature_df = get_feature_df(
    production_model, 
    feature_cutoff_date='2024-05-01', 
)

predictions = production_model.run(feature_df, function_name='PREDICT')
predictions.write.save_as_table(table_name='SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_SOURCE_V1', mode='append', column_order='name')

# View predictions
print('Predictions [column=NEXT_MONTH_REVENUE_PREDICTION]:')
session.table('SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_SOURCE_V1').show()

### Creating a Model Monitor  

We are setting up a **model monitor** to continuously calculate and track model performance and drift over time.  

These calculations are based on the **`BASELINE`** and **`SOURCE`** tables created earlier.  
Each model requires its own dedicated **model monitor** to ensure accurate tracking and evaluation.

In [None]:
source_config = ModelMonitorSourceConfig(
    source='SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_SOURCE_V1',
    timestamp_column='FEATURE_CUTOFF_DATE',
    id_columns=['CUSTOMER_ID'],
    prediction_score_columns=['NEXT_MONTH_REVENUE_PREDICTION'],
    actual_score_columns=['NEXT_MONTH_REVENUE'],
    baseline='SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_V1'
)

monitor_config = ModelMonitorConfig(
    model_version=reg.get_model('CUSTOMER_REVENUE_MODEL').version('PRODUCTION'),
    model_function_name='predict',
    background_compute_warehouse_name='COMPUTE_WH',
    refresh_interval='1 minute',
    aggregation_window='1 day'
)

model_monitor = reg.add_monitor(
    name='SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_V1',
    source_config=source_config,
    model_monitor_config=monitor_config
)

### Simulating the next month of Customer Transactions
Our model has predicted each customer's **revenue for May 2024** and stored the results in the **`SOURCE`** table.  
Next, we simulate the actual transactions for May and update the **true revenue values** for each customer in the **`SOURCE`** table.  
When the **model monitor** refreshes, it will use these updated values to calculate various **model performance metrics**, including the MAPE.

In [None]:
# Add new transactions (created as part of the initial demo setup)
new_transactions = session.table('SIMPLE_MLOPS_DEMO._DATA_GENERATION._TRANSACTIONS').filter(col('DATE').between('2024-05-02','2024-06-01'))
new_transactions.write.save_as_table(table_name='SIMPLE_MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS', mode='append')

# Calculate actual values
actual_values_df = (
    session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')
    .filter(col('DATE').between('2024-05-02','2024-06-01'))
    .group_by(['CUSTOMER_ID'])
    .agg(F.sum('TRANSACTION_AMOUNT').as_('TOTAL_REVENUE'))
    .with_column('DATE', F.to_date(lit('2024-05-01')))
)

# Get list of all customers
customers_df = session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.CUSTOMERS').select('CUSTOMER_ID').distinct()

# Assume 0 revenue for customers without transactions
actual_values_df = actual_values_df.join(customers_df, on=['CUSTOMER_ID'], how='outer')
actual_values_df = actual_values_df.fillna(0,subset='TOTAL_REVENUE')

# Update source table from model monitor
source_table = session.table('SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_SOURCE_V1')
source_table.update(
    condition=(
        (source_table['FEATURE_CUTOFF_DATE'] == actual_values_df['DATE']) &
        (source_table['CUSTOMER_ID'] == actual_values_df['CUSTOMER_ID'])
    ),
    assignments={
        "NEXT_MONTH_REVENUE": actual_values_df['TOTAL_REVENUE'],
    },
    source=actual_values_df
)

## Simulate Customer Transactions until February 2025
For convenience, I encapsulated all the logic for simulating future months into the helper function `simulate_model_performance()`.  
We use this function to simulate the model's behavior until February 2025.

In [None]:
start_date = '2024-06-01'
end_date = '2025-01-01'
demo_flow.simulate_model_performance(production_model, start_date, end_date, generate_data=True)

## Explore the Model Monitor
Navigate to the Model Monitor and observe the `MAPE` and `Jensen-Shannon Distance`  for the last months.  

You will notice the following:
* Declining Model Performance
    * :chart_with_upwards_trend: MAPE (Mean Average Percentage Error)
* Feature Drift
    * :chart_with_upwards_trend: Distance for TOTAL_REVENUE_IN_SHOP_PAST_1_MONTHS (less in shop transaction volume)
    * :chart_with_downwards_trend: Difference for TOTAL_REVENUE_ONLINE_PAST_1_MONTHS (more online transaction volume)

Why is that?  
Well, if we visualize the monthly revenue distribution, we can see that online revenue grew while in-shop transaction declined.

Instead of using the builtin UI, you can also query model monitor metrics using the following table functions and build your own visuals:
* [MODEL_MONITOR_PERFORMANCE_METRIC](https://docs.snowflake.com/en/sql-reference/functions/model-monitor-performance-metric)
* [MODEL_MONITOR_DRIFT_METRIC](https://docs.snowflake.com/en/sql-reference/functions/model-monitor-drift-metric)
* [MODEL_MONITOR_STAT_METRIC](https://docs.snowflake.com/en/sql-reference/functions/model-monitor-stat-metric)

In [None]:
get_snowsight_url(session, 'Link to Model Monitor', '#/data/databases/SIMPLE_MLOPS_DEMO/schemas/MODEL_REGISTRY/model/CUSTOMER_REVENUE_MODEL/version/V1/monitors/MM_V1/dashboard')

In [None]:
with st.expander('**Need help deciding for the right metric?**', expanded=False):
    text = """### **Overview of Feature Drift Metrics:**

|                      | **Jensen-Shannon Distance** | **Wasserstein Distance** | **Difference of Means** |
|-----------------------------|----------------|--------------------------|-------------------------|
| **What It Measures**        | Difference in probability distributions | Amount of movement needed to align two distributions | Simple difference between means of two distributions |
| **Intuition**               | Measures how **different** two distributions are (based on KL divergence, but smoothed and symmetric). | Measures the **work needed** to "move" one distribution to match the other. | Measures the shift in the **central tendency** of the feature values. |
| **Range**                   | 0 to 1 (bounded) | 0 to ∞ (can grow indefinitely) | -∞ to ∞ (unbounded) |
| **Interpretability**        | 0 = identical, 1 = completely different | Larger values mean greater distribution shift | Positive = mean has increased, Negative = mean has decreased |
| **Computational Complexity** | Faster, works well with discrete values | Slower, requires solving an optimization problem | Very fast (simple arithmetic) |
| **Small shifts in values**  | May not detect it well if probability distributions overlap a lot. | Captures even small shifts because it looks at the actual distance between values. | Only detects shifts in the mean, not overall distribution changes. |
| **Major changes in shape**  | Captures well if distributions change significantly. | Captures well if mass shifts significantly. | ❌ No, only captures mean changes. |
| **Outliers or extreme shifts** | May be less sensitive if distributions overlap in many places. | More sensitive because it considers the actual movement of values. | Very sensitive to outliers (mean can shift significantly). |
| **Best for categorical distributions** (e.g., customer segments) | ✅ Yes | ❌ No | ❌ No |
| **Best for continuous features** (e.g., age, income) | ❌ No | ✅ Yes | ✅ Yes |
| **Best for detecting gradual numerical shifts** | ❌ No | ✅ Yes | ✅ Yes, but only if the mean is shifting. |
| **Best for interpretable (bounded 0-1) metric** | ✅ Yes | ❌ No | ❌ No |"""
    st.markdown(text)

### Query the model monitor
With the built-in functions, querying metrics from the model monitor becomes effortless, allowing for further analysis and visualization.
Additionally, this notebook includes a small helper for the model registry, enabling you to quickly navigate through your registered models, visualize their metrics, and compare multiple models in a single graph.

Moreover, the outputs from these model registry functions can be leveraged to create alerts and trigger automated tasks. Example use cases include:
* Sending an email notification to your ML engineer if a model's performance drops below a predefined threshold.
* Initiating an automated retraining of a model with fresh data using [Snowflake Tasks](https://docs.snowflake.com/en/user-guide/tasks-intro)

In [None]:
session.table_function(
    "MODEL_MONITOR_DRIFT_METRIC",
    lit('MM_V1'),
    lit('WASSERSTEIN'),
    lit('TOTAL_REVENUE_IN_SHOP_PAST_1_MONTHS'),
    lit('1 day'),
    lit('2024-01-01'),
    lit('2024-12-31')
).show()

In [None]:
model_registry_helper = ModelRegistryHelper(session, reg)

In [None]:
model_registry_helper.plot_model_performance()

### Why is the new model performing better?
Let's ask CortexPilot what changed.

The bar chart displays the overall revenue per month, categorized by transaction channel (Online and In-Shop). It shows a notable increase in online transactions starting in June.

In [None]:
my_pilot.f_cortex_helper_visualize_query(
    transactions_df, 
    'What was the overall revenue per channel and month? Use a stacked bar plot and use YY-Monthname for the x-axis. Make sure the x-axis is ordered by month.'
)

## Train a new Model Version  

Since **user behavior has changed**, we will train a **new version of our model** using fresh data.  

To streamline this process, I have encapsulated the entire training workflow into the helper function `train_new_model()`, which automates the following steps:  

- **Creates the spine DataFrame**, including the target variable.  
- **Retrieves features** from the Feature Store.  
- **Creates a Snowflake Dataset** from the training data (ensuring reproducibility with a snapshot).  
- **Trains a new XGBoost model**.  
- **Registers the model** in the Snowflake Model Registry.  
- **Sets up a new model monitor** to track performance and drift.  
- **Compares model performance** against the existing production model.  
- **Deploys the new model** if it outperforms the current one by assigning it the **"PRODUCTION"** alias.  

Since the training data includes **June, July, and August 2024** (covering training data up to **Septemer 1st, 2024**, and looking back three months), the model should recognize that **ONLINE transactions** have become a major driver of customer revenue.

In [None]:
import logging
from opentelemetry import trace

class ModelTrainer():
    def __init__(self, session):
        self.session = session
        self.registry = Registry(
            session=session, 
            database_name='SIMPLE_MLOPS_DEMO',
            schema_name='MODEL_REGISTRY', 
            options={'enable_monitoring': True},
        )
        self.fs = FeatureStore(
            session=session, 
            database='SIMPLE_MLOPS_DEMO', 
            name='FEATURE_STORE', 
            default_warehouse=session.get_current_warehouse(),
            creation_mode=CreationMode.CREATE_IF_NOT_EXIST,
        )
        self.logger = logging.getLogger("logger.ModelTrainer")
        self.tracer = trace.get_tracer("tracer.ModelTrainer")

    def train_new_model(self, feature_views: dict, feature_cutoff_date: str, target_start_date: str, target_end_date: str, model_version: str):
        train_df, test_df, feature_columns = self.prepare_data(feature_views, feature_cutoff_date, target_start_date, target_end_date, model_version)
        model = self.train(train_df, feature_columns)
        mape, predictions = self.evaluate_model(model, test_df)
        registered_model = self.register_new_model(model, model_version, train_df, feature_columns, feature_cutoff_date, mape)
        self.create_model_monitor(registered_model, model_version, predictions)
        self.evaluate_against_production_model(registered_model, test_df, mape)

    def prepare_data(self, feature_views: dict, feature_cutoff_date: str, target_start_date: str, target_end_date: str, model_version: str):
        with self.tracer.start_as_current_span("Data Preparation"):
            feature_views = [self.fs.get_feature_view(fv,feature_views[fv]) for fv in feature_views]
            target_df = self.session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.TRANSACTIONS')
            target_df = (
                target_df.filter(col('DATE').between(target_start_date,target_end_date))
                .group_by('CUSTOMER_ID')
                .agg(F.sum('TRANSACTION_AMOUNT').as_('NEXT_MONTH_REVENUE'))
                .with_column('FEATURE_CUTOFF_DATE', F.to_date(lit(feature_cutoff_date)))
            )
            
            customers_df = self.session.table('SIMPLE_MLOPS_DEMO.RETAIL_DATA.CUSTOMERS').select('CUSTOMER_ID').distinct()
            spine_df = target_df.join(customers_df, on=['CUSTOMER_ID'], how='outer')
            spine_df = spine_df.fillna(0, subset='NEXT_MONTH_REVENUE')
    
            train_dataset = self.fs.generate_dataset(
                name="SIMPLE_MLOPS_DEMO.FEATURE_STORE.NEXT_MONTH_REVENUE_DATASET",
                spine_df=spine_df,
                features=feature_views,
                version=model_version,
                spine_timestamp_col="FEATURE_CUTOFF_DATE",
                spine_label_cols=["NEXT_MONTH_REVENUE"],
                include_feature_view_timestamp_col=False,
                desc=f"Training dataset from {feature_cutoff_date}"
            )
            
            df = train_dataset.read.to_snowpark_dataframe()
            train_df, test_df = df.random_split(weights=[0.9, 0.1], seed=0)
            feature_columns = train_df.drop(['CUSTOMER_ID', 'FEATURE_CUTOFF_DATE', 'NEXT_MONTH_REVENUE']).columns
            self.logger.info('Training dataset created.')
            return  train_df, test_df, feature_columns
        
    def train(self, train_df, feature_columns):
        with self.tracer.start_as_current_span("Model Fitting"):
            model = XGBRegressor(
                input_cols=feature_columns,
                label_cols=['NEXT_MONTH_REVENUE'],
                output_cols=['NEXT_MONTH_REVENUE_PREDICTION'],
                n_estimators=100,
                learning_rate=0.05,
                random_state=0
            )
            model = model.fit(train_df)
            feature_importance = dict(zip(feature_columns, xgb_model.to_xgboost().feature_importances_))
            telemetry.add_event("model_training", {"feature_importance": feature_importance})
            self.logger.info('Successfully trained a new model.')
            return model

    def evaluate_model(self, model, test_df):
        with self.tracer.start_as_current_span("Model Evaluation"):
            predictions = model.predict(test_df)
            mape = mean_absolute_percentage_error(
                df=predictions, 
                y_true_col_names="NEXT_MONTH_REVENUE", 
                y_pred_col_names="NEXT_MONTH_REVENUE_PREDICTION"
            )
            telemetry.add_event("model_evaluation", {"metric": "mape", "value": mape})
            self.logger.info(f'New model has a MAPE of {mape}.')
            return mape, predictions

    def register_new_model(self, model, model_version, train_df, feature_columns, feature_cutoff_date, mape):
        with self.tracer.start_as_current_span("Model Registration"):
            registered_model = self.registry.log_model(
                model,
                model_name="CUSTOMER_REVENUE_MODEL",
                version_name=model_version,
                metrics={
                    'MAPE': mape, 
                    'TRAINING_DATA': {'FEATURE_CUTOFF_DATE': feature_cutoff_date}
                },
                comment="Model trained using XGBoost to predict revenue per customer for next month.",
                conda_dependencies=['xgboost'],
                sample_input_data=train_df.select(feature_columns).limit(10),
                options={"relax_version": False, "enable_explainability": True}
            )
            self.logger.info(f'Registered new model with version {model_version} in model registry.')
            return registered_model

    def create_model_monitor(self, registered_model, model_version, predictions):
        with self.tracer.start_as_current_span("Model Monitor Creation"):
            predictions.write.save_as_table(f'SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_{model_version}', mode='overwrite')
            predictions.write.save_as_table(f'SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_SOURCE_{model_version}', mode='overwrite')
            
            source_config = ModelMonitorSourceConfig(
                source=f'SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_SOURCE_{model_version}',
                baseline=f'SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_REVENUE_BASELINE_{model_version}',
                timestamp_column='FEATURE_CUTOFF_DATE',
                id_columns=['CUSTOMER_ID'],
                prediction_score_columns=['NEXT_MONTH_REVENUE_PREDICTION'],
                actual_score_columns=['NEXT_MONTH_REVENUE'],
            )
            
            monitor_config = ModelMonitorConfig(
                model_version=registered_model,
                model_function_name='predict',
                background_compute_warehouse_name='COMPUTE_WH',
                refresh_interval='1 minute',
                aggregation_window='1 day'
            )
            
            model_monitor = self.registry.add_monitor(
                name=f'SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.MM_{model_version}',
                source_config=source_config,
                model_monitor_config=monitor_config
            )

    def evaluate_against_production_model(self, registered_model, test_df, mape):
        with self.tracer.start_as_current_span("Model Deployment"):
            production_model = self.registry.get_model('CUSTOMER_REVENUE_MODEL').version('PRODUCTION')
            production_model_predictions = production_model.run(test_df, function_name='PREDICT')
            production_model_mape = mean_absolute_percentage_error(
                df=production_model_predictions, 
                y_true_col_names="NEXT_MONTH_REVENUE", 
                y_pred_col_names="NEXT_MONTH_REVENUE_PREDICTION"
            )
            
            if mape < production_model_mape:
                self.logger.info(f"New model has a lower MAPE compared to current production model.")
                self.logger.info(f"New model will be put into production by setting its alias to PRODUCTION.")
                
                # Update model aliases:
                production_model.unset_alias('PRODUCTION')
                production_model.set_alias('DEPRECATED')
                registered_model.set_alias('PRODUCTION')
            else:
                self.logger.info(f"Existing production model has a lower MAPE compared to the developed model.")
                self.logger.info(f"New model is not automatically set into production.")

### Deploy the training pipeline
Our pipeline is fairly simple but for demo purposes we will create a Directed-Acyclic-Graph (DAG) for it.  
This DAG consists of two Stored Procedures which capsulate logic for each step and allow detailed monitoring.  

1. Training a new model with fresh data
2. Sending Notifications

By registering the functions as Stored Procedures, they become available outside of this notebook for everybody with the right privileges.
Your co-developers can then easily integrate your work in their pipelines even if they are developing in other IDEs or even languages (e.g. SQL).

Last but not least, Snowflake is capturing logs, metrics and traces for your Stored Procedures, UDFs, etc. automatically in [Event Tables](https://docs.snowflake.com/en/developer-guide/logging-tracing/event-table-setting-up).  
**Note:** You need to [enable telemetry collection](https://docs.snowflake.com/en/developer-guide/logging-tracing/logging-tracing-enabling).

In [None]:
@sproc(
    name='TRAIN_CUSTOMER_REVENUE_MODEL',
    is_permanent=True,
    replace=True,
    stage_location='SIMPLE_MLOPS_DEMO.PUBLIC.PIPELINES',
    packages=['snowflake-snowpark-python','snowflake-ml-python==1.7.4','snowflake-telemetry-python'],
    imports=['@SIMPLE_MLOPS_DEMO.PUBLIC.GITHUB_REPOSITORY_SNOWFLAKE_SIMPLE_MLOPS/branches/main/src/demo_extras/model_trainer.py'],
    execute_as='caller'
)
def train_new_model(session: Session, feature_views: dict, feature_cutoff_date: str, target_start_date: str, target_end_date: str, model_version: str) -> str:
    from model_trainer import ModelTrainer
    logger = logging.getLogger("logger.ModelTrainer")
    tracer = trace.get_tracer("tracer.ModelTrainer")
    logger.info('Starting model training.')
    with tracer.start_as_current_span("XGBoost Pipeline"):
        model_trainer = ModelTrainer(session)
        model_trainer.train_new_model(feature_views, feature_cutoff_date, target_start_date, target_end_date, model_version)
    logger.info('Model training finished.')
    return 'Successfully trained and registered a new model.'

In [None]:
feature_views = {'IN_SHOP_REVENUE_FEATURES':'V1', 'ONLINE_REVENUE_FEATURES':'V1'}
feature_cutoff_date = '2024-09-01'
target_start_date = '2024-09-02'
target_end_date = '2024-10-01'
model_version = 'V2'

train_new_model(session, feature_views, feature_cutoff_date, target_start_date, target_end_date, model_version)

In [None]:
#model_trainer = ModelTrainer(session)
#
#feature_views = {'IN_SHOP_REVENUE_FEATURES':'V1', 'ONLINE_REVENUE_FEATURES':'V1'}
#feature_cutoff_date = '2024-09-01'
#target_start_date = '2024-09-02'
#target_end_date = '2024-10-01'
#model_version = 'V2'
#
#model_trainer.train_new_model(feature_views, feature_cutoff_date, target_start_date, target_end_date, model_version)

### Simulate Model performance for Model Version V2 until 2025-01-31
Once again, we are simulating **model performance** based on customer transactions up to **February 2025**.  
Make sure to check the **model monitor** to evaluate whether the new model version trained on more recent data performs better.  
Additionally, analyze the **feature drift**, where you’ll notice that the trend for the **V2 model** is much more stable.

In [None]:
model = reg.get_model('CUSTOMER_REVENUE_MODEL').version('V2')
start_date = '2024-10-01'
end_date = '2025-01-01'
demo_flow.simulate_model_performance(model, start_date, end_date)

In [None]:
model_registry_helper.plot_model_performance(update_data=True)

### Comparing the two Model Versions
We have already observed that the new model provides **significantly better predictions** for future customer revenue. However, we want to gain deeper insights into **why** this improvement occurred.  

To analyze this, we are generating the SHAP Summary Plot for both models for the latest data. This reveals that the new model recognizes a **much stronger influence** of past **ONLINE transactions** on future customer revenue.  

In [None]:
model_registry_helper.update_registry_data()

shap_exp1 = model_registry_helper.get_model_explanations(
    reg.get_model('CUSTOMER_REVENUE_MODEL').version('V1'), feature_columns=feature_columns, feature_cutoff_date='2025-01-01'
)

shap_exp2 = model_registry_helper.get_model_explanations(
    reg.get_model('CUSTOMER_REVENUE_MODEL').version('V2'), feature_columns=feature_columns, feature_cutoff_date='2025-01-01'
)

col1, col2 = st.columns(2)
with col1:
    shap.summary_plot(shap_exp1)
with col2:
    shap.summary_plot(shap_exp2)

## ML Lineage
Even though you may not have noticed, you’ve been capturing **lineage information** throughout the development of your machine learning pipeline.  

You can retrieve this information using the built-in function `lineage.trace()` for further analysis.  
For example, you can use this data to **visualize the lineage directly in the notebook**.  

Additionally, Snowflake provides a **more user-friendly and interactive UI** that allows you to explore and monitor your machine learning pipeline:  
![text](https://github.com/michaelgorkow/snowflake_simple_mlops/blob/main/resources/ml_lineage3.png?raw=true)

As shown, the lineage captures a **comprehensive view** of your pipeline, tracking data transformations and dependencies from the **source tables**, through the **feature view**, the **training dataset**, and ultimately the **registered model** in the Model Registry.

In [None]:
get_snowsight_url(session, 'Link to Lineage View', '#/data/databases/SIMPLE_MLOPS_DEMO/schemas/MODEL_REGISTRY/model/CUSTOMER_REVENUE_MODEL/version/V2/lineage')

In [None]:
trace = session.lineage.trace(
    object_name='SIMPLE_MLOPS_DEMO.MODEL_REGISTRY.CUSTOMER_REVENUE_MODEL',
    object_version='V2',
    object_domain='model',
    direction='both',
    distance=2
)
trace.show()

In [None]:
lineage_helper = LineageHelper()
lineage_helper.visualize_lineage(trace.to_pandas(), short_names=True)