In [1]:
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from urllib.parse import urlparse
from mlflow.tracking import MlflowClient
import mlflow.sklearn
from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository

In [2]:
# Load data
df = pd.read_csv("https://raw.githubusercontent.com/erkansirin78/datasets/master/Churn_Modelling.csv")

# Select features and target
X = df.iloc[:, 3:-1]
y = df['Exited']

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

# MLflow

In [None]:
# Set MLflow tracking URI and experiment name
os.environ['MLFLOW_TRACKING_URI'] = 'http://localhost:5001/'
os.environ['MLFLOW_S3_ENDPOINT_URL'] = 'http://localhost:9000/'

In [None]:
# Define a function to evaluate metrics
def eval_metrics(actual, pred):
    accuracy = accuracy_score(actual, pred)
    clf_report = classification_report(actual, pred)
    return accuracy, clf_report

In [None]:
experiment_name = "FastAPI with MLflow"
mlflow.set_experiment(experiment_name)

registered_model_name="ChurnModel"

In [6]:
number_of_trees=200

In [None]:
# Start MLflow run
with mlflow.start_run(run_name="churn-rf-sklearn") as run:
    # Load the best pipeline from the saved model
    from train_churn_model import grid_search  # Import the trained model from your script

    # Best estimator
    best_pipeline = grid_search.best_estimator_

    # Predictions
    y_pred = best_pipeline.predict(X_test)

    # Evaluation
    accuracy, clf_report = eval_metrics(y_test, y_pred)
    print("Best Parameters: ", grid_search.best_params_)
    print("Accuracy: ", accuracy)
    print("Classification Report:\n", clf_report)

    # Log parameters and metrics
    mlflow.log_params(grid_search.best_params_)
    mlflow.log_metric("accuracy", accuracy)

    # Log the model
    tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme

    # Model registry does not work with file store
    if tracking_url_type_store != "file":
        # Register the model
        mlflow.sklearn.log_model(best_pipeline, "model", registered_model_name=registered_model_name)
    else:
        mlflow.sklearn.log_model(best_pipeline, "model")

Random Forest model number of trees: 200
  RMSE: 0.6987288024648354
  MAE: 0.5835151515151545
  R2: 0.9810832419633377


Registered model 'AdvertisingRFModel' already exists. Creating a new version of this model...
2022/05/18 10:35:11 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: AdvertisingRFModel, version 3
Created version '3' of model 'AdvertisingRFModel'.


# Optional Part

In [None]:
# Optional: Interact with the MLflow Model Registry
name = registered_model_name
client = MlflowClient()

# Create a new registered model if it doesn't exist
try:
    client.create_registered_model(name)
except Exception as e:
    print(f"Model {name} already exists. Skipping creation.")


In [None]:
# Create a new model version
model_uri = f"runs:/{run.info.run_id}/model"
print("Model URI:", model_uri)

runs:/e6f60e9f9c4e413988a6a22610e2be79/sklearn-model


In [None]:
mv = client.create_model_version(name, model_uri, run.info.run_id)
print("Model version {} created".format(mv.version))
last_mv = mv.version
print("Latest model version:", last_mv)

2022/05/18 10:38:28 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: AdvertisingRFModel, version 6


model version 6 created
6


In [None]:
# Function to print model info
def print_models_info(models):
    for m in models:
        print("name: {}".format(m.name))
        print("latest version: {}".format(m.version))
        print("run_id: {}".format(m.run_id))
        print("current_stage: {}".format(m.current_stage))

# Get the latest model versions
models = client.get_latest_versions(name, stages=["None"])
print_models_info(models)

# Print the latest model version
print(f"Latest version: {last_mv}")