Connect to Snowflake

In [None]:
from snowflake.snowpark import Session
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions

session = Session.builder.configs(SnowflakeLoginOptions("test_conn")).create()
session.query_tag = "model-registry-1"

Load cleaned-up diamonds table

In [None]:
df = session.table("test.diamonds.diamonds_transform_pipeline")
train_df, test_df = df.random_split(weights=[0.9, 0.1], seed=0)
train_df.show()

Train an XGB regressor and predict diamond prices

In [None]:
from snowflake.ml.modeling.ensemble import RandomForestRegressor

model = RandomForestRegressor(
    input_cols=["CUT_OE", "COLOR_OE", "CLARITY_OE", "CARAT", "DEPTH", "TABLE_PCT", "X", "Y", "Z"],
    label_cols=['PRICE'],
    output_cols=['PREDICTED_PRICE'])
model.fit(train_df)

pred = model.predict(test_df)
pred.select("PRICE", "PREDICTED_PRICE").show()

Register model

In [None]:
from snowflake.ml.registry import Registry

registry = Registry(session=session)    # database_name="TEST", schema_name="PUBLIC"

# CREATE MODEL TEST.PUBLIC.RANDOMFORESTREGRESSOR WITH VERSION V1
# FROM @TEST.PUBLIC.SNOWPARK_TEMP_STAGE_.../model
model_ref = registry.log_model(
    model,
    model_name="RandomForestRegressor",
    version_name="v1",
    conda_dependencies=["scikit-learn"])