Version: 0.0.2  Updated date: 07/05/2024
Conda Environment : py-snowpark_df_ml_fs-1.15.0_v1

# Getting Started with Snowflake Feature Store
We will use the Use-Case to show how Snowflake Feature Store (and Model Registry) can be used to maintain & store features, retrieve them for training and perform micro-batch inference.

In the development (TRAINING) enviroment we will 
- create FeatureViews in the Feature Store that maintain the required customer-behaviour features.
- use these Features to train a model, and save the model in the Snowflake model-registry.
- plot the clusters for the trained model to visually verify. 

In the production (SERVING) environment we will
- re-create the FeatureViews on production data
- generate an Inference FeatureView that uses the saved model to perform incremental inference

# Feature Engineering & Model Training

In [1]:
%load_ext autoreload
%autoreload 2

#### Notebook Packages

In [2]:
# Python
from time import perf_counter

# ML
import pandas as pd
import xgboost as xgb
from sklearn.compose import ColumnTransformer
from sklearn.metrics import mean_absolute_error, mean_absolute_percentage_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler

# SNOWFLAKE
# Snowpark
from snowflake.ml.data.data_connector import DataConnector
from snowflake.ml.registry import Registry as ModelRegistry
from snowflake.snowpark import Session, Row
from snowflake.ml.dataset import Dataset
from snowflake.ml.dataset import load_dataset
from snowflake.ml.experiment import ExperimentTracking
from snowflake.ml.experiment.callback.xgboost import SnowflakeXgboostCallback
from snowflake.ml.model.model_signature import infer_signature
from snowflake.snowpark.context import get_active_session

# Custom
from useful_fns import create_SF_Session

  from .autonotebook import tqdm as notebook_tqdm


### Setup Snowflake connection and database parameters

In [3]:
# Schemas
tpcxai_schema = 'SERVING'

In [4]:
fs_qs_role, tpcxai_database, tpcxai_serving_schema, session, warehouse_env = create_SF_Session(tpcxai_schema, role="ACCOUNTADMIN")

You might have more than one threads sharing the Session object trying to update sql_simplifier_enabled. Updating this while other tasks are running can potentially cause unexpected behavior. Please update the session configuration before starting the threads.



Connection Established with the following parameters:
User                        : JARCHEN
Role                        : "ACCOUNTADMIN"
Database                    : "TPCXAI_SF0001_QUICKSTART_INC"
Schema                      : "SERVING"
Warehouse                   : "TPCXAI_SF0001_QUICKSTART_WH"
Snowflake version           : 9.37.1
Snowpark for Python version : 1.38.0 



In [5]:
# Create compute pool
def create_compute_pool(name: str, instance_family: str, min_nodes: int = 1, max_nodes: int = 10) -> list[Row]:
    query = f"""
        CREATE COMPUTE POOL IF NOT EXISTS {name}
            MIN_NODES = {min_nodes}
            MAX_NODES = {max_nodes}
            INSTANCE_FAMILY = {instance_family}
    """
    return session.sql(query).collect()

compute_pool = "DEMO_POOL_CPU"
create_compute_pool(compute_pool, "CPU_X64_S")

[Row(status='DEMO_POOL_CPU already exists, statement succeeded.')]

## PIPELINE DEVELOPMENT

In [None]:
import math
def create_data_connector(session, dataset_name) -> DataConnector:
    """Load data from Snowflake DataSet"""
    ds = Dataset.load(
        session=session, 
        name=dataset_name
    )
    ds_latest_version = str(ds.list_versions()[-1])
    ds_df = load_dataset(
        session, 
        dataset_name, 
        ds_latest_version
    )

    return DataConnector.from_dataset(ds_df)


