In [1]:
import os
import pickle
import click
import mlflow
import numpy as np
from hyperopt import STATUS_OK, Trials, fmin, hp, tpe
from hyperopt.pyll import scope
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

In [2]:
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("random-forest-hyperopt")

data_path = "./output"

2024/05/26 13:15:30 INFO mlflow.tracking.fluent: Experiment with name 'random-forest-hyperopt' does not exist. Creating a new experiment.


In [7]:
def load_pickle(filename: str):
    with open(filename, "rb") as f_in:
        return pickle.load(f_in)
        
def run_optimization(data_path: str, num_trials: int):

    X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl"))
    X_val, y_val = load_pickle(os.path.join(data_path, "val.pkl"))

    def objective(params):

        with mlflow.start_run():
            mlflow.set_tag("model", "RandomForestRegressor")
            mlflow.log_params(params)

            rf = RandomForestRegressor(**params)
            rf.fit(X_train, y_train)
            y_pred = rf.predict(X_val)
            rmse = mean_squared_error(y_val, y_pred, squared=False)

            mlflow.log_metric("rmse", rmse)

        return {'loss': rmse, 'status': STATUS_OK}

    search_space = {
        'max_depth': scope.int(hp.quniform('max_depth', 1, 20, 1)),
        'n_estimators': scope.int(hp.quniform('n_estimators', 10, 50, 1)),
        'min_samples_split': scope.int(hp.quniform('min_samples_split', 2, 10, 1)),
        'min_samples_leaf': scope.int(hp.quniform('min_samples_leaf', 1, 4, 1)),
        'random_state': 42
    }

    rstate = np.random.default_rng(42)  # for reproducible results
    fmin(
        fn=objective,
        space=search_space,
        algo=tpe.suggest,
        max_evals=num_trials,
        trials=Trials(),
        rstate=rstate
    )

