Connect to Snowflake

In [1]:
from snowflake.snowpark import Session
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions

session = Session.builder.configs(SnowflakeLoginOptions("test_conn")).create()

SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. 


Load cleaned-up test diamonds table dataset

In [2]:
df = session.table("test.diamonds.diamonds_transform_pipeline")
_, test_df = df.random_split(weights=[0.9, 0.1], seed=0)
test_df.show()

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"CARAT"               |"DEPTH"              |"TABLE_PCT"          |"X"                  |"Y"                  |"Z"                  |"CUT_OE"  |"COLOR_OE"  |"CLARITY_OE"  |"COLOR"  |"CLARITY"  |"PRICE"  |"CUT"      |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|0.02079002079002079   |0.5472222222222223   |0.3076923076923076   |0.3919925512104283   |0.07249575551782682  |0.08364779874213837  |2.0       |6.0         |5.0           |J        |SI1        |351      |VERY_GOOD  |
|0.02286902286902287   |0.4555555555555555   |0.3653846153846153   |0.40875232774674114  |0.07521222410865874  |0.08238993710691

List all models in current registry

In [3]:
from snowflake.ml.registry import Registry

registry = Registry(session=session, database_name="TEST", schema_name="PUBLIC")
print("Models:"); registry.show_models()

Models:


Unnamed: 0,created_on,name,database_name,schema_name,comment,owner,default_version_name,versions
0,2024-04-30 10:45:27.831000-07:00,RANDOMFORESTCLASSIFIER,TEST,PUBLIC,,ACCOUNTADMIN,V1,"[""V1""]"
1,2024-04-30 07:52:14.437000-07:00,RANDOMFORESTREGRESSOR,TEST,PUBLIC,new comment,ACCOUNTADMIN,V1,"[""V1"",""V2""]"
2,2024-04-30 10:23:54.568000-07:00,XGBCLASSIFIER,TEST,PUBLIC,,ACCOUNTADMIN,V1,"[""V1"",""V2"",""V3""]"
3,2024-04-30 10:47:42.566000-07:00,XGBOOSTER,TEST,PUBLIC,,ACCOUNTADMIN,V1,"[""V1""]"


Show functions for our registered model

In [4]:
model = registry.get_model("RandomForestRegressor")
mv = model.default
print("Functions:"); mv.show_functions()

Functions:


[{'name': 'PREDICT',
  'target_method': 'predict',
  'signature': ModelSignature(
                      inputs=[
                          FeatureSpec(dtype=DataType.DOUBLE, name='CUT_OE'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='COLOR_OE'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='CLARITY_OE'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='CARAT'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='DEPTH'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='TABLE_PCT'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='X'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='Y'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='Z')
                      ],
                      outputs=[
                          FeatureSpec(dtype=DataType.DOUBLE, name='CUT_OE'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='COLOR_OE'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='CLARITY_OE'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='CARAT'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='DEPTH'),
  		FeatureSpec(dtype

Predict diamond prices

In [5]:
pred = mv.run(test_df)      # function_name='predict'
pred.select("PRICE", "PREDICTED_PRICE").show()

-------------------------------
|"PRICE"  |"PREDICTED_PRICE"  |
-------------------------------
|351      |375.35             |
|353      |389.05             |
|355      |408.83             |
|357      |384.73             |
|554      |545.08             |
|554      |526.33             |
|2757     |2804.83            |
|2759     |2995.51            |
|2759     |2760.89            |
|2762     |2918.55            |
-------------------------------



Add MAPE metric

In [12]:
from snowflake.ml.modeling.metrics import mean_absolute_percentage_error

mape = mean_absolute_percentage_error(
    df=pred, 
    y_true_col_names="PRICE", 
    y_pred_col_names="PREDICTED_PRICE")
print(f"MAPE: {mape}")

#mv.delete_metric("MAPE")
#mv.set_metric("MAPE", mape)
#mv.show_metrics()

SyntaxError: invalid syntax (804073422.py, line 1)