In [1]:
%env MLFLOW_TRACKING_URI=http://127.0.0.1:5001

env: MLFLOW_TRACKING_URI=http://127.0.0.1:5001


In [2]:
import mlflow

import numpy as np
from sklearn import datasets, metrics
from sklearn.linear_model import ElasticNet
from sklearn.model_selection import train_test_split


def eval_metrics(pred, actual):
    rmse = np.sqrt(metrics.mean_squared_error(actual, pred))
    mae = metrics.mean_absolute_error(actual, pred)
    r2 = metrics.r2_score(actual, pred)
    return rmse, mae, r2

In [3]:
# Set th experiment name
mlflow.set_experiment("wine-quality")
# Enable auto-logging to MLflow
mlflow.sklearn.autolog()



2024/10/29 13:37:01 INFO mlflow.tracking.fluent: Experiment with name 'wine-quality' does not exist. Creating a new experiment.


In [4]:
# Load wine quality dataset
X, y = datasets.load_wine(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)

# Start a run and train a model
with mlflow.start_run(run_name="default-params"):
    lr = ElasticNet()
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_test)
    rmse, mae, r2 = eval_metrics(y_pred, y_test)
    mlflow.log_metrics(
        {
            "mean_squared_error_X_test": rmse,
            "mean_absolute_error_X_test": mae,
            "r2_score_X_test": r2,
        }
    )

2024/10/29 13:37:09 INFO mlflow.tracking._tracking_service.client: 🏃 View run default-params at: http://127.0.0.1:5001/#/experiments/180333391132533584/runs/1a875e40a53b479c92c303aa2b71c5a5.
2024/10/29 13:37:09 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5001/#/experiments/180333391132533584.


In [5]:
from scipy.stats import uniform
from sklearn.model_selection import RandomizedSearchCV

lr = ElasticNet()

# Define distribution to pick parameter values from
distributions = dict(
    alpha=uniform(loc=0, scale=10),  # sample alpha uniformly from [-5.0, 5.0]
    l1_ratio=uniform(),  # sample l1_ratio uniformlyfrom [0, 1.0]
)

# Initialize random search instance
clf = RandomizedSearchCV(
    estimator=lr,
    param_distributions=distributions,
    # Optimize for mean absolute error
    scoring="neg_mean_absolute_error",
    # Use 5-fold cross validation
    cv=5,
    # Try 100 samples. Note that MLflow only logs the top 5 runs.
    n_iter=100,
)

# Start a parent run
with mlflow.start_run(run_name="hyperparameter-tuning"):
    search = clf.fit(X_train, y_train)

    # Evaluate the best model on test dataset
    y_pred = clf.best_estimator_.predict(X_test)
    rmse, mae, r2 = eval_metrics(y_pred, y_test)
    mlflow.log_metrics(
        {
            "mean_squared_error_X_test": rmse,
            "mean_absolute_error_X_test": mae,
            "r2_score_X_test": r2,
        }
    )


2024/10/29 13:37:20 INFO mlflow.sklearn.utils: Logging the 5 best runs, 95 runs will be omitted.
2024/10/29 13:37:21 INFO mlflow.tracking._tracking_service.client: 🏃 View run upbeat-croc-431 at: http://127.0.0.1:5001/#/experiments/180333391132533584/runs/65e78c1db29a4b7f813dc7ac13c8c933.
2024/10/29 13:37:21 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5001/#/experiments/180333391132533584.
2024/10/29 13:37:21 INFO mlflow.tracking._tracking_service.client: 🏃 View run unequaled-calf-711 at: http://127.0.0.1:5001/#/experiments/180333391132533584/runs/9eb612e820db4f79889d9c913363c89e.
2024/10/29 13:37:21 INFO mlflow.tracking._tracking_service.client: 🏃 View run omniscient-grub-582 at: http://127.0.0.1:5001/#/experiments/180333391132533584/runs/e76e30d3bc1c4438aad2f23442d7689b.
2024/10/29 13:37:21 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5001/#/experiments/180333391132533584.
2024/10/29 13:37:21 INFO mlflo

2024/10/29 13:37:21 INFO mlflow.tracking._tracking_service.client: 🏃 View run agreeable-stag-46 at: http://127.0.0.1:5001/#/experiments/180333391132533584/runs/00ab398d9ff64d318cb6c68438bce0d8.
2024/10/29 13:37:21 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5001/#/experiments/180333391132533584.
