https://scikit-learn.org/1.1/modules/ensemble.html#random-forests

## Training and Logging

In [1]:
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
(X.shape, y.shape)

((569, 30), (569,))

In [2]:
[(X[0], y[0]), (X[568], y[568])]

[(array([1.799e+01, 1.038e+01, 1.228e+02, 1.001e+03, 1.184e-01, 2.776e-01,
         3.001e-01, 1.471e-01, 2.419e-01, 7.871e-02, 1.095e+00, 9.053e-01,
         8.589e+00, 1.534e+02, 6.399e-03, 4.904e-02, 5.373e-02, 1.587e-02,
         3.003e-02, 6.193e-03, 2.538e+01, 1.733e+01, 1.846e+02, 2.019e+03,
         1.622e-01, 6.656e-01, 7.119e-01, 2.654e-01, 4.601e-01, 1.189e-01]),
  0),
 (array([7.760e+00, 2.454e+01, 4.792e+01, 1.810e+02, 5.263e-02, 4.362e-02,
         0.000e+00, 0.000e+00, 1.587e-01, 5.884e-02, 3.857e-01, 1.428e+00,
         2.548e+00, 1.915e+01, 7.189e-03, 4.660e-03, 0.000e+00, 0.000e+00,
         2.676e-02, 2.783e-03, 9.456e+00, 3.037e+01, 5.916e+01, 2.686e+02,
         8.996e-02, 6.444e-02, 0.000e+00, 0.000e+00, 2.871e-01, 7.039e-02]),
  1)]

In [4]:
import getpass

import mlflow
from rikai_sklearn.mlflow import log_model
from sklearn.ensemble import RandomForestClassifier


mlflow_tracking_uri = "sqlite:///mlruns.db"
mlflow.set_tracking_uri(mlflow_tracking_uri)

# train a model
with mlflow.start_run() as run:
    ####
    # Part 1: Train the model and register it on MLflow
    ####
    model = RandomForestClassifier(max_depth=2, random_state=0)
    model.fit(X, y)

    registered_model_name = f"{getpass.getuser()}_random_forest_clf"
    log_model(model, registered_model_name=registered_model_name)


Successfully registered model 'da_random_forest_clf'.
2022/11/15 19:25:03 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_random_forest_clf, version 1
Created version '1' of model 'da_random_forest_clf'.


## Apply the model on large scale dataset

In [5]:
from example import spark
from rikai.spark.sql.codegen.mlflow_logger import CONF_MLFLOW_TRACKING_URI
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")
spark.conf.set(CONF_MLFLOW_TRACKING_URI, mlflow_tracking_uri)
spark.sql(f"""
CREATE OR REPLACE MODEL mlflow_sklearn_m USING 'mlflow:///{registered_model_name}';
"""
)

spark.sql("show models").show(1, vertical=False, truncate=False)


:: loading settings :: url = jar:file:/Users/da/.cache/pants/named_caches/pex_root/installed_wheels/fcaa57f02b90be772d50778078fc41c3660d5a6c43218e45b2c2aef2ec8e9d58/pyspark-3.2.2-py2.py3-none-any.whl/pyspark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /Users/da/.ivy2/cache
The jars for the packages stored in: /Users/da/.ivy2/jars
ai.eto#rikai_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-abd5cf9b-fba6-4bff-818b-3d9301bdc9c9;1.0
	confs: [default]
	found ai.eto#rikai_2.12;0.1.14 in central
	found org.xerial.snappy#snappy-java;1.1.8.4 in central
	found com.typesafe.scala-logging#scala-logging_2.12;3.9.4 in central
	found org.slf4j#slf4j-api;1.7.30 in spark-list
	found org.mlflow#mlflow-client;1.21.0 in central
	found org.apache.logging.log4j#log4j-core;2.17.1 in central
:: resolution report :: resolve 213ms :: artifacts dl 8ms
	:: modules in use:
	ai.eto#rikai_2.12;0.1.14 from central in [default]
	com.typesafe.scala-logging#scala-logging_2.12;3.9.4 from central in [default]
	org.apache.logging.log4j#log4j-core;2.17.1 from central in [default]
	org.mlflow#mlflow-client;1.21.0 from central in [default]
	org.slf4j#slf4j-api;1.7.30 from spark-list in [default]
	org.x

+----------------+------+------------------------------+-------+
|name            |flavor|uri                           |options|
+----------------+------+------------------------------+-------+
|mlflow_sklearn_m|      |mlflow:///da_random_forest_clf|       |
+----------------+------+------------------------------+-------+



In [6]:
result = spark.sql(f"""
select ML_PREDICT(mlflow_sklearn_m, array(1.799e+01, 1.038e+01, 1.228e+02, 1.001e+03, 1.184e-01, 2.776e-01,
        3.001e-01, 1.471e-01, 2.419e-01, 7.871e-02, 1.095e+00, 9.053e-01,
        8.589e+00, 1.534e+02, 6.399e-03, 4.904e-02, 5.373e-02, 1.587e-02,
        3.003e-02, 6.193e-03, 2.538e+01, 1.733e+01, 1.846e+02, 2.019e+03,
        1.622e-01, 6.656e-01, 7.119e-01, 2.654e-01, 4.601e-01, 1.189e-01))
"""
)

result.printSchema()
result.toPandas()

root
 |-- mlflow_sklearn_m: integer (nullable = true)



                                                                                

Unnamed: 0,mlflow_sklearn_m
0,0


In [7]:
result = spark.sql(f"""
select ML_PREDICT(mlflow_sklearn_m, array(7.760e+00, 2.454e+01, 4.792e+01, 1.810e+02, 5.263e-02, 4.362e-02,
        0.000e+00, 0.000e+00, 1.587e-01, 5.884e-02, 3.857e-01, 1.428e+00,
        2.548e+00, 1.915e+01, 7.189e-03, 4.660e-03, 0.000e+00, 0.000e+00,
        2.676e-02, 2.783e-03, 9.456e+00, 3.037e+01, 5.916e+01, 2.686e+02,
        8.996e-02, 6.444e-02, 0.000e+00, 0.000e+00, 2.871e-01, 7.039e-02))
"""
)

result.printSchema()
result.toPandas()

root
 |-- mlflow_sklearn_m: integer (nullable = true)



Unnamed: 0,mlflow_sklearn_m
0,1
