https://scikit-learn.org/1.1/modules/clustering.html#spectral-clustering

## Training and Logging

In [2]:
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
(X.shape, y.shape)

((150, 4), (150,))

In [3]:
(y[0], y[50], y[100])

(0, 1, 2)

In [5]:
import getpass

import mlflow
from rikai_sklearn.mlflow import log_model


mlflow_tracking_uri = "sqlite:///mlruns.db"
mlflow.set_tracking_uri(mlflow_tracking_uri)

# train a model
with mlflow.start_run() as run:
    ####
    # Part 1: Train the model and register it on MLflow
    ####
    from sklearn.cluster import SpectralClustering
    model = SpectralClustering(n_clusters=3, assign_labels='discretize', random_state=0)
    model.fit(X)
    registered_model_name = f"{getpass.getuser()}_spectral"
    log_model(model, registered_model_name=registered_model_name)


Successfully registered model 'da_spectral'.
2022/11/16 16:43:12 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_spectral, version 1
Created version '1' of model 'da_spectral'.


## Apply the model on large scale dataset

In [6]:
from example import spark
from rikai.spark.sql.codegen.mlflow_logger import CONF_MLFLOW_TRACKING_URI
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")
spark.conf.set(CONF_MLFLOW_TRACKING_URI, mlflow_tracking_uri)
spark.sql(f"""
CREATE OR REPLACE MODEL mlflow_sklearn_m USING 'mlflow:///{registered_model_name}';
"""
)

spark.sql("show models").show(1, vertical=False, truncate=False)


:: loading settings :: url = jar:file:/Users/da/.cache/pants/named_caches/pex_root/installed_wheels/fcaa57f02b90be772d50778078fc41c3660d5a6c43218e45b2c2aef2ec8e9d58/pyspark-3.2.2-py2.py3-none-any.whl/pyspark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /Users/da/.ivy2/cache
The jars for the packages stored in: /Users/da/.ivy2/jars
ai.eto#rikai_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-7de99911-967c-4301-a0e8-44f8e510387f;1.0
	confs: [default]
	found ai.eto#rikai_2.12;0.1.14 in central
	found org.xerial.snappy#snappy-java;1.1.8.4 in central
	found com.typesafe.scala-logging#scala-logging_2.12;3.9.4 in central
	found org.slf4j#slf4j-api;1.7.30 in spark-list
	found org.mlflow#mlflow-client;1.21.0 in central
	found org.apache.logging.log4j#log4j-core;2.17.1 in central
:: resolution report :: resolve 217ms :: artifacts dl 7ms
	:: modules in use:
	ai.eto#rikai_2.12;0.1.14 from central in [default]
	com.typesafe.scala-logging#scala-logging_2.12;3.9.4 from central in [default]
	org.apache.logging.log4j#log4j-core;2.17.1 from central in [default]
	org.mlflow#mlflow-client;1.21.0 from central in [default]
	org.slf4j#slf4j-api;1.7.30 from spark-list in [default]
	org.x

+----------------+------+---------------------+-------+
|name            |flavor|uri                  |options|
+----------------+------+---------------------+-------+
|mlflow_sklearn_m|      |mlflow:///da_spectral|       |
+----------------+------+---------------------+-------+



In [7]:
from rikai_sklearn.numpy import array_to_literal

result = spark.sql(f"""
select
  ML_PREDICT(mlflow_sklearn_m, {array_to_literal(X[0])}) as y0,
  ML_PREDICT(mlflow_sklearn_m, {array_to_literal(X[50])}) as y50,
  ML_PREDICT(mlflow_sklearn_m, {array_to_literal(X[100])}) as y100
"""
)

result.printSchema()
result.toPandas()

root
 |-- y0: integer (nullable = true)
 |-- y50: integer (nullable = true)
 |-- y100: integer (nullable = true)



22/11/16 16:43:27 ERROR PythonUDFRunner: Python worker exited unexpectedly (crashed)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/da/.cache/pants/named_caches/pex_root/installed_wheels/fcaa57f02b90be772d50778078fc41c3660d5a6c43218e45b2c2aef2ec8e9d58/pyspark-3.2.2-py2.py3-none-any.whl/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 599, in main
    eval_type = read_int(infile)
  File "/Users/da/.cache/pants/named_caches/pex_root/installed_wheels/fcaa57f02b90be772d50778078fc41c3660d5a6c43218e45b2c2aef2ec8e9d58/pyspark-3.2.2-py2.py3-none-any.whl/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 564, in read_int
    raise EOFError
EOFError

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:556)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$2.read(PythonUDFRunner.scala:86)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$2.read(PythonUD

PythonException: 
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "/Users/da/.cache/pants/named_caches/pex_root/venvs/97296c1a2aa5705046eb53ad742c9a9d598e9ce1/95d120849befe49d6fe9386d83f19789f01f916b/lib/python3.8/site-packages/rikai/spark/sql/codegen/sklearn.py", line 52, in sklearn_inference_udf
    y = [_pickler.dumps(pred) for pred in model.predict(X)]
  File "/private/var/folders/6d/q8_1m19n1jzg6_3lghwylhpc0000gp/T/pants-sandbox-CqFuTr/./rikai_sklearn/models/cluster.py", line 13, in predict
    return self.model.predict(x).tolist()
AttributeError: 'SpectralClustering' object has no attribute 'predict'
