In [7]:
import ast
import json
import warnings

import pandas as pd
from common import get_col_types, get_next_version, get_version_with_highest_accuracy
from snowflake.ml.modeling.impute import SimpleImputer
from snowflake.ml.modeling.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
)
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.preprocessing import OneHotEncoder
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.registry import Registry
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T
from snowflake.snowpark.functions import col

In [2]:
session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()

SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. 


In [6]:
df = session.table("TITANIC")
df = df.drop(["PASSENGER_ID", "AGE", "DECK", "ALIVE", "ADULT_MALE", "EMBARKED"])
train_df, test_df = df.random_split([0.8, 0.2], seed=42)
train_df.write.save_as_table("TRAIN", mode="overwrite")
test_df.write.save_as_table("TEST", mode="overwrite")
train_df, test_df = session.table("train"), session.table("test")
parameters = {
    "n_estimators": [100, 200, 300, 400, 500],
    "learning_rate": [0.1, 0.2, 0.3, 0.4, 0.5],
    "max_depth": list(range(3, 5, 1)),
    "min_child_weight": list(range(3, 5, 1)),
}
cat_cols = get_col_types(df, "string")
pipeline = Pipeline(
    steps=[
        (
            "SimpleImputer",
            SimpleImputer(
                input_cols=cat_cols,
                output_cols=cat_cols,
                strategy="most_frequent",
                drop_input_cols=True,
            ),
        ),
        (
            "OneHotEncoder",
            OneHotEncoder(
                input_cols=cat_cols,
                output_cols=cat_cols,
                drop_input_cols=True,
                drop="first",
                handle_unknown="ignore",
            ),
        ),
        (
            "GridSearchCV",
            GridSearchCV(
                estimator=XGBClassifier(random_state=42),
                param_grid=parameters,
                n_jobs=-1,
                scoring="accuracy",
                label_cols="SURVIVED",
            ),
        ),
    ],
)
pipeline.fit(train_df)
result_df = pipeline.predict_proba(train_df)
result_df = result_df.with_column(
    "OUTPUT_SURVIVED", F.col("predict_proba_1").cast(T.LongType())
)
result_df.show()

  success, nchunks, nrows, ci_output = write_pandas(
  success, nchunks, nrows, ci_output = write_pandas(
  success, nchunks, nrows, ci_output = write_pandas(


-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"SEX_MALE"  |"CLASS_SECOND"  |"CLASS_THIRD"  |"WHO_MAN"  |"WHO_WOMAN"  |"EMBARK_TOWN_QUEENSTOWN"  |"EMBARK_TOWN_SOUTHAMPTON"  |"SURVIVED"  |"PCLASS"  |"SIBSP"  |"PARCH"  |"FARE"   |"ALONE"  |"PREDICT_PROBA_0"     |"PREDICT_PROBA_1"    |"OUTPUT_SURVIVED"  |
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|1.0         |0.0             |1.0            |1.0        |0.0          |0.0                       |1.0                        |0           |3         |1        |0        |7.25     |False    |0.9325933456420898    |0.067406632

In [8]:
metrics = {
    "Accuracy": accuracy_score(
        df=result_df,
        y_true_col_names="SURVIVED",
        y_pred_col_names="OUTPUT_SURVIVED",
    ),
    "Precision": precision_score(
        df=result_df,
        y_true_col_names="SURVIVED",
        y_pred_col_names="OUTPUT_SURVIVED",
    ),
    "Recall": recall_score(
        df=result_df,
        y_true_col_names="SURVIVED",
        y_pred_col_names="OUTPUT_SURVIVED",
    ),
    "F1 Score": f1_score(
        df=result_df,
        y_true_col_names="SURVIVED",
        y_pred_col_names="OUTPUT_SURVIVED",
    ),
    "Confusion Matrix": confusion_matrix(
        df=result_df, y_true_col_name="SURVIVED", y_pred_col_name="OUTPUT_SURVIVED"
    ).tolist(),
}

DataFrame.flatten() is deprecated since 0.7.0. Use `DataFrame.join_table_function()` instead.


In [13]:
metrics

{'Accuracy': 0.887931,
 'Precision': 0.8961038961038961,
 'Recall': 0.7931034482758621,
 'F1 Score': 0.8414634146341463,
 'Confusion Matrix': [[411.0, 24.0], [54.0, 207.0]]}

In [14]:
X = train_df.drop("SURVIVED").limit(100)

reg = Registry(session=session)

titanic_model = reg.log_model(
    model_name="TITANIC_PIPE",
    version_name=get_next_version(reg, "TITANIC_PIPE"),
    model=pipeline,
    metrics=metrics,
)

  return next(self.gen)


In [15]:
m = reg.get_model("TITANIC_PIPE")
m.default_version = get_version_with_highest_accuracy(reg, "TITANIC_PIPE")

## Call pipeline from SQL

Show the data is not cleaned before performming inference

In [17]:
test_df = session.table("TEST")
test_df.show()

-------------------------------------------------------------------------------------------------------------
|"SURVIVED"  |"PCLASS"  |"SIBSP"  |"PARCH"  |"FARE"    |"ALONE"  |"SEX"   |"CLASS"  |"WHO"  |"EMBARK_TOWN"  |
-------------------------------------------------------------------------------------------------------------
|1           |3         |0        |0        |7.9250    |True     |FEMALE  |THIRD    |WOMAN  |SOUTHAMPTON    |
|0           |3         |0        |0        |8.4583    |True     |MALE    |THIRD    |MAN    |QUEENSTOWN     |
|0           |1         |0        |0        |51.8625   |True     |MALE    |FIRST    |MAN    |SOUTHAMPTON    |
|0           |3         |0        |0        |8.0500    |True     |MALE    |THIRD    |MAN    |SOUTHAMPTON    |
|1           |3         |0        |0        |8.0292    |True     |FEMALE  |THIRD    |CHILD  |QUEENSTOWN     |
|0           |3         |0        |0        |7.8958    |True     |MALE    |THIRD    |MAN    |SOUTHAMPTON    |
|1        

Run the pipeline

In [20]:
# Copy this code in a snowflake worksheet or run via session.sql
inference_df = session.sql(
    """
select *, TITANIC_PIPE!predict_proba(*):PREDICT_PROBA_1
as surv_pred
from (
select * exclude survived
from test)
            """
)
inference_df.show()

-----------------------------------------------------------------------------------------------------------------------
|"PCLASS"  |"SIBSP"  |"PARCH"  |"FARE"    |"ALONE"  |"SEX"   |"CLASS"  |"WHO"  |"EMBARK_TOWN"  |"SURV_PRED"           |
-----------------------------------------------------------------------------------------------------------------------
|3         |0        |0        |7.9250    |True     |FEMALE  |THIRD    |WOMAN  |SOUTHAMPTON    |0.5756063461303711    |
|3         |0        |0        |8.4583    |True     |MALE    |THIRD    |MAN    |QUEENSTOWN     |0.06476970762014389   |
|1         |0        |0        |51.8625   |True     |MALE    |FIRST    |MAN    |SOUTHAMPTON    |0.07613715529441833   |
|3         |0        |0        |8.0500    |True     |MALE    |THIRD    |MAN    |SOUTHAMPTON    |0.1316869705915451    |
|3         |0        |0        |8.0292    |True     |FEMALE  |THIRD    |CHILD  |QUEENSTOWN     |0.5572702884674072    |
|3         |0        |0        |7.8958  

In [22]:
# Copy this code in a snowflake worksheet or run via session.sql
inference_df = session.sql(
    """
select *, TITANIC_PIPE!predict(*):OUTPUT_SURVIVED
as surv_pred
from (
select * exclude survived
from test)
            """
)
inference_df.show()

--------------------------------------------------------------------------------------------------------------
|"PCLASS"  |"SIBSP"  |"PARCH"  |"FARE"    |"ALONE"  |"SEX"   |"CLASS"  |"WHO"  |"EMBARK_TOWN"  |"SURV_PRED"  |
--------------------------------------------------------------------------------------------------------------
|3         |0        |0        |7.9250    |True     |FEMALE  |THIRD    |WOMAN  |SOUTHAMPTON    |1            |
|3         |0        |0        |8.4583    |True     |MALE    |THIRD    |MAN    |QUEENSTOWN     |0            |
|1         |0        |0        |51.8625   |True     |MALE    |FIRST    |MAN    |SOUTHAMPTON    |0            |
|3         |0        |0        |8.0500    |True     |MALE    |THIRD    |MAN    |SOUTHAMPTON    |0            |
|3         |0        |0        |8.0292    |True     |FEMALE  |THIRD    |CHILD  |QUEENSTOWN     |1            |
|3         |0        |0        |7.8958    |True     |MALE    |THIRD    |MAN    |SOUTHAMPTON    |0            |
|