Connect to Snowflake

In [1]:
from snowflake.snowpark import Session
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions

session = Session.builder.configs(SnowflakeLoginOptions("test_conn")).create()
session.query_tag = "model-registry-1"

SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. 


Load cleaned-up diamonds table

In [2]:
df = session.table("test.diamonds.diamonds_transform_pipeline")
train_df, test_df = df.random_split(weights=[0.9, 0.1], seed=0)
train_df.show()

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"CARAT"               |"DEPTH"              |"TABLE_PCT"          |"X"                  |"Y"                  |"Z"                  |"CUT_OE"  |"COLOR_OE"  |"CLARITY_OE"  |"COLOR"  |"CLARITY"  |"PRICE"  |"CUT"      |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|0.006237006237006237  |0.5138888888888888   |0.23076923076923073  |0.3677839851024209   |0.06757215619694397  |0.07641509433962265  |0.0       |1.0         |6.0           |E        |SI2        |326      |IDEAL      |
|0.002079002079002079  |0.46666666666666656  |0.34615384615384615  |0.36219739292364994  |0.06519524617996604  |0.07264150943396

Train an XGB regressor and predict diamond prices

In [3]:
from snowflake.ml.modeling.ensemble import RandomForestRegressor

model = RandomForestRegressor(
    input_cols=["CUT_OE", "COLOR_OE", "CLARITY_OE", "CARAT", "DEPTH", "TABLE_PCT", "X", "Y", "Z"],
    label_cols=['PRICE'],
    output_cols=['PREDICTED_PRICE'])
model.fit(train_df)

pred = model.predict(test_df)
pred.select("PRICE", "PREDICTED_PRICE").show()

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


-------------------------------
|"PRICE"  |"PREDICTED_PRICE"  |
-------------------------------
|351      |382.13             |
|353      |396.04             |
|355      |408.96             |
|357      |390.98             |
|554      |551.38             |
|554      |537.03             |
|2757     |2745.75            |
|2759     |2980.49            |
|2759     |2841.06            |
|2762     |2969.94            |
-------------------------------



Register model

In [5]:
from snowflake.ml.registry import Registry

registry = Registry(session=session)    # database_name="TEST", schema_name="PUBLIC"

# CREATE MODEL TEST.PUBLIC.RANDOMFORESTREGRESSOR WITH VERSION V1
# FROM @TEST.PUBLIC.SNOWPARK_TEMP_STAGE_.../model
model_ref = registry.log_model(
    model,
    model_name="RandomForestRegressor",
    version_name="v2",
    conda_dependencies=["scikit-learn"])

  return next(self.gen)


In [7]:
registry.show_models()

Unnamed: 0,created_on,name,database_name,schema_name,comment,owner,default_version_name,versions
0,2024-04-30 07:52:14.437000-07:00,RANDOMFORESTREGRESSOR,TEST,PUBLIC,new comment,ACCOUNTADMIN,V1,"[""V1"",""V2""]"