In [8]:
run_optimization(data_path, 100)

  1%|▋                                                              | 1/100 [00:07<12:36,  7.64s/trial, best loss: 5.370086069268862]




  2%|█▎                                                             | 2/100 [00:08<05:32,  3.40s/trial, best loss: 5.370086069268862]




  3%|█▉                                                             | 3/100 [00:08<03:24,  2.11s/trial, best loss: 5.370086069268862]




  4%|██▌                                                            | 4/100 [00:14<05:27,  3.41s/trial, best loss: 5.357490752366866]




  5%|███▏                                                           | 5/100 [00:16<05:00,  3.16s/trial, best loss: 5.357490752366866]




  6%|███▊                                                           | 6/100 [00:26<08:11,  5.23s/trial, best loss: 5.354695072530291]




  7%|████▍                                                          | 7/100 [00:35<10:14,  6.61s/trial, best loss: 5.354695072530291]




  8%|█████                                                          | 8/100 [00:36<07:23,  4.82s/trial, best loss: 5.354695072530291]




  9%|█████▋                                                         | 9/100 [00:42<08:01,  5.29s/trial, best loss: 5.354695072530291]




 10%|██████▏                                                       | 10/100 [00:47<07:50,  5.23s/trial, best loss: 5.354695072530291]




 11%|██████▊                                                       | 11/100 [00:51<07:07,  4.81s/trial, best loss: 5.335419588556921]




 12%|███████▍                                                      | 12/100 [00:54<06:07,  4.18s/trial, best loss: 5.335419588556921]




 13%|████████                                                      | 13/100 [00:55<04:38,  3.20s/trial, best loss: 5.335419588556921]




 14%|████████▋                                                     | 14/100 [00:58<04:43,  3.29s/trial, best loss: 5.335419588556921]




 15%|█████████▎                                                    | 15/100 [01:04<05:40,  4.01s/trial, best loss: 5.335419588556921]




 16%|█████████▉                                                    | 16/100 [01:08<05:28,  3.91s/trial, best loss: 5.335419588556921]




 17%|██████████▌                                                   | 17/100 [01:13<06:08,  4.44s/trial, best loss: 5.335419588556921]




 18%|███████████▏                                                  | 18/100 [01:21<07:25,  5.43s/trial, best loss: 5.322418787243458]




 19%|███████████▊                                                  | 19/100 [01:23<05:56,  4.40s/trial, best loss: 5.322418787243458]




 20%|████████████▍                                                 | 20/100 [01:25<04:39,  3.49s/trial, best loss: 5.322418787243458]




 21%|█████████████                                                 | 21/100 [01:36<07:51,  5.97s/trial, best loss: 5.320599657074168]




 22%|█████████████▋                                                | 22/100 [01:45<08:52,  6.82s/trial, best loss: 5.320599657074168]




 23%|██████████████▎                                               | 23/100 [01:59<11:32,  9.00s/trial, best loss: 5.320599657074168]




 24%|██████████████▉                                               | 24/100 [02:10<12:08,  9.59s/trial, best loss: 5.320599657074168]




 25%|███████████████▌                                              | 25/100 [02:19<11:35,  9.27s/trial, best loss: 5.320599657074168]




 26%|████████████████                                              | 26/100 [02:29<11:46,  9.55s/trial, best loss: 5.320408749882826]




 27%|████████████████▋                                             | 27/100 [02:33<09:48,  8.07s/trial, best loss: 5.320408749882826]




 28%|█████████████████▎                                            | 28/100 [02:41<09:32,  7.95s/trial, best loss: 5.320408749882826]




 29%|█████████████████▉                                            | 29/100 [02:48<08:56,  7.56s/trial, best loss: 5.320408749882826]




 30%|██████████████████▌                                           | 30/100 [03:02<10:58,  9.40s/trial, best loss: 5.320408749882826]




 31%|███████████████████▏                                          | 31/100 [03:02<07:45,  6.75s/trial, best loss: 5.320408749882826]




 32%|███████████████████▊                                          | 32/100 [03:09<07:38,  6.74s/trial, best loss: 5.320408749882826]




 33%|████████████████████▍                                         | 33/100 [03:14<07:05,  6.35s/trial, best loss: 5.320408749882826]




 34%|█████████████████████                                         | 34/100 [03:23<07:49,  7.12s/trial, best loss: 5.320408749882826]




 35%|█████████████████████▋                                        | 35/100 [03:23<05:30,  5.08s/trial, best loss: 5.320408749882826]




 36%|██████████████████████▎                                       | 36/100 [03:31<06:21,  5.96s/trial, best loss: 5.320408749882826]




 37%|██████████████████████▉                                       | 37/100 [03:45<08:35,  8.18s/trial, best loss: 5.320408749882826]




 38%|███████████████████████▌                                      | 38/100 [03:49<07:14,  7.01s/trial, best loss: 5.320408749882826]




 39%|████████████████████████▏                                     | 39/100 [04:02<08:50,  8.70s/trial, best loss: 5.320408749882826]




 40%|████████████████████████▊                                     | 40/100 [04:06<07:18,  7.30s/trial, best loss: 5.320408749882826]




 41%|█████████████████████████▍                                    | 41/100 [04:12<06:54,  7.03s/trial, best loss: 5.320408749882826]




 42%|██████████████████████████                                    | 42/100 [04:13<05:02,  5.22s/trial, best loss: 5.320408749882826]




 43%|██████████████████████████▋                                   | 43/100 [04:24<06:24,  6.75s/trial, best loss: 5.320408749882826]




 44%|███████████████████████████▎                                  | 44/100 [04:31<06:31,  6.99s/trial, best loss: 5.320408749882826]




 45%|███████████████████████████▉                                  | 45/100 [04:33<05:02,  5.49s/trial, best loss: 5.320408749882826]




 46%|████████████████████████████▌                                 | 46/100 [04:47<07:15,  8.06s/trial, best loss: 5.320408749882826]




 47%|█████████████████████████████▏                                | 47/100 [04:52<06:19,  7.17s/trial, best loss: 5.320408749882826]




 48%|█████████████████████████████▊                                | 48/100 [04:55<05:04,  5.85s/trial, best loss: 5.320408749882826]




 49%|██████████████████████████████▍                               | 49/100 [04:56<03:39,  4.30s/trial, best loss: 5.320408749882826]




 50%|███████████████████████████████                               | 50/100 [05:02<04:03,  4.86s/trial, best loss: 5.320408749882826]




 51%|███████████████████████████████▌                              | 51/100 [05:03<03:07,  3.82s/trial, best loss: 5.320408749882826]




 52%|████████████████████████████████▏                             | 52/100 [05:12<04:08,  5.19s/trial, best loss: 5.320408749882826]




 53%|████████████████████████████████▊                             | 53/100 [05:15<03:37,  4.63s/trial, best loss: 5.320408749882826]




 54%|█████████████████████████████████▍                            | 54/100 [05:18<03:06,  4.06s/trial, best loss: 5.320408749882826]




 55%|██████████████████████████████████                            | 55/100 [05:27<04:19,  5.76s/trial, best loss: 5.320408749882826]




 56%|██████████████████████████████████▋                           | 56/100 [05:40<05:44,  7.82s/trial, best loss: 5.320408749882826]




 57%|███████████████████████████████████▉                           | 57/100 [05:46<05:16,  7.36s/trial, best loss: 5.31689418131154]




 58%|████████████████████████████████████▌                          | 58/100 [05:48<03:57,  5.66s/trial, best loss: 5.31689418131154]




 59%|█████████████████████████████████████▏                         | 59/100 [05:52<03:25,  5.02s/trial, best loss: 5.31689418131154]




 60%|█████████████████████████████████████▊                         | 60/100 [05:56<03:13,  4.84s/trial, best loss: 5.31689418131154]




 61%|██████████████████████████████████████▍                        | 61/100 [06:03<03:29,  5.36s/trial, best loss: 5.31248543850713]




 62%|███████████████████████████████████████                        | 62/100 [06:05<02:48,  4.43s/trial, best loss: 5.31248543850713]




 63%|███████████████████████████████████████▋                       | 63/100 [06:11<02:58,  4.83s/trial, best loss: 5.31248543850713]




 64%|████████████████████████████████████████▎                      | 64/100 [06:11<02:12,  3.67s/trial, best loss: 5.31248543850713]




 65%|████████████████████████████████████████▉                      | 65/100 [06:12<01:36,  2.77s/trial, best loss: 5.31248543850713]




 66%|████████████████████████████████████████▉                     | 66/100 [06:20<02:27,  4.35s/trial, best loss: 5.311879192419245]




 67%|█████████████████████████████████████████▌                    | 67/100 [06:26<02:39,  4.84s/trial, best loss: 5.311879192419245]




 68%|██████████████████████████████████████████▏                   | 68/100 [06:32<02:49,  5.28s/trial, best loss: 5.311879192419245]




 69%|██████████████████████████████████████████▊                   | 69/100 [06:39<02:51,  5.53s/trial, best loss: 5.311879192419245]




 70%|███████████████████████████████████████████▍                  | 70/100 [06:43<02:35,  5.18s/trial, best loss: 5.311879192419245]




 71%|████████████████████████████████████████████                  | 71/100 [06:50<02:44,  5.69s/trial, best loss: 5.311234357303558]




 72%|████████████████████████████████████████████▋                 | 72/100 [06:56<02:42,  5.81s/trial, best loss: 5.311234357303558]




 73%|█████████████████████████████████████████████▎                | 73/100 [07:04<02:56,  6.53s/trial, best loss: 5.311234357303558]




 74%|█████████████████████████████████████████████▉                | 74/100 [07:11<02:53,  6.68s/trial, best loss: 5.311234357303558]




 75%|██████████████████████████████████████████████▌               | 75/100 [07:18<02:45,  6.60s/trial, best loss: 5.311234357303558]




 76%|███████████████████████████████████████████████               | 76/100 [07:25<02:42,  6.79s/trial, best loss: 5.311234357303558]




 77%|███████████████████████████████████████████████▋              | 77/100 [07:31<02:32,  6.65s/trial, best loss: 5.311234357303558]




 78%|████████████████████████████████████████████████▎             | 78/100 [07:36<02:16,  6.22s/trial, best loss: 5.311234357303558]




 79%|████████████████████████████████████████████████▉             | 79/100 [07:43<02:12,  6.31s/trial, best loss: 5.311234357303558]




 80%|█████████████████████████████████████████████████▌            | 80/100 [07:49<02:05,  6.29s/trial, best loss: 5.311234357303558]




 81%|██████████████████████████████████████████████████▏           | 81/100 [07:54<01:52,  5.92s/trial, best loss: 5.311234357303558]




 82%|██████████████████████████████████████████████████▊           | 82/100 [07:58<01:33,  5.22s/trial, best loss: 5.311234357303558]




 83%|███████████████████████████████████████████████████▍          | 83/100 [08:05<01:36,  5.70s/trial, best loss: 5.311234357303558]




 84%|████████████████████████████████████████████████████          | 84/100 [08:08<01:20,  5.05s/trial, best loss: 5.311234357303558]




 85%|████████████████████████████████████████████████████▋         | 85/100 [08:16<01:27,  5.85s/trial, best loss: 5.311234357303558]




 86%|█████████████████████████████████████████████████████▎        | 86/100 [08:19<01:09,  4.98s/trial, best loss: 5.311234357303558]




 87%|█████████████████████████████████████████████████████▉        | 87/100 [08:22<00:57,  4.44s/trial, best loss: 5.311234357303558]




 88%|██████████████████████████████████████████████████████▌       | 88/100 [08:28<00:59,  4.94s/trial, best loss: 5.311234357303558]




 89%|███████████████████████████████████████████████████████▏      | 89/100 [08:33<00:55,  5.01s/trial, best loss: 5.311234357303558]




 90%|███████████████████████████████████████████████████████▊      | 90/100 [08:39<00:52,  5.23s/trial, best loss: 5.311234357303558]




 91%|████████████████████████████████████████████████████████▍     | 91/100 [08:45<00:50,  5.56s/trial, best loss: 5.311234357303558]




 92%|█████████████████████████████████████████████████████████     | 92/100 [08:46<00:32,  4.05s/trial, best loss: 5.311234357303558]




 93%|█████████████████████████████████████████████████████████▋    | 93/100 [08:52<00:32,  4.63s/trial, best loss: 5.311234357303558]




 94%|██████████████████████████████████████████████████████████▎   | 94/100 [08:56<00:26,  4.39s/trial, best loss: 5.311234357303558]




 95%|██████████████████████████████████████████████████████████▉   | 95/100 [08:56<00:16,  3.29s/trial, best loss: 5.311234357303558]




 96%|███████████████████████████████████████████████████████████▌  | 96/100 [08:58<00:11,  2.84s/trial, best loss: 5.311234357303558]




 97%|████████████████████████████████████████████████████████████▏ | 97/100 [09:07<00:13,  4.51s/trial, best loss: 5.311234357303558]




 98%|████████████████████████████████████████████████████████████▊ | 98/100 [09:11<00:08,  4.41s/trial, best loss: 5.311234357303558]




 99%|█████████████████████████████████████████████████████████████▍| 99/100 [09:16<00:04,  4.60s/trial, best loss: 5.311234357303558]




100%|█████████████████████████████████████████████████████████████| 100/100 [09:21<00:00,  5.61s/trial, best loss: 5.311234357303558]



