https://scikit-learn.org/1.1/modules/decomposition.html#principal-component-analysis-pca

## Training and Logging

In [1]:
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
(X.shape, y.shape)

((150, 4), (150,))

In [2]:
[(X[0], y[0]), (X[1], y[1])]

[(array([5.1, 3.5, 1.4, 0.2]), 0), (array([4.9, 3. , 1.4, 0.2]), 0)]

In [3]:
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
X_r = pca.fit(X).transform(X)
X_r[0]

array([-2.68412563,  0.31939725])

In [4]:
import getpass

import mlflow
from liga.sklearn.mlflow import log_model
from sklearn.decomposition import PCA


mlflow_tracking_uri = "sqlite:///mlruns.db"
mlflow.set_tracking_uri(mlflow_tracking_uri)

# train a model
with mlflow.start_run() as run:
    ####
    # Part 1: Train the model and register it on MLflow
    ####
    model = PCA(n_components=2)
    model.fit(X, y)
    registered_model_name = f"{getpass.getuser()}_pca"
    log_model(model, registered_model_name=registered_model_name)


Registered model 'da_pca' already exists. Creating a new version of this model...
2023/01/09 21:32:59 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_pca, version 2
Created version '2' of model 'da_pca'.


## Apply the model on large scale dataset

In [5]:
from example import spark

from liga.mlflow.logger import CONF_MLFLOW_TRACKING_URI
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")
spark.conf.set(CONF_MLFLOW_TRACKING_URI, mlflow_tracking_uri)
spark.sql(f"""
CREATE OR REPLACE MODEL mlflow_sklearn_m USING 'mlflow:///{registered_model_name}';
"""
)

spark.sql("show models").show(1, vertical=False, truncate=False)


23/01/09 21:33:00 WARN Utils: Your hostname, debian resolves to a loopback address: 127.0.1.1; using 192.168.31.194 instead (on interface wlx1cbfce3ffbfe)
23/01/09 21:33:00 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
23/01/09 21:33:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/01/09 21:33:08 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
23/01/09 21:33:08 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
23/01/09 21:33:08 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.
23/01/09 21:33:08 WARN Utils: Service 'SparkUI' could not bind on port 4043. Attempting port 4044.
23/01/09

+----------------+------+----------------+-------+
|name            |flavor|uri             |options|
+----------------+------+----------------+-------+
|mlflow_sklearn_m|      |mlflow:///da_pca|       |
+----------------+------+----------------+-------+



In [6]:
from liga.numpy.sql import literal

result = spark.sql(f"""
select
  ML_PREDICT(mlflow_sklearn_m, {literal(X[0])})
"""
)

result.printSchema()
result.toPandas()

root
 |-- mlflow_sklearn_m: array (nullable = true)
 |    |-- element: float (containsNull = true)



                                                                                

Unnamed: 0,mlflow_sklearn_m
0,"[-2.6841256618499756, 0.3193972408771515]"
