https://scikit-learn.org/1.1/modules/clustering.html#k-means

## Training and Logging

In [1]:
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
(X.shape, y.shape)

((150, 4), (150,))

In [2]:
(y[0], y[50], y[100])

(0, 1, 2)

In [3]:
import getpass

import mlflow
from liga.sklearn.mlflow import log_model
from sklearn.cluster import KMeans


mlflow_tracking_uri = "sqlite:///mlruns.db"
mlflow.set_tracking_uri(mlflow_tracking_uri)

# train a model
with mlflow.start_run() as run:
    ####
    # Part 1: Train the model and register it on MLflow
    ####
    model = KMeans(n_clusters=3, random_state=0)
    model.fit(X)
    registered_model_name = f"{getpass.getuser()}_kmeans"
    log_model(model, registered_model_name=registered_model_name)


Successfully registered model 'da_kmeans'.
2023/02/16 15:43:52 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_kmeans, version 1
Created version '1' of model 'da_kmeans'.


## Apply the model on large scale dataset

In [4]:
from example import spark
from liga.mlflow import CONF_MLFLOW_TRACKING_URI
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")
spark.conf.set(CONF_MLFLOW_TRACKING_URI, mlflow_tracking_uri)
spark.sql(f"""
CREATE OR REPLACE MODEL mlflow_sklearn_m LOCATION 'mlflow:///{registered_model_name}';
"""
)

spark.sql("show models").show(1, vertical=False, truncate=False)


2023-02-16 15:43:52,691 INFO Rikai (__init__.py:127): setting spark.sql.extensions to net.xmacs.liga.spark.RikaiSparkSessionExtensions
2023-02-16 15:43:52,691 INFO Rikai (__init__.py:127): setting spark.driver.extraJavaOptions to -Dio.netty.tryReflectionSetAccessible=true
2023-02-16 15:43:52,691 INFO Rikai (__init__.py:127): setting spark.executor.extraJavaOptions to -Dio.netty.tryReflectionSetAccessible=true
2023-02-16 15:43:52,692 INFO Rikai (__init__.py:127): setting spark.jars to https://github.com/liga-ai/liga/releases/download/v0.2.2/liga-spark321-assembly_2.12-0.2.2.jar
23/02/16 15:43:53 WARN Utils: Your hostname, tubi resolves to a loopback address: 127.0.1.1; using 192.168.31.32 instead (on interface wlp0s20f3)
23/02/16 15:43:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR,

+----------------+------+-------------------+-------+
|name            |flavor|uri                |options|
+----------------+------+-------------------+-------+
|mlflow_sklearn_m|      |mlflow:///da_kmeans|       |
+----------------+------+-------------------+-------+



In [5]:
from liga.numpy.sql import literal

result = spark.sql(f"""
select
  ML_PREDICT(mlflow_sklearn_m, {literal(X[0])}) as y0,
  ML_PREDICT(mlflow_sklearn_m, {literal(X[50])}) as y50,
  ML_PREDICT(mlflow_sklearn_m, {literal(X[100])}) as y100
"""
)

result.printSchema()
result.toPandas()

root
 |-- y0: integer (nullable = true)
 |-- y50: integer (nullable = true)
 |-- y100: integer (nullable = true)



                                                                                

Unnamed: 0,y0,y50,y100
0,1,0,2
