## Step 1: Prepare the Training and Evaluation Dataset

In [1]:
import numpy as np

# prepare training data
X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
y = np.dot(X, np.array([1, 2])) + 3

# prepare evaluation data
X_eval = np.array([[3, 3], [3, 4]])
y_eval = np.dot(X_eval, np.array([1, 2])) + 3

## Step 2: Launching the Spark Session with Rikai Extension

In [2]:
from example import spark

:: loading settings :: url = jar:file:/Users/da/.cache/pants/named_caches/pex_root/installed_wheels/2af3ba1a0b98d4936a6b141f1e78958d6eb400c51361fed2a7baa49e97f8a312/pyspark-3.2.0-py2.py3-none-any.whl/pyspark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /Users/da/.ivy2/cache
The jars for the packages stored in: /Users/da/.ivy2/jars
ai.eto#rikai_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-88a1fd33-5207-4c44-b6bb-514e84780706;1.0
	confs: [default]
	found ai.eto#rikai_2.12;0.1.14 in central
	found org.xerial.snappy#snappy-java;1.1.8.4 in central
	found com.typesafe.scala-logging#scala-logging_2.12;3.9.4 in central
	found org.slf4j#slf4j-api;1.7.30 in spark-list
	found org.mlflow#mlflow-client;1.21.0 in central
	found org.apache.logging.log4j#log4j-core;2.17.1 in central
:: resolution report :: resolve 225ms :: artifacts dl 8ms
	:: modules in use:
	ai.eto#rikai_2.12;0.1.14 from central in [default]
	com.typesafe.scala-logging#scala-logging_2.12;3.9.4 from central in [default]
	org.apache.logging.log4j#log4j-core;2.17.1 from central in [default]
	org.mlflow#mlflow-client;1.21.0 from central in [default]
	org.slf4j#slf4j-api;1.7.30 from spark-list in [default]
	org.x

## Step 3: Training and Logging using MLflow

In [3]:
import getpass

import mlflow
import rikai
from sklearn.linear_model import LinearRegression


mlflow_tracking_uri = "sqlite:///mlruns.db"
mlflow.set_tracking_uri(mlflow_tracking_uri)

# enable autologging
mlflow.sklearn.autolog()

# train a model
model = LinearRegression()
with mlflow.start_run() as run:
    ####
    # Part 1: Train the model and register it on MLflow
    ####
    model.fit(X, y)
    metrics = mlflow.sklearn.eval_and_log_metrics(model, X_eval, y_eval, prefix="val_")

    schema = "float"
    registered_model_name = f"{getpass.getuser()}_sklearn_lr"
    rikai.mlflow.sklearn.log_model(
        model,
        "model",
        schema,
        registered_model_name=registered_model_name,
        model_type="linear_regression",
    )


Registered model 'da_sklearn_lr' already exists. Creating a new version of this model...
2022/11/07 16:01:16 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_sklearn_lr, version 5
Created version '5' of model 'da_sklearn_lr'.


## Step 4: Create the model using the registered MLflow uri

In [4]:
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")
spark.conf.set("spark.rikai.sql.ml.registry.mlflow.tracking_uri", mlflow_tracking_uri)
spark.sql(
    f"""
CREATE OR REPLACE MODEL mlflow_sklearn_m USING 'mlflow:///{registered_model_name}';
"""
)

spark.sql("show models").show(1, vertical=False, truncate=False)


+----------------+------+-----------------------+-------+
|name            |flavor|uri                    |options|
+----------------+------+-----------------------+-------+
|mlflow_sklearn_m|      |mlflow:///da_sklearn_lr|       |
+----------------+------+-----------------------+-------+



## Step 5: predict using the registered Rikai model

In [5]:
df = spark.range(100).selectExpr("id as x0", "id+1 as x1")
df.createOrReplaceTempView("tbl_X")

result = spark.sql(f"""
select ML_PREDICT(mlflow_sklearn_m, array(x0, x1)) from tbl_X
"""
)

result.printSchema()

root
 |-- mlflow_sklearn_m: float (nullable = true)



In [6]:
result.toPandas()

                                                                                

Unnamed: 0,mlflow_sklearn_m
0,5.0
1,8.0
2,11.0
3,14.0
4,17.0
...,...
95,290.0
96,293.0
97,296.0
98,299.0
