In [10]:
!python -V

Python 3.10.14


In [13]:
import os
import pickle
import mlflow
import numpy as np
from hyperopt import STATUS_OK, Trials, fmin, hp, tpe
from hyperopt.pyll import scope
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("random-forest-hyperopt")


def load_pickle(filename: str):
    with open(filename, "rb") as f_in:
        return pickle.load(f_in)


def run_optimization(data_path: str, num_trials: int):

    X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl"))
    X_val, y_val = load_pickle(os.path.join(data_path, "val.pkl"))

    def objective(params):

        with mlflow.start_run():
            mlflow.log_params(params)
            rf = RandomForestRegressor(**params)
            rf.fit(X_train, y_train)
            y_pred = rf.predict(X_val)
            rmse = mean_squared_error(y_val, y_pred, squared=False)
            mlflow.log_metric("rmse", rmse)

        return {'loss': rmse, 'status': STATUS_OK}

    search_space = {
        'max_depth': scope.int(hp.quniform('max_depth', 1, 20, 1)),
        'n_estimators': scope.int(hp.quniform('n_estimators', 10, 50, 1)),
        'min_samples_split': scope.int(hp.quniform('min_samples_split', 2, 10, 1)),
        'min_samples_leaf': scope.int(hp.quniform('min_samples_leaf', 1, 4, 1)),
        'random_state': 42
    }

    rstate = np.random.default_rng(42)  # for reproducible results
    fmin(
        fn=objective,
        space=search_space,
        algo=tpe.suggest,
        max_evals=num_trials,
        trials=Trials(),
        rstate=rstate
    )


# if __name__ == '__main__':
run_optimization("output",15)

2024/09/17 12:20:45 INFO mlflow.tracking.fluent: Experiment with name 'random-forest-hyperopt' does not exist. Creating a new experiment.


  0%|                                                                                                                                                              | 0/15 [00:00<?, ?trial/s, best loss=?]



2024/09/17 12:20:53 INFO mlflow.tracking._tracking_service.client: 🏃 View run honorable-ox-719 at: http://127.0.0.1:5000/#/experiments/1/runs/d19d51fc6ca6481c93bc89be77f5d153.

2024/09/17 12:20:53 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



  7%|████████▊                                                                                                                            | 1/15 [00:08<01:55,  8.27s/trial, best loss: 5.370086069268862]



2024/09/17 12:20:56 INFO mlflow.tracking._tracking_service.client: 🏃 View run righteous-sponge-93 at: http://127.0.0.1:5000/#/experiments/1/runs/391c447b71c3420284ff708422746274.

2024/09/17 12:20:56 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 13%|█████████████████▋                                                                                                                   | 2/15 [00:11<01:09,  5.31s/trial, best loss: 5.370086069268862]



2024/09/17 12:20:59 INFO mlflow.tracking._tracking_service.client: 🏃 View run smiling-perch-292 at: http://127.0.0.1:5000/#/experiments/1/runs/346c569d7de5410bad6ef103ef09d4d4.

2024/09/17 12:20:59 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 20%|██████████████████████████▌                                                                                                          | 3/15 [00:14<00:51,  4.28s/trial, best loss: 5.370086069268862]



2024/09/17 12:21:05 INFO mlflow.tracking._tracking_service.client: 🏃 View run grandiose-eel-203 at: http://127.0.0.1:5000/#/experiments/1/runs/034bece5d2a74ecca425eede4767096d.

2024/09/17 12:21:05 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 27%|███████████████████████████████████▍                                                                                                 | 4/15 [00:19<00:51,  4.67s/trial, best loss: 5.357490752366866]



2024/09/17 12:21:09 INFO mlflow.tracking._tracking_service.client: 🏃 View run delicate-grouse-125 at: http://127.0.0.1:5000/#/experiments/1/runs/f69e464e2f2d4e04b5d2c150057c8532.

2024/09/17 12:21:09 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 33%|████████████████████████████████████████████▎                                                                                        | 5/15 [00:23<00:44,  4.48s/trial, best loss: 5.357490752366866]



2024/09/17 12:21:16 INFO mlflow.tracking._tracking_service.client: 🏃 View run unruly-colt-427 at: http://127.0.0.1:5000/#/experiments/1/runs/cfe66753d97d4adb8cec327afee0fcc3.

2024/09/17 12:21:16 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 40%|█████████████████████████████████████████████████████▏                                                                               | 6/15 [00:31<00:50,  5.61s/trial, best loss: 5.354700855292386]



2024/09/17 12:21:24 INFO mlflow.tracking._tracking_service.client: 🏃 View run gentle-bass-527 at: http://127.0.0.1:5000/#/experiments/1/runs/4e134bd8a632426798618f374e88490b.