def compare_params(input_d, extracted_d):
    ignore_keys = ['callbacks'] # Ignore complex objects
    mismatches = []
    
    for key, val in input_d.items():
        if key in ignore_keys: continue
            
        # Check if key exists in extraction
        if key not in extracted_d:
            mismatches.append(f"Missing key: {key}")
            continue
            
        ex_val = extracted_d[key]
        
        # Handle Float vs Int (63 vs 63.0) and NaNs
        if isinstance(val, (int, float)) and isinstance(ex_val, (int, float)):
            # Check for NaN in both (NaN != NaN in Python, so we must handle explicitly)
            if pd.isna(val) and pd.isna(ex_val):
                continue
            if not math.isclose(val, ex_val):
                mismatches.append(f"{key}: {val} (Input) != {ex_val} (Row)")
        
        # Standard comparison for strings/others
        elif val != ex_val:
            mismatches.append(f"{key}: {val} != {ex_val}")
            
    return mismatches

def generate_train_val_set(dataframe: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Generate train and validation dataset"""
    # Split data
    X = dataframe[['RETURN_RATIO', 'FREQUENCY']]
    y = dataframe["RETURN_ROW_PRICE"]
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    print(f"Splitted data")

    # Combine features and target for each split
    train_df = pd.concat([X_train, y_train], axis=1)
    val_df = pd.concat([X_test, y_test], axis=1)
    return train_df, val_df

def build_pipeline(**model_params) -> Pipeline:
    """Create pipeline with preprocessors and model"""
    # Define column types
    feature_cols = ['RETURN_RATIO', 'FREQUENCY'] 

    # Create preprocessing steps
    preprocessor = ColumnTransformer(
        transformers=[
            ('NUM', MinMaxScaler(), feature_cols)
        ],
        remainder='passthrough',
    )

    model = xgb.XGBRegressor(**(model_params))

    return Pipeline([("preprocessor", preprocessor), ("regressor", model)])


def evaluate_model(model: Pipeline, X_test: pd.DataFrame, y_test: pd.DataFrame):
    """Evaluate model performance"""
    # Make predictions
    y_pred = model.predict(X_test)
    # Calculate metrics
    metrics = {
        "mean_absolute_error": mean_absolute_error(y_test, y_pred),
        "mean_absolute_percentage_error": mean_absolute_percentage_error(y_test, y_pred),
        "r2_score": r2_score(y_test, y_pred),
    }

    return metrics


def train():
    from snowflake.ml.modeling import tune
    from snowflake.ml.modeling.tune.search import RandomSearch, BayesOpt
    session = get_active_session()
    # Get tuner context
    tuner_context = tune.get_tuner_context()
    params = tuner_context.get_hyper_params()
    dm = tuner_context.get_dataset_map()
    model_name = params.pop("model_name")
    mr_schema_name = params.pop("mr_schema_name")
    experiment_name = params.pop("experiment_name")
    
    # Initialize experiment tracking for this trial
    exp = ExperimentTracking(session=session, schema_name=mr_schema_name)
    exp.set_experiment(experiment_name)

    run = exp.start_run()
    print("OG NAME!!!!!")
    print(run.name)
    print("++++++++++++++")

    # with exp.start_run():
    # Load data
    train_data = dm["train"].to_pandas()
    val_data = dm["val"].to_pandas()

    # Separate features and target
    X_train = train_data.drop('RETURN_ROW_PRICE', axis=1)
    y_train = train_data['RETURN_ROW_PRICE']
    X_val = val_data.drop('RETURN_ROW_PRICE', axis=1)
    y_val = val_data['RETURN_ROW_PRICE']

    # Train model
    sig = infer_signature(X_train, y_train)
    callback = SnowflakeXgboostCallback(
        exp, model_name="name", model_signature=sig
    )
    params['callbacks'] = [callback]

    model = build_pipeline(
        model_params=params
    )
    # Log model parameters with the log_param(...) or log_params(...) methods
    exp.log_params(params)

    print("Training model...", end="")
    start = perf_counter()
    model.fit(X_train, y_train)
    elapsed = perf_counter() - start
    print(f" done! Elapsed={elapsed:.3f}s")

    # Evaluate model
    print("Evaluating model...", end="")
    start = perf_counter()
    metrics = evaluate_model(
        model,
        X_val,
        y_val,
    )
    elapsed = perf_counter() - start
    print(f" done! Elapsed={elapsed:.3f}s")

    # Log model metrics with the log_metric(...) or log_metrics(...) methods
    exp.log_metrics(metrics)

    # Report to HPO framework (optimize on validation F1)
    tuner_context.report(
        metrics=metrics, 
        model=model
    )
    return {
        "run_name": run.name, 
        "params": params,
        "mean_absolute_error": metrics['mean_absolute_error'],
        "mean_absolute_percentage_error": metrics['mean_absolute_percentage_error'],
        "r2_score": metrics['r2_score'],
        "model": model,
        "X_train": X_train,
        "metrics": metrics
    }


In [None]:
from snowflake.ml.jobs import remote

@remote(compute_pool, stage_name="payload_stage", target_instances=3)
def train_remote(
        source_dataset: str, 
        model_name: str, 
        mr_schema_name: str,
        experiment_name: str
    ):
    from snowflake.ml.modeling import tune
    from snowflake.ml.modeling.tune.search import RandomSearch, BayesOpt

    # Retrieve session from SPCS service context
    session = Session.builder.getOrCreate()

    # Load data
    print("Loading data...", end="", flush=True)
    start = perf_counter()
    dc = create_data_connector(session, dataset_name=source_dataset)
    df = dc.to_pandas()
    elapsed = perf_counter() - start
    print(f" done! Loaded {len(df)} rows, elapsed={elapsed:.3f}s")

    print(f"Building train/val data")
    train_df, val_df = generate_train_val_set(df)

    X = train_df[['RETURN_RATIO', 'FREQUENCY']]
    y = train_df["RETURN_ROW_PRICE"]
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    # Create DataConnectors
    dataset_map = {
        "train": DataConnector.from_dataframe(session.create_dataframe(train_df)),
        "val": DataConnector.from_dataframe(session.create_dataframe(val_df)),
    }

    # Define search space for XGBoost
    search_space = {
        'mr_schema_name': mr_schema_name,
        'model_name': model_name,
        'experiment_name': experiment_name,
        'n_estimators': tune.randint(50, 200),
        'random_state': 42,
    }

    # Configure tuner
    tuner_config = tune.TunerConfig(
        metric='mean_absolute_percentage_error',
        mode='min',
        search_alg=RandomSearch(),
        num_trials=2,
    )

    # Create tuner
    tuner = tune.Tuner(
        train_func=train,
        search_space=search_space, 
        tuner_config=tuner_config
    )

    print(f"HPO starting")
    results = tuner.run(dataset_map=dataset_map)

    best_config = results.best_result[0] if isinstance(results.best_result, list) else results.best_result
    best_model = results.best_model[0] if isinstance(results.best_model, list) else results.best_model
    best_config_record = best_config.to_dict(orient='records')[0]
    best_config_dict = {
        str(k).removeprefix('config/'): v 
        for k, v in best_config_record.items() 
        if k.startswith('config/')
    }
    results_df: pd.DataFrame = results.results
    exp = ExperimentTracking(session=session, schema_name=mr_schema_name)
    exp.set_experiment(experiment_name)
    param_cols = [c for c in results_df.columns if str(c).startswith('params/')]

    for index, row in results_df.iterrows():
        run_name = row['run_name']
        exp.start_run(run_name)

        # run = runs.run_name
        metrics = {
            "mean_absolute_error": row['metrics/mean_absolute_error'],
            "mean_absolute_percentage_error": row['metrics/mean_absolute_percentage_error'],
            "r2_score": row['metrics/r2_score'],
        }
        params_series = row[param_cols]
        params_dict = {
            str(k).removeprefix('params/'): v 
            for k, v in params_series.items()
        }
        diffs = compare_params(best_config_dict, params_dict)

        if not diffs:
            # Save model to registry
            print("Logging model to Model Registry...", end="")
            exp.log_model(
                model=best_model, 
                model_name=model_name, 
                metrics=metrics,
                sample_input_data=X_train,
                conda_dependencies=["xgboost"],
            ) # type: ignore
            
        exp.end_run(row['run_name'])
    return {
        "results": results.results,
        "best_config": best_config,
        "best_model": best_model
    }
        

train_job = train_remote(
    source_dataset="TPCXAI_SF0001_QUICKSTART_INC._TRAINING_FEATURE_STORE.UC01_TRAINING",
    model_name = "MODEL_1.UC01_SNOWFLAKEML_RF_REGRESSOR_MODEL",
    mr_schema_name = "MODEL_1",
    experiment_name="MY_EXPERIMENT"
)

In [129]:
print(train_job.id)
print(train_job.status)

TPCXAI_SF0001_QUICKSTART_INC.SERVING.TRAIN_REMOTE_3V4OX5VWXBOQ
PENDING


In [130]:
train_job.wait()
train_job.show_logs()

2025-12-02 14:36:41,216 - INFO - Snowflake Connector for Python Version: 3.18.0, Python Version: 3.10.19, Platform: Linux-5.15.196-14.2025103011g9a182a6+snow+aws+5.15+amd64.x86_64-x86_64-with-glibc2.39
2025-12-02 14:36:41,217 - INFO - Connecting to GLOBAL Snowflake domain
2025-12-02 14:36:44,441 - INFO - Snowflake Connector for Python Version: 3.18.0, Python Version: 3.10.19, Platform: Linux-5.15.196-14.2025103011g9a182a6+snow+aws+5.15+amd64.x86_64-x86_64-with-glibc2.39
2025-12-02 14:36:44,442 - INFO - Connecting to GLOBAL Snowflake domain
2025-12-02 14:36:46,666 - INFO - Snowflake Connector for Python Version: 3.18.0, Python Version: 3.10.19, Platform: Linux-5.15.196-14.2025103011g9a182a6+snow+aws+5.15+amd64.x86_64-x86_64-with-glibc2.39
2025-12-02 14:36:46,666 - INFO - Connecting to GLOBAL Snowflake domain
2025-12-02 14:36:48,923 - INFO - Snowflake Connector for Python Version: 3.18.0, Python Version: 3.10.19, Platform: Linux-5.15.196-14.2025103011g9a182a6+snow+aws+5.15+amd64.x86_64-x

In [131]:
train_job.result()

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


{'results':    mean_absolute_error  mean_absolute_percentage_error  r2_score  \
 0            23.627853                        0.909369  0.004247   
 1            23.627853                        0.909369  0.004247   
 
    should_checkpoint     trial_id  time_total_s  config/n_estimators  \
 0                NaN  bc057_00000     25.434754                   74   
 1                NaN  bc057_00001     22.779046                  139   
 
    config/random_state                                   config/callbacks  \
 0                   42  [<snowflake.ml.experiment.callback.xgboost.Sno...   
 1                   42  [<snowflake.ml.experiment.callback.xgboost.Sno...   
 
            run_name  ...                                            X_train  \
 0  CHATTY_LEOPARD_4  ...         RETURN_RATIO  FREQUENCY\n0             ...   
 1     HAPPY_ROBIN_2  ...         RETURN_RATIO  FREQUENCY\n0             ...   
 
   params/n_estimators  params/random_state  \
 0                74.0            

## CLEAN UP

In [118]:
# session.close()

In [11]:
from datetime import datetime
from zoneinfo import ZoneInfo
formatted_time = datetime.now(ZoneInfo("Australia/Melbourne")).strftime("%A, %B %d, %Y %I:%M:%S %p %Z")

print(f"The last run time in Melbourne is: {formatted_time}")

The last run time in Melbourne is: Tuesday, December 02, 2025 08:28:41 PM AEDT
