In [1]:
import argparse
import os
import pickle

import mlflow
from hyperopt import hp, space_eval
from hyperopt.pyll import scope
from mlflow.entities import ViewType
from mlflow.tracking import MlflowClient
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

HPO_EXPERIMENT_NAME = "random-forest-hyperopt"
EXPERIMENT_NAME = "random-forest-best-models"

mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment(EXPERIMENT_NAME)
mlflow.sklearn.autolog()

SPACE = {
    'max_depth': scope.int(hp.quniform('max_depth', 1, 20, 1)),
    'n_estimators': scope.int(hp.quniform('n_estimators', 10, 50, 1)),
    'min_samples_split': scope.int(hp.quniform('min_samples_split', 2, 10, 1)),
    'min_samples_leaf': scope.int(hp.quniform('min_samples_leaf', 1, 4, 1)),
    'random_state': 42
}

2022/05/29 23:40:02 INFO mlflow.tracking.fluent: Experiment with name 'random-forest-best-models' does not exist. Creating a new experiment.


In [6]:
def load_pickle(filename):
    with open(filename, "rb") as f_in:
        return pickle.load(f_in)

In [2]:
def train_and_log_model(data_path, params):
    X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl"))
    X_valid, y_valid = load_pickle(os.path.join(data_path, "valid.pkl"))
    X_test, y_test = load_pickle(os.path.join(data_path, "test.pkl"))

    with mlflow.start_run():
        params = space_eval(SPACE, params)
        rf = RandomForestRegressor(**params)
        rf.fit(X_train, y_train)

        # evaluate model on the validation and test sets
        valid_rmse = mean_squared_error(y_valid, rf.predict(X_valid), squared=False)
        mlflow.log_metric("valid_rmse", valid_rmse)
        test_rmse = mean_squared_error(y_test, rf.predict(X_test), squared=False)
        mlflow.log_metric("test_rmse", test_rmse)

In [8]:
def run(data_path, log_top):

    client = MlflowClient()

    # retrieve the top_n model runs and log the models to MLflow
    experiment = client.get_experiment_by_name(HPO_EXPERIMENT_NAME)
    runs = client.search_runs(
        experiment_ids=experiment.experiment_id,
        run_view_type=ViewType.ACTIVE_ONLY,
        max_results=log_top,
        order_by=["metrics.rmse ASC"]
    )
    for run in runs:
        train_and_log_model(data_path=data_path, params=run.data.params)

    # select the model with the lowest test RMSE
    experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
    best_runs = client.search_runs(        
        experiment_ids=experiment.experiment_id,
        run_view_type=ViewType.ACTIVE_ONLY,
        max_results=log_top,
        order_by=["metrics.rmse ASC"])
    
    return best_runs

In [18]:
best_runs = run(data_path = "./output", log_top = 10)

In [29]:
best_runs[0].info

<RunInfo: artifact_uri='./mlflowruns/3/d8a08b43b2354d778b5203a60b47749d/artifacts', end_time=1653865656975, experiment_id='3', lifecycle_stage='active', run_id='d8a08b43b2354d778b5203a60b47749d', run_uuid='d8a08b43b2354d778b5203a60b47749d', start_time=1653865652665, status='FINISHED', user_id='jabarnett'>

In [30]:
dir(best_runs[0])

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_data',
 '_get_properties_helper',
 '_info',
 '_properties',
 'data',
 'from_dictionary',
 'from_proto',
 'info',
 'to_dictionary',
 'to_proto']

In [35]:
best_runs[0].data.metrics

{'training_score': 0.7501052942223351,
 'training_rmse': 5.780324753018479,
 'valid_rmse': 6.687538752784537,
 'test_rmse': 6.605413364774572,
 'training_r2_score': 0.7501052942223351,
 'training_mae': 3.9431812051896853,
 'training_mse': 33.41215425035814}

In [36]:
[i.data.metrics['test_rmse'] for i in best_runs]

[6.605413364774572,
 6.566445845484492,
 6.590336535010927,
 6.551613800771846,
 6.576883688294563,
 6.937440606755121,
 6.624796724443255,
 6.623414115210141,
 6.552297734408109,
 6.592055907932963]