In [1]:
import pandas as pd
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import mlflow
from mlflow.models import infer_signature

In [2]:
## Set the Tracking URI
mlflow.set_tracking_uri(uri = 'http://127.0.0.1:5555')

In [3]:
# Load the dataset
X, y = datasets.load_iris(return_X_y = True)

# Split the dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)

# Hyperparameter Tuning
params = {'penalty': 'l2', 'solver': 'lbfgs', 'max_iter': 1000, "multi_class": "auto", "random_state": 8888}

# Train the model
lr = LogisticRegression(**params)
lr.fit(X_train, y_train)

In [4]:
# Prediction on the test set
y_pred = lr.predict(X_test)
y_pred

array([2, 0, 0, 0, 1, 2, 0, 0, 0, 2, 2, 2, 1, 1, 2, 2, 0, 0, 2, 2, 1, 1,
       1, 1, 0, 2, 2, 1, 1, 0])

In [5]:
accuracy = accuracy_score(y_test, y_pred)
print(accuracy)

0.9666666666666667


In [7]:
# MLflow Tracking
mlflow.set_tracking_uri(uri = 'http://127.0.0.1:5050')

# Create a new MLFlow experiment
mlflow.set_experiment("MLFLOW Quickstart")

# Start the MLFlow experiment
with mlflow.start_run():
    mlflow.log_params(params)

    # Log the accuracy metrics
    mlflow.log_metric("accuracy", accuracy)

    # Set a tag that we can use to remind ourselves what this run was for 
    mlflow.set_tag("Training Info", "Basic LR model for iris data")

    # Infer the model signature
    signature = infer_signature(X_train, lr.predict(X_train))

    # Log the model
    model_info = mlflow.sklearn.log_model(
        sk_model = lr,
        artifact_path = "iris_model",
        signature = signature,
        input_example = X_train,
        registered_model_name = "Tracking_quickstart"
    )

2024/12/09 21:53:27 INFO mlflow.tracking.fluent: Experiment with name 'MLFLOW Quickstart' does not exist. Creating a new experiment.
Successfully registered model 'Tracking_quickstart'.
2024/12/09 21:53:30 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: Tracking_quickstart, version 1


🏃 View run able-boar-610 at: http://127.0.0.1:5050/#/experiments/616111320491376902/runs/1ff25e4a5d554b138c36c08d467f6ff2
🧪 View experiment at: http://127.0.0.1:5050/#/experiments/616111320491376902


Created version '1' of model 'Tracking_quickstart'.


## Validate the model before deployment

In [11]:
from mlflow.models import validate_serving_input

model_uri = 'runs:/1ff25e4a5d554b138c36c08d467f6ff2/iris_model'

