In [11]:
# Snowpark for Python
import snowflake.snowpark
from snowflake.snowpark import Session
from snowflake.snowpark.functions import udf, sproc

In [12]:
cn_params = {
    "user": "**********",
    "password": "*********",
    "account": "************", 
    "warehouse": "ANALYSIS_WH",
    "database" : "EDWPRODHH",
    "schema" : "HERMES"
}

In [13]:
snowpark_session = Session.builder.configs(cn_params).create()

In [14]:
snowpark_session.add_packages('snowflake-snowpark-python', 'xgboost', 'pandas', 'numpy', 'joblib', 'cachetools')

The version of package xgboost in the local environment is 1.7.5, which does not fit the criteria for the requirement xgboost. Your UDF might not work when the package version is different between the server and your local environment
The version of package cachetools in the local environment is 5.3.0, which does not fit the criteria for the requirement cachetools. Your UDF might not work when the package version is different between the server and your local environment


In [15]:
def train_debtor_v1 (session: Session) -> str:
    
    from joblib import dump
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import r2_score
    from xgboost import XGBRegressor
    import pandas as pd
    
    # Read data   
    df = session.sql("""
    SELECT *
    FROM EDWPRODHH.PUB_MBUTLER.MASTER_DIALER_MODEL_DEBTOR
    """).to_pandas()
    labels = df["DOL_COMMISSION_ATTR"]
    features = df.drop(["DOL_COMMISSION_ATTR"], axis = 1)

    # Split the data
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.3, random_state=42)

    # Define the model
    model = XGBRegressor(n_estimators=10, learning_rate=0.1, max_depth = 9, n_jobs=16, random_state=42)

    # Train the model
    model.fit(X_train, y_train)

    # Make predictions
    y_pred = model.predict(X_test)

    # Compute R-squared score
    r2score = r2_score(y_test, y_pred)

    # Save the model
    dump_path = "/tmp/train_debtor_v1.joblib"
    dump(model, dump_path)
    session.file.put(dump_path, "@prod_models", overwrite=True)
    
    print("R-squared: ", r2score)
    return "Model trained and saved with R-squared: " + str(r2score) + "."


In [16]:
snowpark_session.sproc.register(
    func = train_debtor_v1,
    name = "train_debtor_v1",
    replace = True
)

<snowflake.snowpark.stored_procedure.StoredProcedure at 0x1a5f7caa910>

In [17]:
snowpark_session.call("train_debtor_v1")

'Model trained and saved with R-squared: 0.2527832813566191.'

In [18]:

snowpark_session.add_import("@edwprodhh.hermes.prod_models/train_debtor_v1.joblib.gz")

In [19]:
@udf (name = "prod_predict_v1_debtor", stage_location = '@prod_models', session = snowpark_session, packages = ["pandas", "joblib", "scikit-learn", "xgboost"], replace = True)
def predict_Debtor_rev_v1 (inputs: list) -> float:
    
    import sys
    import pandas as pd
    import numpy as np
    from joblib import load
        
        
    IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
    import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
    
    model_file = import_dir + "train_debtor_v1.joblib.gz"
    model = load(model_file)
    
    df = pd.DataFrame(
        [inputs],
        columns = [
              'ASSIGNED_AMT',
              'DEBT_AGE',
               'EXPERIAN_SCORE',
                'MEDIAN_HOUSEHOLD_INCOME', 
               'HAS_PREVIOUS_PAYMENT', 
                'IS_ONLY_DEBTOR_IN_PACKET',
                'PARKING',
                'TOLL', 
                'AI',
                'SP', 
                'HAS_EMAIL'
        ]
    )
    df['EXPERIAN_SCORE'] = pd.to_numeric(df['EXPERIAN_SCORE'], errors='coerce')
    df['MEDIAN_HOUSEHOLD_INCOME'] = pd.to_numeric(df['MEDIAN_HOUSEHOLD_INCOME'], errors='coerce')
    df['ASSIGNED_AMT'] = pd.to_numeric(df['ASSIGNED_AMT'], errors='coerce')
    
    y_pred = model.predict(df)[0]
    y_pred = np.clip(y_pred, a_min=0, a_max=None)
    return y_pred

The version of package xgboost in the local environment is 1.7.5, which does not fit the criteria for the requirement xgboost. Your UDF might not work when the package version is different between the server and your local environment