2024/09/17 12:21:24 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 47%|██████████████████████████████████████████████████████████████                                                                       | 7/15 [00:38<00:48,  6.10s/trial, best loss: 5.354700855292386]



2024/09/17 12:21:27 INFO mlflow.tracking._tracking_service.client: 🏃 View run rare-mink-507 at: http://127.0.0.1:5000/#/experiments/1/runs/a618961f20914a478d93d3f22d5e37fb.

2024/09/17 12:21:27 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 53%|██████████████████████████████████████████████████████████████████████▉                                                              | 8/15 [00:42<00:36,  5.18s/trial, best loss: 5.354700855292386]



2024/09/17 12:21:45 INFO mlflow.tracking._tracking_service.client: 🏃 View run sassy-quail-895 at: http://127.0.0.1:5000/#/experiments/1/runs/6a6c644808d04ce586e505ef612ff10d.

2024/09/17 12:21:45 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 60%|███████████████████████████████████████████████████████████████████████████████▊                                                     | 9/15 [01:00<00:56,  9.41s/trial, best loss: 5.354700855292386]



2024/09/17 12:21:51 INFO mlflow.tracking._tracking_service.client: 🏃 View run zealous-deer-957 at: http://127.0.0.1:5000/#/experiments/1/runs/3965c0bff8ca4a3ebe72f8e98766c3c4.

2024/09/17 12:21:51 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 67%|████████████████████████████████████████████████████████████████████████████████████████                                            | 10/15 [01:06<00:40,  8.16s/trial, best loss: 5.354700855292386]



2024/09/17 12:21:55 INFO mlflow.tracking._tracking_service.client: 🏃 View run clean-bat-182 at: http://127.0.0.1:5000/#/experiments/1/runs/32a06cdc50ef49afa28a1fd28e8b2b09.

2024/09/17 12:21:55 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 73%|████████████████████████████████████████████████████████████████████████████████████████████████▊                                   | 11/15 [01:10<00:27,  6.99s/trial, best loss: 5.335419588556921]



2024/09/17 12:22:00 INFO mlflow.tracking._tracking_service.client: 🏃 View run smiling-yak-842 at: http://127.0.0.1:5000/#/experiments/1/runs/bfe2aaf6f2f042eb819ce7f40fba4e31.

2024/09/17 12:22:00 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 80%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▌                          | 12/15 [01:15<00:18,  6.26s/trial, best loss: 5.335419588556921]



2024/09/17 12:22:03 INFO mlflow.tracking._tracking_service.client: 🏃 View run powerful-sheep-148 at: http://127.0.0.1:5000/#/experiments/1/runs/820d3b5a16354c218e3d5ac3b42863fb.

2024/09/17 12:22:03 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 87%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                 | 13/15 [01:18<00:10,  5.33s/trial, best loss: 5.335419588556921]



2024/09/17 12:22:07 INFO mlflow.tracking._tracking_service.client: 🏃 View run bustling-lamb-248 at: http://127.0.0.1:5000/#/experiments/1/runs/9d4599bc994a4b81a1894d5ba3ace89d.

2024/09/17 12:22:07 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



 93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏        | 14/15 [01:22<00:04,  4.95s/trial, best loss: 5.335419588556921]



2024/09/17 12:22:12 INFO mlflow.tracking._tracking_service.client: 🏃 View run resilient-owl-895 at: http://127.0.0.1:5000/#/experiments/1/runs/33f90205d00d4a1ba2a7bcc4badaf4b3.

2024/09/17 12:22:12 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [01:27<00:00,  5.81s/trial, best loss: 5.335419588556921]


In [2]:
from mlflow.tracking import MlflowClient
MLFLOW_TRACKING_URI = "sqlite:////Users/user/notebooks/mlflow.db"
client = MlflowClient(tracking_uri=MLFLOW_TRACKING_URI)

In [63]:
from mlflow.entities import ViewType

runs = client.search_runs(
    experiment_ids=1,
    filter_string="metrics.rmse < 7",
    run_view_type=ViewType.ACTIVE_ONLY,
    max_results=5,
    order_by=["metrics.rmse ASC"]
)

for run in runs:
    print(f"run id: {run.info.run_id}, rmse: {run.data.metrics['rmse']:.4f}")

run id: cd7fd20c31674fecaef6675882829cfe, rmse: 5.3354
run id: 70ba89b5b32d423b9d4d8f77eab3a0ba, rmse: 5.3354
run id: e38b013bbbc344da979b84eec832d2e3, rmse: 5.3354
run id: 44d5fcafed904b30b3696cbe37bc0d90, rmse: 5.3547
run id: 1750797a8c4140698939f309de808e7b, rmse: 5.3547


In [11]:
from mlflow.entities import ViewType