# The model is logged with an input example. MLflow converts
# it into the serving payload format for the deployed model endpoint,
# and saves it to 'serving_input_payload.json'
serving_payload = """{
  "inputs": [
    [
      4.4,
      3.2,
      1.3,
      0.2
    ],
    [
      5.5,
      2.6,
      4.4,
      1.2
    ],
    [
      6.4,
      3.2,
      4.5,
      1.5
    ],
    [
      7.7,
      3.8,
      6.7,
      2.2
    ],
    [
      6.5,
      3.0,
      5.8,
      2.2
    ],
    [
      6.3,
      2.5,
      5.0,
      1.9
    ],
    [
      5.8,
      2.7,
      5.1,
      1.9
    ],
    [
      6.7,
      3.3,
      5.7,
      2.1
    ],
    [
      6.1,
      2.9,
      4.7,
      1.4
    ],
    [
      5.5,
      4.2,
      1.4,
      0.2
    ],
    [
      6.3,
      2.3,
      4.4,
      1.3
    ],
    [
      6.9,
      3.1,
      4.9,
      1.5
    ],
    [
      5.0,
      3.6,
      1.4,
      0.2
    ],
    [
      6.4,
      2.8,
      5.6,
      2.1
    ],
    [
      5.4,
      3.7,
      1.5,
      0.2
    ],
    [
      6.7,
      3.0,
      5.0,
      1.7
    ],
    [
      6.5,
      2.8,
      4.6,
      1.5
    ],
    [
      5.6,
      2.5,
      3.9,
      1.1
    ],
    [
      5.8,
      2.7,
      4.1,
      1.0
    ],
    [
      6.4,
      2.7,
      5.3,
      1.9
    ],
    [
      5.1,
      3.5,
      1.4,
      0.2
    ],
    [
      6.6,
      3.0,
      4.4,
      1.4
    ],
    [
      6.1,
      2.6,
      5.6,
      1.4
    ],
    [
      5.5,
      3.5,
      1.3,
      0.2
    ],
    [
      4.8,
      3.0,
      1.4,
      0.3
    ],
    [
      6.1,
      3.0,
      4.6,
      1.4
    ],
    [
      7.9,
      3.8,
      6.4,
      2.0
    ],
    [
      6.3,
      2.7,
      4.9,
      1.8
    ],
    [
      5.0,
      3.3,
      1.4,
      0.2
    ],
    [
      4.5,
      2.3,
      1.3,
      0.3
    ],
    [
      4.9,
      3.1,
      1.5,
      0.1
    ],
    [
      5.2,
      2.7,
      3.9,
      1.4
    ],
    [
      5.7,
      3.0,
      4.2,
      1.2
    ],
    [
      6.2,
      2.8,
      4.8,
      1.8
    ],
    [
      6.0,
      2.2,
      5.0,
      1.5
    ],
    [
      5.8,
      2.7,
      5.1,
      1.9
    ],
    [
      6.3,
      2.9,
      5.6,
      1.8
    ],
    [
      4.9,
      2.5,
      4.5,
      1.7
    ],
    [
      4.4,
      3.0,
      1.3,
      0.2
    ],
    [
      5.1,
      3.3,
      1.7,
      0.5
    ],
    [
      5.7,
      3.8,
      1.7,
      0.3
    ],
    [
      7.7,
      2.8,
      6.7,
      2.0
    ],
    [
      6.3,
      3.3,
      6.0,
      2.5
    ],
    [
      5.1,
      2.5,
      3.0,
      1.1
    ],
    [
      5.1,
      3.8,
      1.6,
      0.2
    ],
    [
      5.1,
      3.8,
      1.5,
      0.3
    ],
    [
      4.6,
      3.2,
      1.4,
      0.2
    ],
    [
      6.3,
      3.3,
      4.7,
      1.6
    ],
    [
      6.1,
      2.8,
      4.7,
      1.2
    ],
    [
      4.8,
      3.0,
      1.4,
      0.1
    ],
    [
      5.7,
      2.9,
      4.2,
      1.3
    ],
    [
      5.4,
      3.0,
      4.5,
      1.5
    ],
    [
      6.0,
      3.0,
      4.8,
      1.8
    ],
    [
      4.9,
      2.4,
      3.3,
      1.0
    ],
    [
      6.0,
      2.9,
      4.5,
      1.5
    ],
    [
      6.7,
      2.5,
      5.8,
      1.8
    ],
    [
      5.0,
      3.5,
      1.3,
      0.3
    ],
    [
      6.9,
      3.1,
      5.4,
      2.1
    ],
    [
      5.7,
      4.4,
      1.5,
      0.4
    ],
    [
      4.7,
      3.2,
      1.3,
      0.2
    ],
    [
      4.8,
      3.1,
      1.6,
      0.2
    ],
    [
      7.4,
      2.8,
      6.1,
      1.9
    ],
    [
      5.9,
      3.0,
      5.1,
      1.8
    ],
    [
      7.2,
      3.2,
      6.0,
      1.8
    ],
    [
      5.6,
      3.0,
      4.5,
      1.5
    ],
    [
      7.0,
      3.2,
      4.7,
      1.4
    ],
    [
      5.7,
      2.8,
      4.1,
      1.3
    ],
    [
      5.3,
      3.7,
      1.5,
      0.2
    ],
    [
      6.4,
      2.9,
      4.3,
      1.3
    ],
    [
      6.7,
      3.1,
      5.6,
      2.4
    ],
    [
      5.4,
      3.4,
      1.5,
      0.4
    ],
    [
      5.1,
      3.8,
      1.9,
      0.4
    ],
    [
      5.8,
      2.7,
      3.9,
      1.2
    ],
    [
      7.7,
      2.6,
      6.9,
      2.3
    ],
    [
      5.6,
      2.9,
      3.6,
      1.3
    ],
    [
      6.9,
      3.1,
      5.1,
      2.3
    ],
    [
      6.9,
      3.2,
      5.7,
      2.3
    ],
    [
      6.7,
      3.1,
      4.7,
      1.5
    ],
    [
      5.2,
      4.1,
      1.5,
      0.1
    ],
    [
      5.4,
      3.9,
      1.7,
      0.4
    ],
    [
      6.3,
      2.5,
      4.9,
      1.5
    ],
    [
      5.9,
      3.0,
      4.2,
      1.5
    ],
    [
      5.1,
      3.7,
      1.5,
      0.4
    ],
    [
      4.3,
      3.0,
      1.1,
      0.1
    ],
    [
      5.1,
      3.5,
      1.4,
      0.3
    ],
    [
      5.0,
      3.4,
      1.6,
      0.4
    ],
    [
      5.4,
      3.4,
      1.7,
      0.2
    ],
    [
      7.2,
      3.6,
      6.1,
      2.5
    ],
    [
      6.0,
      3.4,
      4.5,
      1.6
    ],
    [
      6.8,
      3.0,
      5.5,
      2.1
    ],
    [
      6.4,
      3.2,
      5.3,
      2.3
    ],
    [
      6.5,
      3.0,
      5.5,
      1.8
    ],
    [
      6.8,
      2.8,
      4.8,
      1.4
    ],
    [
      4.7,
      3.2,
      1.6,
      0.2
    ],
    [
      5.2,
      3.5,
      1.5,
      0.2
    ],
    [
      7.1,
      3.0,
      5.9,
      2.1
    ],
    [
      5.8,
      2.8,
      5.1,
      2.4
    ],
    [
      6.2,
      2.9,
      4.3,
      1.3
    ],
    [
      5.0,
      2.3,
      3.3,
      1.0
    ],
    [
      5.0,
      2.0,
      3.5,
      1.0
    ],
    [
      5.1,
      3.4,
      1.5,
      0.2
    ],
    [
      6.5,
      3.0,
      5.2,
      2.0
    ],
    [
      5.7,
      2.5,
      5.0,
      2.0
    ],
    [
      6.0,
      2.2,
      4.0,
      1.0
    ],
    [
      5.2,
      3.4,
      1.4,
      0.2
    ],
    [
      5.6,
      3.0,
      4.1,
      1.3
    ],
    [
      5.5,
      2.4,
      3.7,
      1.0
    ],
    [
      4.6,
      3.4,
      1.4,
      0.3
    ],
    [
      6.4,
      2.8,
      5.6,
      2.2
    ],
    [
      5.6,
      2.8,
      4.9,
      2.0
    ],
    [
      5.5,
      2.4,
      3.8,
      1.1
    ],
    [
      4.6,
      3.1,
      1.5,
      0.2
    ],
    [
      6.3,
      2.8,
      5.1,
      1.5
    ],
    [
      4.9,
      3.6,
      1.4,
      0.1
    ],
    [
      6.4,
      3.1,
      5.5,
      1.8
    ],
    [
      6.2,
      3.4,
      5.4,
      2.3
    ],
    [
      5.8,
      4.0,
      1.2,
      0.2
    ],
    [
      5.9,
      3.2,
      4.8,
      1.8
    ],
    [
      6.6,
      2.9,
      4.6,
      1.3
    ],
    [
      4.9,
      3.1,
      1.5,
      0.2
    ]
  ]
}"""

