https://scikit-learn.org/stable/modules/linear_model.html#classification

## Training and Logging

In [1]:
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
(X.shape, y.shape)

((569, 30), (569,))

In [2]:
[(X[0], y[0]), (X[568], y[568])]

[(array([1.799e+01, 1.038e+01, 1.228e+02, 1.001e+03, 1.184e-01, 2.776e-01,
         3.001e-01, 1.471e-01, 2.419e-01, 7.871e-02, 1.095e+00, 9.053e-01,
         8.589e+00, 1.534e+02, 6.399e-03, 4.904e-02, 5.373e-02, 1.587e-02,
         3.003e-02, 6.193e-03, 2.538e+01, 1.733e+01, 1.846e+02, 2.019e+03,
         1.622e-01, 6.656e-01, 7.119e-01, 2.654e-01, 4.601e-01, 1.189e-01]),
  0),
 (array([7.760e+00, 2.454e+01, 4.792e+01, 1.810e+02, 5.263e-02, 4.362e-02,
         0.000e+00, 0.000e+00, 1.587e-01, 5.884e-02, 3.857e-01, 1.428e+00,
         2.548e+00, 1.915e+01, 7.189e-03, 4.660e-03, 0.000e+00, 0.000e+00,
         2.676e-02, 2.783e-03, 9.456e+00, 3.037e+01, 5.916e+01, 2.686e+02,
         8.996e-02, 6.444e-02, 0.000e+00, 0.000e+00, 2.871e-01, 7.039e-02]),
  1)]

In [3]:
import getpass

import mlflow
from liga.sklearn.mlflow import log_model
from sklearn.linear_model import RidgeClassifier


mlflow_tracking_uri = "sqlite:///mlruns.db"
mlflow.set_tracking_uri(mlflow_tracking_uri)

# train a model
with mlflow.start_run() as run:
    ####
    # Part 1: Train the model and register it on MLflow
    ####
    model = RidgeClassifier().fit(X, y)

    registered_model_name = f"{getpass.getuser()}_ridge_clf"
    log_model(model, registered_model_name=registered_model_name)


Registered model 'da_ridge_clf' already exists. Creating a new version of this model...
2023/01/09 21:31:17 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_ridge_clf, version 2
Created version '2' of model 'da_ridge_clf'.


## Apply the model on large scale dataset

In [4]:
from example import spark
from liga.mlflow.logger import CONF_MLFLOW_TRACKING_URI
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")
spark.conf.set(CONF_MLFLOW_TRACKING_URI, mlflow_tracking_uri)
spark.sql(f"""
CREATE OR REPLACE MODEL mlflow_sklearn_m USING 'mlflow:///{registered_model_name}';
"""
)

spark.sql("show models").show(1, vertical=False, truncate=False)


23/01/09 21:31:18 WARN Utils: Your hostname, debian resolves to a loopback address: 127.0.1.1; using 192.168.31.194 instead (on interface wlx1cbfce3ffbfe)
23/01/09 21:31:18 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
23/01/09 21:31:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/01/09 21:31:24 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
23/01/09 21:31:24 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
23/01/09 21:31:24 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.
23/01/09 21:31:24 WARN Utils: Service 'SparkUI' could not bind on port 4043. Attempting port 4044.


+----------------+------+----------------------+-------+
|name            |flavor|uri                   |options|
+----------------+------+----------------------+-------+
|mlflow_sklearn_m|      |mlflow:///da_ridge_clf|       |
+----------------+------+----------------------+-------+



In [5]:
result = spark.sql(f"""
select ML_PREDICT(mlflow_sklearn_m, array(1.799e+01, 1.038e+01, 1.228e+02, 1.001e+03, 1.184e-01, 2.776e-01,
        3.001e-01, 1.471e-01, 2.419e-01, 7.871e-02, 1.095e+00, 9.053e-01,
        8.589e+00, 1.534e+02, 6.399e-03, 4.904e-02, 5.373e-02, 1.587e-02,
        3.003e-02, 6.193e-03, 2.538e+01, 1.733e+01, 1.846e+02, 2.019e+03,
        1.622e-01, 6.656e-01, 7.119e-01, 2.654e-01, 4.601e-01, 1.189e-01))
"""
)

result.printSchema()
result.toPandas()

root
 |-- mlflow_sklearn_m: integer (nullable = true)



                                                                                

Unnamed: 0,mlflow_sklearn_m
0,0


In [6]:
result = spark.sql(f"""
select ML_PREDICT(mlflow_sklearn_m, array(7.760e+00, 2.454e+01, 4.792e+01, 1.810e+02, 5.263e-02, 4.362e-02,
        0.000e+00, 0.000e+00, 1.587e-01, 5.884e-02, 3.857e-01, 1.428e+00,
        2.548e+00, 1.915e+01, 7.189e-03, 4.660e-03, 0.000e+00, 0.000e+00,
        2.676e-02, 2.783e-03, 9.456e+00, 3.037e+01, 5.916e+01, 2.686e+02,
        8.996e-02, 6.444e-02, 0.000e+00, 0.000e+00, 2.871e-01, 7.039e-02))
"""
)

result.printSchema()
result.toPandas()

root
 |-- mlflow_sklearn_m: integer (nullable = true)



Unnamed: 0,mlflow_sklearn_m
0,1