runs = client.search_runs(
    experiment_ids=2,
    filter_string="params.min_samples_split = '2'",
    run_view_type=ViewType.ALL,
    max_results=5,
    order_by=["params.min_samples_split ASC"]
)

for run in runs:
    print(f"run_id: {run.info.run_id}, min_samples_split: {run.data.params['min_samples_split']}")

run_id: 562b67d2f501432ea2db9902fc618229, min_samples_split: 2


In [26]:
import os
import pickle
import click
import mlflow

from mlflow.entities import ViewType
from mlflow.tracking import MlflowClient
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

HPO_EXPERIMENT_NAME = "random-forest-hyperopt"
EXPERIMENT_NAME = "random-forest-best-models"
RF_PARAMS = ['max_depth', 'n_estimators', 'min_samples_split', 'min_samples_leaf', 'random_state']

mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment(EXPERIMENT_NAME)
mlflow.sklearn.autolog()


def load_pickle(filename):
    with open(filename, "rb") as f_in:
        return pickle.load(f_in)


def train_and_log_model(data_path, params):
    X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl"))
    X_val, y_val = load_pickle(os.path.join(data_path, "val.pkl"))
    X_test, y_test = load_pickle(os.path.join(data_path, "test.pkl"))

    with mlflow.start_run():
        new_params = {}
        for param in RF_PARAMS:
            new_params[param] = int(params[param])

        rf = RandomForestRegressor(**new_params)
        rf.fit(X_train, y_train)

        # Evaluate model on the validation and test sets
        val_rmse = mean_squared_error(y_val, rf.predict(X_val), squared=False)
        mlflow.log_metric("val_rmse", val_rmse)
        test_rmse = mean_squared_error(y_test, rf.predict(X_test), squared=False)
        mlflow.log_metric("test_rmse", test_rmse)

def run_register_model(data_path: str, top_n: int):

    client = MlflowClient()

    # Retrieve the top_n model runs and log the models
    experiment = client.get_experiment_by_name(HPO_EXPERIMENT_NAME)
    runs = client.search_runs(
        experiment_ids=experiment.experiment_id,
        run_view_type=ViewType.ACTIVE_ONLY,
        max_results=top_n,
        order_by=["metrics.rmse ASC"]
    )
    for run in runs:
        train_and_log_model(data_path=data_path, params=run.data.params)

    # Select the model with the lowest test RMSE
    experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
    best_run = client.search_runs(
        experiment_ids=experiment.experiment_id,
        run_view_type=ViewType.ACTIVE_ONLY,
        max_results=top_n,
        order_by=["metrics.test_rmse ASC"]
    )[0]

    # Register the best model
    run_id = best_run.info.run_id
    model_uri = f"runs:/{run_id}/model"
    mlflow.register_model(model_uri, name="rf-best-model")


# if __name__ == '__main__':
run_register_model(data_path="output",top_n=5)



2024/09/17 14:23:23 INFO mlflow.tracking._tracking_service.client: 🏃 View run rebellious-frog-136 at: http://127.0.0.1:5000/#/experiments/4/runs/24d1858a457740f9a0f967ce419c7172.
2024/09/17 14:23:23 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/4.
2024/09/17 14:23:29 INFO mlflow.tracking._tracking_service.client: 🏃 View run welcoming-snail-412 at: http://127.0.0.1:5000/#/experiments/4/runs/0ece247281dc4804a6a3e2f1c2680f68.
2024/09/17 14:23:29 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/4.
2024/09/17 14:23:38 INFO mlflow.tracking._tracking_service.client: 🏃 View run flawless-duck-888 at: http://127.0.0.1:5000/#/experiments/4/runs/97fd3f09be1b40ac9bc5aa57d22e6d6f.
2024/09/17 14:23:38 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/4.
2024/09/17 14:23:47 INFO mlflow.tracking._tracking_service.client: 🏃 View run car

In [29]:
experiment = client.get_experiment_by_name('random-forest-best-models')
runs = client.search_runs(
    experiment_ids=experiment.experiment_id,
    filter_string="metrics.test_rmse < 7",
    run_view_type=ViewType.ACTIVE_ONLY,
    max_results=5,
    order_by=["metrics.test_rmse ASC"]
)

for run in runs:
    print(f"run id: {run.info.run_id}, test_rmse: {run.data.metrics['test_rmse']:.4f}")

run id: 0ece247281dc4804a6a3e2f1c2680f68, test_rmse: 5.5674
run id: 24d1858a457740f9a0f967ce419c7172, test_rmse: 5.5674
run id: 36c88e26072948b587112b06b5b72b99, test_rmse: 5.5853
run id: 97fd3f09be1b40ac9bc5aa57d22e6d6f, test_rmse: 5.5853
run id: 5ec021897fa04c81b2c18320e2a77a88, test_rmse: 5.5921