# Validate the serving payload works on the model
validate_serving_input(model_uri, serving_payload)

Downloading artifacts: 100%|██████████| 7/7 [00:00<00:00, 309.53it/s]


array([0, 1, 1, 2, 2, 2, 2, 2, 1, 0, 1, 1, 0, 2, 0, 2, 1, 1, 1, 2, 0, 1,
       2, 0, 0, 1, 2, 2, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 2, 2, 1,
       0, 0, 0, 1, 1, 0, 1, 1, 2, 1, 1, 2, 0, 2, 0, 0, 0, 2, 2, 2, 1, 1,
       1, 0, 1, 2, 0, 0, 1, 2, 1, 2, 2, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 2,
       1, 2, 2, 2, 1, 0, 0, 2, 2, 1, 1, 1, 0, 2, 2, 1, 0, 1, 1, 0, 2, 2,
       1, 0, 2, 0, 2, 2, 0, 2, 1, 0])

## Load the model back for prediction as a generic python function model

In [8]:
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
predictions = loaded_model.predict(X_test)

iris_features_name = datasets.load_iris().feature_names

result = pd.DataFrame(X_test, columns = iris_features_name)
result['actual_class'] = y_test
result['predicted_class'] = predictions

  from .autonotebook import tqdm as notebook_tqdm
Downloading artifacts: 100%|██████████| 7/7 [00:00<00:00, 1250.22it/s] 


In [10]:
result[:5]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),actual_class,predicted_class
0,6.7,3.0,5.2,2.3,2,2
1,5.0,3.0,1.6,0.2,0,0
2,4.8,3.4,1.9,0.2,0,0
3,4.9,3.0,1.4,0.2,0,0
4,5.5,2.3,4.0,1.3,1,1


## Model Registry
The MLflow Model Registry component is a centralized model store, set of APIs, and UI, to collaboratively manage the full lifecycle of an MLflow Mode. It provides model lineage (which MLflow experiment and run produced the model), model versioning, model aliasing, model tagging, and annotations.

In [None]:
# Inferencing from model from model registry

import mlflow.sklearn
model_name = 'Tracking_quickstart'
model_version = 'latest'

model_uri = f"models:/{model_name}/{model_version}"

model = mlflow.sklearn.load_model(model_uri)
model

  latest = client.get_latest_versions(name, None if stage is None else [stage])
Downloading artifacts: 100%|██████████| 7/7 [00:00<00:00, 212.28it/s]
