Connect to Snowflake

In [None]:
from snowflake.snowpark import Session
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions

session = Session.builder.configs(SnowflakeLoginOptions("test_conn")).create()

Load cleaned-up 10% test diamonds table dataset

In [None]:
df = session.table("test.diamonds.diamonds_transform_pipeline")
_, test_df = df.random_split(weights=[0.9, 0.1], seed=0)
test_df.show()

List all models in current registry

In [None]:
from snowflake.ml.registry import Registry

registry = Registry(session=session, database_name="TEST", schema_name="PUBLIC")
print("Models:"); registry.show_models()

Show functions for our registered model

In [None]:
model = registry.get_model("RandomForestRegressor")
mv = model.default
print("Functions:"); mv.show_functions()

Predict diamond prices

In [None]:
pred = mv.run(test_df)      # function_name='predict'
pred.select("PRICE", "PREDICTED_PRICE").show()

Add MAPE metric

In [None]:
from snowflake.ml.modeling.metrics import mean_absolute_percentage_error

mape = mean_absolute_percentage_error(
    df=pred, 
    y_true_col_names="PRICE", 
    y_pred_col_names="PREDICTED_PRICE")
print(f"MAPE: {mape}")

mv.set_metric("MAPE", mape)
# mv.delete_metric("MAPE")
mv.show_metrics()