## Step 1: Prepare the Training and Evaluation Dataset

In [1]:
from sklearn.datasets import load_diabetes
X, y = load_diabetes(return_X_y=True)
(X.shape, y.shape)

((442, 10), (442,))

In [2]:
[(X[0], y[0]), (X[1], y[1])]

[(array([ 0.03807591,  0.05068012,  0.06169621,  0.02187239, -0.0442235 ,
         -0.03482076, -0.04340085, -0.00259226,  0.01990749, -0.01764613]),
  151.0),
 (array([-0.00188202, -0.04464164, -0.05147406, -0.02632753, -0.00844872,
         -0.01916334,  0.07441156, -0.03949338, -0.06833155, -0.09220405]),
  75.0)]

## Step 2: Launching the Spark Session with Rikai Extension

In [3]:
from example import spark

23/01/06 09:33:52 WARN Utils: Your hostname, debian resolves to a loopback address: 127.0.1.1; using 192.168.31.194 instead (on interface wlx1cbfce3ffbfe)
23/01/06 09:33:52 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
:: loading settings :: url = jar:file:/home/da/.cache/pants/named_caches/pex_root/installed_wheels/8f254bc20b539246427b2913639b8a0258db76ab54ba91fbbebb8dc8c36183c1/pyspark-3.3.1-py2.py3-none-any.whl/pyspark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/da/.ivy2/cache
The jars for the packages stored in: /home/da/.ivy2/jars
ai.eto#rikai_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-5aa137f5-e60a-4940-888e-94344c67fcdc;1.0
	confs: [default]
	found ai.eto#rikai_2.12;0.2.0-SNAPSHOT in local-ivy-cache
	found org.xerial.snappy#snappy-java;1.1.8.4 in central
	found com.typesafe.scala-logging#scala-logging_2.12;3.9.4 in central
	found org.slf4j#slf4j-api;1.7.30 in spark-list
	found org.mlflow#mlflow-client;1.21.0 in central
	found org.apache.logging.log4j#log4j-core;2.17.1 in central
downloading /home/da/.ivy2/local/ai.eto/rikai_2.12/0.2.0-SNAPSHOT/jars/rikai-spark331-assembly_2.12.jar ...
	[SUCCESSFUL ] ai.eto#rikai_2.12;0.2.0-SNAPSHOT!rikai-spark331-assembly_2.12.jar (10ms)
downloading /home/da/.ivy2/local/ai.eto/rikai_2.12/0.2.0-SNAPSHOT/jars/rikai_2.12.jar ...
	[SUCCESSFUL ] ai.eto#rikai_2.12;0.2.0-SNAPSHOT!rikai_2.12.jar (1ms)
downloading https://repo1.maven.

23/01/06 09:34:22 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/01/06 09:34:23 WARN SparkContext: The path file:///home/da/.ivy2/jars/ai.eto_rikai_2.12-0.2.0-SNAPSHOT.jar has been added already. Overwriting of added paths is not supported in the current version.


## Step 3: Training and Logging using MLflow

In [4]:
import getpass

import mlflow
from liga.sklearn.mlflow import log_model
from sklearn.linear_model import LinearRegression


mlflow_tracking_uri = "sqlite:///mlruns.db"
mlflow.set_tracking_uri(mlflow_tracking_uri)

# train a model
model = LinearRegression()
with mlflow.start_run() as run:
    ####
    # Part 1: Train the model and register it on MLflow
    ####
    model.fit(X, y)

    registered_model_name = f"{getpass.getuser()}_sklearn_lr"
    log_model(model, registered_model_name)


2023/01/06 09:35:03 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2023/01/06 09:35:03 INFO mlflow.store.db.utils: Updating database tables
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.runtime.migration] Running upgrade  -> 451aebb31d03, add metric step
INFO  [alembic.runtime.migration] Running upgrade 451aebb31d03 -> 90e64c465722, migrate user column to tags
INFO  [alembic.runtime.migration] Running upgrade 90e64c465722 -> 181f10493468, allow nulls for metric values
INFO  [alembic.runtime.migration] Running upgrade 181f10493468 -> df50e92ffc5e, Add Experiment Tags Table
INFO  [alembic.runtime.migration] Running upgrade df50e92ffc5e -> 7ac759974ad8, Update run tags with larger limit
INFO  [alembic.runtime.migration] Running upgrade 7ac759974ad8 -> 89d4b8295536, create latest metrics table
INFO  [89d4b8295536_create_latest_metrics_table_py] Migration complete!
INFO  

## Step 4: Create the model using the registered MLflow uri

In [5]:
from rikai.spark.sql.codegen.mlflow_logger import CONF_MLFLOW_TRACKING_URI
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")
spark.conf.set(CONF_MLFLOW_TRACKING_URI, mlflow_tracking_uri)
spark.sql(f"""
CREATE OR REPLACE MODEL mlflow_sklearn_m USING 'mlflow:///{registered_model_name}';
"""
)

spark.sql("show models").show(1, vertical=False, truncate=False)


+----------------+------+-----------------------+-------+
|name            |flavor|uri                    |options|
+----------------+------+-----------------------+-------+
|mlflow_sklearn_m|      |mlflow:///da_sklearn_lr|       |
+----------------+------+-----------------------+-------+



## Step 5: predict using the registered Rikai model

In [6]:
from rikai_sklearn.numpy import array_to_literal

result = spark.sql(f"""
select
  ML_PREDICT(mlflow_sklearn_m, {array_to_literal(X[0])})
"""
)

result.printSchema()
result.toPandas()

root
 |-- mlflow_sklearn_m: float (nullable = true)



                                                                                

Unnamed: 0,mlflow_sklearn_m
0,206.116684


In [7]:
spark.sql(f"""
select  ML_PREDICT(mlflow_sklearn_m, {array_to_literal(X[1])})
""").toPandas()

Unnamed: 0,mlflow_sklearn_m
0,68.07103
