# MLflow quickstart: inference

This notebook shows how to load a model previously logged to MLflow and use it to make predictions on data in different formats. The notebook includes two examples of applying the model:
* as a scikit-learn model to a pandas DataFrame
* as a PySpark UDF to a Spark DataFrame
  
## Requirements
* This notebook requires Databricks Runtime 6.4 or above, or Databricks Runtime 6.4 ML or above. You can also use a Python 3 cluster running Databricks Runtime 5.5 LTS or Databricks Runtime 5.5 LTS ML.
* If you are using a cluster running Databricks Runtime, you must install MLflow. See "Install a library on a cluster" ([AWS](https://docs.databricks.com/libraries/cluster-libraries.html#install-a-library-on-a-cluster)|[Azure](https://docs.microsoft.com/azure/databricks/libraries/cluster-libraries#--install-a-library-on-a-cluster)|[GCP](https://docs.gcp.databricks.com/libraries/cluster-libraries.html#install-a-library-on-a-cluster)). Select **Library Source** PyPI and enter `mlflow` in the **Package** field.
* If you are using a cluster running Databricks Runtime ML, MLflow is already installed.  

## Prerequsite
* This notebook uses the ElasticNet models from MLflow quickstart part 1: training and logging ([AWS](https://docs.databricks.com/applications/mlflow/tracking-ex-scikit.html#training-quickstart)|[Azure](https://docs.microsoft.com/azure/databricks/applications/mlflow/tracking-ex-scikit#--training-quickstart)|[GCP](https://docs.gcp.databricks.com/applications/mlflow/tracking-ex-scikit.html#training-quickstart)).

## Find and copy the run ID of the run that created the model

Find and copy a run ID associated with an ElasticNet training run from the MLflow quickstart part 1: training and logging notebook. The run ID appears on the run details page; it is a 32-character alphanumeric string shown after the label "**Run**".  

To navigate to the run details page for the MLflow quickstart part 1: training and logging notebook, open that notebook and click **Experiment** in the upper right corner. The Experiments sidebar displays. Do one of the following:

* In the Experiments sidebar, click the icon at the far right of the date and time of the run. The run details page appears in a new tab. 

* Click the square icon with the arrow to the right of **Experiment Runs**. The Experiment page displays in a new tab. This page lists all of the runs associated with this notebook. To display the run details page for a particular run, click the link in the **Start Time** column for that run. 

For more information, see "View notebook experiment" ([AWS](https://docs.databricks.com/applications/mlflow/tracking.html#view-notebook-experiment)|[Azure](https://docs.microsoft.com/azure/databricks/applications/mlflow/tracking#view-notebook-experiment)|[GCP](https://docs.gcp.databricks.com/applications/mlflow/tracking.html#view-notebook-experiment)).

In [0]:
import mlflow
from pyspark.sql.functions import struct, col
# logged_model = 'runs:/1eaf6ff2b6aa4985b38e5e7a60656349/model'


# Replace <run-id1> with the run ID you identified in the previous step.
run_id1 = "1eaf6ff2b6aa4985b38e5e7a60656349"
model_uri = "runs:/" + run_id1 + "/model"



## Load the model as a scikit-learn model
Use the MLflow API to load the model from the MLflow server that was created by the run. After loading the model, you can use just like you would any scikit-learn model.

In [0]:
import mlflow.sklearn
model = mlflow.sklearn.load_model(model_uri=model_uri)
# model.coef_

In [0]:
# Import required libraries
from sklearn import datasets
import numpy as np
import pandas as pd

# Load diabetes datasets
diabetes = datasets.load_diabetes()
X = diabetes.data
y = diabetes.target

# Create pandas DataFrame for sklearn ElasticNet linear_model
Y = np.array([y]).transpose()
d = np.concatenate((X, Y), axis=1)
cols = ['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6', 'progression']
data = pd.DataFrame(d, columns=cols)



In [0]:
display(data.columns)


Index(['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6',
       'progression'],
      dtype='object')

In [0]:
# Get a prediction for a row of the dataset
model.predict(data[0:1].drop(["progression"], axis=1))



Out[5]: array([184.12785252])

## Create a PySpark UDF and use it for batch inference
In this section, you use the MLflow API to create a PySpark UDF from the model you saved to MLflow. For more information, see [Export a python_function model as an Apache Spark UDF](https://mlflow.org/docs/latest/models.html#export-a-python-function-model-as-an-apache-spark-udf).  

Saving the model as a PySpark UDF allows you to run the model to make predictions on a Spark DataFrame.

In [0]:
# Create the PySpark UDF
import mlflow.pyfunc
pyfunc_udf = mlflow.pyfunc.spark_udf(spark, model_uri=model_uri)

2023/07/17 10:18:52 INFO mlflow.models.flavor_backend_registry: Selected backend for flavor 'python_function'


In [0]:
# For the purposes of this example, create a small Spark DataFrame. This is the original pandas DataFrame without the label column.
dataframe = spark.createDataFrame(data.drop(["progression"], axis=1))

Use the Spark function `withColumn()` to apply the PySpark UDF to the DataFrame and return a new DataFrame with a `prediction` column.

In [0]:
display(data.columns)

print(dataframe.columns)
print(data['progression'])


Index(['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6',
       'progression'],
      dtype='object')['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6']
0      151.0
1       75.0
2      141.0
3      206.0
4      135.0
       ...  
437    178.0
438    104.0
439    132.0
440    220.0
441     57.0
Name: progression, Length: 442, dtype: float64


In [0]:
from pyspark.sql.functions import struct

predicted_df = dataframe.withColumn("prediction", pyfunc_udf(struct('age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6')))
display(predicted_df)
print(predicted_df.columns)


age,sex,bmi,bp,s1,s2,s3,s4,s5,s6,prediction
0.0380759064334241,0.0506801187398187,0.0616962065186885,0.0218723549949558,-0.0442234984244464,-0.0348207628376986,-0.0434008456520269,-0.0025922619981828,0.0199084208763183,-0.0176461251598052,184.12785251808447
-0.001882016527791,-0.044641636506989,-0.0514740612388061,-0.0263278347173518,-0.0084487241112169,-0.019163339748222,0.0744115640787594,-0.0394933828740919,-0.0683297436244215,-0.09220404962683,85.44978278638258
0.0852989062966783,0.0506801187398187,0.0444512133365941,-0.0056706105549342,-0.0455994512826475,-0.0341944659141195,-0.0323559322397657,-0.0025922619981828,0.0028637705189401,-0.0259303389894746,157.045618244272
-0.0890629393522603,-0.044641636506989,-0.0115950145052127,-0.0366564467985606,0.0121905687618,0.0249905933641021,-0.0360375700438527,0.0343088588777263,0.0226920225667445,-0.0093619113301358,168.05380875641526
0.005383060374248,-0.044641636506989,-0.0363846922044735,0.0218723549949558,0.0039348516125931,0.0155961395104161,0.0081420836051921,-0.0025922619981828,-0.0319914449413559,-0.0466408735636482,103.41859616060836
-0.0926954778032799,-0.044641636506989,-0.0406959404999971,-0.0194420933298793,-0.0689906498720667,-0.0792878444118122,0.0412768238419757,-0.076394503750001,-0.0411803851880079,-0.0963461565416647,110.11241202517525
-0.0454724779400257,0.0506801187398187,-0.0471628129432825,-0.015999222636143,-0.040095639849843,-0.0248000120604336,0.000778807997017968,-0.0394933828740919,-0.0629129499162512,-0.0383566597339788,92.40504164841695
0.063503675590561,0.0506801187398187,-0.0018947058402846,0.0666296740135272,0.0906198816792644,0.108914381123697,0.0228686348215404,0.0177033544835672,-0.0358167281015492,0.0030644094143683,166.1791865373107
0.0417084448844436,0.0506801187398187,0.0616962065186885,-0.0400993174922969,-0.0139525355440215,0.0062016856567301,-0.0286742944356786,-0.0025922619981828,-0.0149564750249113,0.0113486232440377,161.84443877907745
-0.0709002470971626,-0.044641636506989,0.0390621529671896,-0.0332135761048244,-0.0125765826858204,-0.034507614375909,-0.0249926566315915,-0.0025922619981828,0.0677363261102861,-0.0135040182449705,212.63267845554392


['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6', 'prediction']
Column<'prediction'>


In [0]:
predicted_df.select('prediction').show()

+------------------+
|        prediction|
+------------------+
|184.12785251808447|
| 85.44978278638258|
|  157.045618244272|
|168.05380875641526|
|103.41859616060837|
|110.11241202517523|
| 92.40504164841695|
| 166.1791865373107|
|161.84443877907745|
|212.63267845554392|
|105.71240789225261|
|166.24057698675514|
|107.33247660805279|
| 173.9850222285255|
| 101.7398060781732|
|176.85096142404555|
| 183.0160167435623|
|184.21649084620276|
|  117.067105924991|
|113.62697050215566|
+------------------+
only showing top 20 rows



In [0]:
print(data['progression'])

0      151.0
1       75.0
2      141.0
3      206.0
4      135.0
       ...  
437    178.0
438    104.0
439    132.0
440    220.0
441     57.0
Name: progression, Length: 442, dtype: float64
