# **Duration Prediction with MLflow**

In [1]:
import pandas as pd
import pickle
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

In [3]:
import mlflow

In [4]:
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location='/home/kylepaul/notebooks/mlops-zoom-camp-2022/session_2/mlruns/1', creation_time=1685346804195, experiment_id='1', last_update_time=1685346804195, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [5]:
def read_dataframe(filename):
    df = pd.read_parquet(filename)
    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)
    df = df[(df.duration >= 1) & (df.duration <= 60)]
    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    return df

df_train = read_dataframe('./data/green_tripdata_2022-01.parquet')
df_test = read_dataframe('./data/green_tripdata_2022-02.parquet')

In [23]:
df_train

Unnamed: 0,VendorID,lpep_pickup_datetime,lpep_dropoff_datetime,store_and_fwd_flag,RatecodeID,PULocationID,DOLocationID,passenger_count,trip_distance,fare_amount,...,mta_tax,tip_amount,tolls_amount,ehail_fee,improvement_surcharge,total_amount,payment_type,trip_type,congestion_surcharge,duration
0,2,2022-01-01 00:14:21,2022-01-01 00:15:33,N,1.0,42,42,1.0,0.44,3.50,...,0.5,0.00,0.0,,0.3,4.80,2.0,1.0,0.00,1.200000
1,1,2022-01-01 00:20:55,2022-01-01 00:29:38,N,1.0,116,41,1.0,2.10,9.50,...,0.5,0.00,0.0,,0.3,10.80,2.0,1.0,0.00,8.716667
2,1,2022-01-01 00:57:02,2022-01-01 01:13:14,N,1.0,41,140,1.0,3.70,14.50,...,0.5,4.60,0.0,,0.3,23.15,1.0,1.0,2.75,16.200000
3,2,2022-01-01 00:07:42,2022-01-01 00:15:57,N,1.0,181,181,1.0,1.69,8.00,...,0.5,0.00,0.0,,0.3,9.30,2.0,1.0,0.00,8.250000
4,2,2022-01-01 00:07:50,2022-01-01 00:28:52,N,1.0,33,170,1.0,6.26,22.00,...,0.5,5.21,0.0,,0.3,31.26,1.0,1.0,2.75,21.033333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
62490,2,2022-01-31 23:25:00,2022-01-31 23:33:00,,,40,65,,1.40,8.38,...,0.0,1.93,0.0,,0.3,10.61,,,,8.000000
62491,2,2022-01-31 23:52:00,2022-02-01 00:10:00,,,36,61,,2.97,14.92,...,0.0,0.00,0.0,,0.3,15.22,,,,18.000000
62492,2,2022-01-31 23:17:00,2022-01-31 23:36:00,,,75,167,,3.70,16.26,...,0.0,0.00,0.0,,0.3,16.56,,,,19.000000
62493,2,2022-01-31 23:45:00,2022-01-31 23:55:00,,,116,166,,1.88,9.48,...,0.0,2.17,0.0,,0.3,11.95,,,,10.000000


In [6]:
df_test

Unnamed: 0,VendorID,lpep_pickup_datetime,lpep_dropoff_datetime,store_and_fwd_flag,RatecodeID,PULocationID,DOLocationID,passenger_count,trip_distance,fare_amount,...,mta_tax,tip_amount,tolls_amount,ehail_fee,improvement_surcharge,total_amount,payment_type,trip_type,congestion_surcharge,duration
0,2,2022-02-01 00:20:21,2022-02-01 00:24:30,N,1.0,43,238,1.0,1.16,5.50,...,0.5,1.02,0.0,,0.3,7.82,1.0,1.0,0.00,4.150000
1,2,2022-02-01 00:32:26,2022-02-01 00:35:31,N,1.0,166,24,1.0,0.57,4.50,...,0.5,0.00,0.0,,0.3,5.80,2.0,1.0,0.00,3.083333
2,1,2022-02-01 00:17:27,2022-02-01 00:44:44,N,1.0,226,219,1.0,0.00,42.20,...,0.5,0.00,0.0,,0.3,43.00,1.0,1.0,0.00,27.283333
3,2,2022-02-01 00:45:37,2022-02-01 01:27:16,N,1.0,89,83,1.0,16.62,49.00,...,0.5,0.00,0.0,,0.3,50.30,2.0,1.0,0.00,41.650000
4,2,2022-02-01 00:06:46,2022-02-01 00:30:06,N,1.0,7,238,1.0,5.97,21.00,...,0.5,4.50,0.0,,0.3,29.55,1.0,1.0,2.75,23.333333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69394,2,2022-02-28 23:27:00,2022-02-28 23:38:00,,,65,87,,2.61,13.60,...,0.0,3.66,0.0,,0.3,20.31,,,,11.000000
69395,2,2022-02-28 23:59:00,2022-03-01 00:10:00,,,97,231,,2.88,12.07,...,0.0,3.00,0.0,,0.3,18.12,,,,11.000000
69396,2,2022-02-28 23:18:00,2022-02-28 23:27:00,,,74,116,,2.22,10.68,...,0.0,1.22,0.0,,0.3,12.20,,,,9.000000
69397,2,2022-02-28 23:31:00,2022-02-28 23:39:00,,,42,69,,1.59,8.88,...,0.0,0.00,0.0,,0.3,9.18,,,,8.000000


In [7]:
len(df_train), len(df_test)

(59603, 66097)

In [8]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_test['PU_DO'] = df_test['PULocationID'] + '_' + df_test['DOLocationID']

df_test['PU_DO']

0         43_238
1         166_24
2        226_219
3          89_83
4          7_238
          ...   
69394      65_87
69395     97_231
69396     74_116
69397      42_69
69398    243_100
Name: PU_DO, Length: 66097, dtype: object

In [9]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()
train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

test_dicts = df_test[categorical + numerical].to_dict(orient='records')
X_test = dv.transform(test_dicts)

In [10]:
target = 'duration'
y_train = df_train[target].values
y_test = df_test[target].values

In [11]:
from sklearn.metrics import mean_squared_error
model_lr = LinearRegression()
model_lr.fit(X_train, y_train)

y_test_pred = model_lr.predict(X_test)
mean_squared_error(y_test, y_test_pred, squared=False)

6.901396794686184

In [12]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, model_lr), f_out)

**Why MLflow and artifact?**
- The function of mlflow artifact is to **log and store any output files that are generated during an mlflow run**. You can use them to keep track of your models, images, text files, or any other type of output. You should use them because they enable you to:
    - Reproduce your experiments by having all the necessary files `in one place`.
    - `Compare and evaluate` different models based on their artifacts
    - `Deploy` your models easily using the MLmodel format and Azure Machine Learning
    - Use your models as inputs for pipelines or `Responsible AI dashboard`
    - `Load` your models using a standard interface regardless of their format
- Artifacts are an essential part of mlflow tracking and can help you **manage your machine learning lifecycle more effectively.**

**What is artifact?**
- An artifact is a file or directory that is logged as part of an `mlflow run`. 
- It can be any type of output, such as a model, an image, a text file, etc. 
- Artifacts are stored in an artifact store, which can be a local file system, a remote server, or a cloud storage service. 
- You can use the `mlflow.artifacts` module to download, load, or log artifacts. For example, the code below logs an artifact with the local path `models/lin_reg.bin` and the artifact path `models_pickle`. This means that the file `models/lin_reg.bin` will be copied to the artifact store under the subdirectory `models_pickle` within the run’s artifact URI.

**Explain the code:**
- It starts a new mlflow run with a **unique ID** and a **default experiment name**.
- It sets a tag with the name `developer` and the value `cristian` for the current run
- It logs two parameters with the names `train-data-path` and `valid-data-path` and the values `./data/green_tripdata_2021-01.csv` and `./data/green_tripdata_2021-02.csv` respectively. These parameters indicate the paths to the training and validation data files
- It assigns a value of `0.1` to a variable called `alpha` and logs it as another parameter with the name `alpha`. This parameter is used to control the **regularization strength** of the Lasso regression model
- It creates an **instance of the Lasso** class with the alpha value and fits it to the training data `X_train` and `y_train`
- It makes **predictions** on the validation data `X_test` using the fitted model and stores them in a variable called `y_test_pred`
- It calculates the root mean squared error `rmse` between the predictions and the actual values `y_test` and logs it as a **metric** with the name `rmse`. This metric is used to evaluate the performance of the model
- It logs an **artifact** with the local path `models/lin_reg.bin` and the artifact path `models_pickle`. This artifact is a **binary file** that contains the trained model

In [13]:
with mlflow.start_run():
    mlflow.set_tag("developer", "kyle-paul")

    mlflow.log_param("train-data-path", "./data/green_tripdata_2022-01.parquet")
    mlflow.log_param("valid-data-path", "./data/green_tripdata_2022-02.parquet")

    alpha = 0.1
    mlflow.log_param("alpha", alpha)
    model_lr = Lasso(alpha)
    model_lr.fit(X_train, y_train)

    y_test_pred = model_lr.predict(X_test)
    rmse = mean_squared_error(y_test, y_test_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle")

**Overview over xgboost and hyperopt library:**

- **xgboost** is a library that implements `gradient boosting machines`, which are a type of `supervised learning algorithm` that can handle regression, classification, ranking, and other tasks. Xgboost is known for its speed, scalability, and performance.
- **hyperopt** is a library that implements `Bayesian optimization`, which is a method for finding the optimal values of a function over a complex space of inputs. hyperopt can be used for `hyperparameter tuning`, which is the process of finding the best settings for a machine learning model.
- **xgboost** and **hyperopt** can be used together to train and tune gradient boosting models efficiently and effectively. You can use hyperopt to define a search space of hyperparameters for xgboost, such as learning rate, number of trees, depth of trees, etc. Then you can use hyperopt to optimize the loss function of xgboost over the search space using a smart algorithm that learns from previous trials. Then you can use **mlflow** to track and compare the results of different trials and models.

**Explain the code:**
This code uses the `xgboost` and `hyperopt` libraries for training and tuning a **gradient boosting model**.
The code does the following steps:

- It imports the xgboost and hyperopt modules
- It creates two `xgboost DMatrix objects` from the training and test data `X_train, y_train, X_test, y_test`
- It defines an `objective function` that takes a `dictionary of parameters` as input and returns a `dictionary of loss and status` as output
- Inside the objective function, it does the following:
    - It starts a new mlflow run and sets a tag with the name `model` and the value `xgboost`
    - It logs the parameters to mlflow using `mlflow.log_params()`
    - It trains a xgboost model using `xgb.train()` with the given parameters, the training data, and early stopping based on the validation data
    - It makes predictions on the test data using `booster.predict()` and stores them in a variable called `y_test_pred`
    - It calculates the root mean squared error `rmse` between the predictions and the actual values `y_test` and logs it to mlflow using `mlflow.log_metric()`
    - It returns a `dictionary` with the rmse as the loss and STATUS_OK as the status

In [14]:
import xgboost as xgb
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [15]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_test, label=y_test)

In [16]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_test_pred = booster.predict(valid)
        rmse = mean_squared_error(y_test, y_test_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

**Explain the code:**
This code uses the hyperopt library for hyperparameter tuning with following steps:

- It defines a **search space** for the `xgboost hyperparameters` using the **hp** and **scope** modules from `hyperopt`. The search space is a `dictionary` that maps each hyperparameter name to a distribution or a value. For example:
    - `max_depth` is sampled from a discrete uniform distribution between 4 and 100.
    - `learning_rate` is sampled from a log-uniform distribution between 0.001 and 1.
    - `objective` is fixed to `reg:linear`.
    - The `scope.int` function is used to cast the sampled values to integers when needed.
- It calls the **fmin function** from hyperopt to find the `best hyperparameters` that minimize the objective function defined earlier. The fmin function takes the following arguments:
    - `fn`: the objective function to minimize
    - `space`: the search space for the hyperparameters
    - `algo`: the optimization algorithm to use, in this case `tpe.suggest` which is a **tree-structured Parzen estimator**
    - `max_evals`: the maximum number of evaluations or trials to perform
    - `trials`: an object that stores information about each trial
- It stores the best result in a variable called `best_result`, which is a dictionary that contains the best hyperparameters and the best loss value

In [42]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

[0]	validation-rmse:8.82007                                          
[1]	validation-rmse:6.65455                                          
[2]	validation-rmse:6.18643                                          
[3]	validation-rmse:6.07380                                          
[4]	validation-rmse:6.02657                                          
[5]	validation-rmse:6.00926                                          
[6]	validation-rmse:5.99143                                          
[7]	validation-rmse:5.97534                                          
[8]	validation-rmse:5.97229                                          
[9]	validation-rmse:5.97670                                          
[10]	validation-rmse:5.97546                                         
[11]	validation-rmse:5.97309                                         
[12]	validation-rmse:5.97164                                         
[13]	validation-rmse:5.97112                                         
[14]	validation-rmse

[96]	validation-rmse:6.03885                                         
[97]	validation-rmse:6.03859                                         
[98]	validation-rmse:6.03819                                         
[99]	validation-rmse:6.03786                                         
[100]	validation-rmse:6.03740                                        
[101]	validation-rmse:6.03688                                        
[102]	validation-rmse:6.03654                                        
[103]	validation-rmse:6.03624                                        
[104]	validation-rmse:6.03602                                        
[105]	validation-rmse:6.03567                                        
[106]	validation-rmse:6.03546                                        
[107]	validation-rmse:6.03530                                        
[108]	validation-rmse:6.03517                                        
[109]	validation-rmse:6.03499                                        
[110]	validation-rms

[330]	validation-rmse:6.01880                                        
[331]	validation-rmse:6.01878                                        
[332]	validation-rmse:6.01872                                        
[333]	validation-rmse:6.01866                                        
[334]	validation-rmse:6.01863                                        
[335]	validation-rmse:6.01865                                        
[336]	validation-rmse:6.01864                                        
[337]	validation-rmse:6.01870                                        
[338]	validation-rmse:6.01850                                        
[339]	validation-rmse:6.01851                                        
[340]	validation-rmse:6.01856                                        
[341]	validation-rmse:6.01864                                        
[342]	validation-rmse:6.01858                                        
[343]	validation-rmse:6.01859                                        
[344]	validation-rms

[564]	validation-rmse:6.01249                                        
[565]	validation-rmse:6.01252                                        
[566]	validation-rmse:6.01248                                        
[567]	validation-rmse:6.01253                                        
[568]	validation-rmse:6.01260                                        
[569]	validation-rmse:6.01263                                        
[570]	validation-rmse:6.01256                                        
[571]	validation-rmse:6.01254                                        
[572]	validation-rmse:6.01257                                        
[573]	validation-rmse:6.01252                                        
[574]	validation-rmse:6.01248                                        
[575]	validation-rmse:6.01252                                        
[576]	validation-rmse:6.01268                                        
[577]	validation-rmse:6.01260                                        
[578]	validation-rms

[39]	validation-rmse:6.15301                                         
[40]	validation-rmse:6.16021                                         
[41]	validation-rmse:6.16795                                         
[42]	validation-rmse:6.17434                                         
[43]	validation-rmse:6.17909                                         
[44]	validation-rmse:6.18259                                         
[45]	validation-rmse:6.18737                                         
[46]	validation-rmse:6.19258                                         
[47]	validation-rmse:6.19752                                         
[48]	validation-rmse:6.19893                                         
[49]	validation-rmse:6.20335                                         
[50]	validation-rmse:6.20341                                         
[51]	validation-rmse:6.21013                                         
[52]	validation-rmse:6.21753                                         
[53]	validation-rmse

[212]	validation-rmse:5.82910                                        
[213]	validation-rmse:5.82921                                        
[0]	validation-rmse:11.48781                                         
[1]	validation-rmse:8.49382                                          
[2]	validation-rmse:7.11929                                          
[3]	validation-rmse:6.51371                                          
[4]	validation-rmse:6.24658                                          
[5]	validation-rmse:6.12319                                          
[6]	validation-rmse:6.06411                                          
[7]	validation-rmse:6.02387                                          
[8]	validation-rmse:6.00891                                          
[9]	validation-rmse:5.99871                                          
[10]	validation-rmse:5.99191                                         
[11]	validation-rmse:5.97829                                         
[12]	validation-rmse

[230]	validation-rmse:5.80641                                        
[231]	validation-rmse:5.80629                                        
[232]	validation-rmse:5.80645                                        
[233]	validation-rmse:5.80636                                        
[234]	validation-rmse:5.80617                                        
[235]	validation-rmse:5.80605                                        
[236]	validation-rmse:5.80666                                        
[237]	validation-rmse:5.80672                                        
[238]	validation-rmse:5.80656                                        
[239]	validation-rmse:5.80683                                        
[240]	validation-rmse:5.80681                                        
[241]	validation-rmse:5.80707                                        
[242]	validation-rmse:5.80660                                        
[243]	validation-rmse:5.80564                                        
[244]	validation-rms

[149]	validation-rmse:5.93891                                        
[150]	validation-rmse:5.93895                                        
[151]	validation-rmse:5.93892                                        
[152]	validation-rmse:5.93935                                        
[153]	validation-rmse:5.93930                                        
[154]	validation-rmse:5.93926                                        
[155]	validation-rmse:5.93975                                        
[156]	validation-rmse:5.93992                                        
[157]	validation-rmse:5.93998                                        
[158]	validation-rmse:5.94017                                        
[159]	validation-rmse:5.94012                                        
[160]	validation-rmse:5.94033                                        
[161]	validation-rmse:5.94040                                        
[0]	validation-rmse:12.46461                                         
[1]	validation-rmse:

[219]	validation-rmse:5.87787                                        
[220]	validation-rmse:5.87751                                        
[221]	validation-rmse:5.87748                                        
[222]	validation-rmse:5.87762                                        
[223]	validation-rmse:5.87737                                        
[224]	validation-rmse:5.87671                                        
[225]	validation-rmse:5.87647                                        
[226]	validation-rmse:5.87502                                        
[227]	validation-rmse:5.87471                                        
[228]	validation-rmse:5.87452                                        
[229]	validation-rmse:5.87419                                        
[230]	validation-rmse:5.87368                                        
[231]	validation-rmse:5.87319                                        
[232]	validation-rmse:5.87363                                        
[233]	validation-rms

[28]	validation-rmse:6.76795                                         
[29]	validation-rmse:6.70495                                         
[30]	validation-rmse:6.64821                                         
[31]	validation-rmse:6.59634                                         
[32]	validation-rmse:6.54894                                         
[33]	validation-rmse:6.50697                                         
[34]	validation-rmse:6.46820                                         
[35]	validation-rmse:6.43261                                         
[36]	validation-rmse:6.40093                                         
[37]	validation-rmse:6.37240                                         
[38]	validation-rmse:6.34618                                         
[39]	validation-rmse:6.32298                                         
[40]	validation-rmse:6.30194                                         
[41]	validation-rmse:6.28275                                         
[42]	validation-rmse

[262]	validation-rmse:6.02154                                        
[263]	validation-rmse:6.02145                                        
[264]	validation-rmse:6.02143                                        
[265]	validation-rmse:6.02147                                        
[266]	validation-rmse:6.02117                                        
[267]	validation-rmse:6.02116                                        
[268]	validation-rmse:6.02131                                        
[269]	validation-rmse:6.02105                                        
[270]	validation-rmse:6.02106                                        
[271]	validation-rmse:6.02109                                        
[272]	validation-rmse:6.02086                                        
[273]	validation-rmse:6.02081                                        
[274]	validation-rmse:6.02085                                        
[275]	validation-rmse:6.02054                                        
[276]	validation-rms

[106]	validation-rmse:5.92014                                        
[107]	validation-rmse:5.92126                                        
[108]	validation-rmse:5.92109                                        
[109]	validation-rmse:5.92104                                        
[110]	validation-rmse:5.92074                                        
[111]	validation-rmse:5.92240                                        
[112]	validation-rmse:5.92247                                        
[113]	validation-rmse:5.92255                                        
[114]	validation-rmse:5.92346                                        
[115]	validation-rmse:5.92412                                        
[116]	validation-rmse:5.92416                                        
[117]	validation-rmse:5.92395                                        
[118]	validation-rmse:5.92382                                        
[119]	validation-rmse:5.92409                                        
[120]	validation-rms

[194]	validation-rmse:5.95882                                        
[195]	validation-rmse:5.95897                                        
[196]	validation-rmse:5.95881                                        
[197]	validation-rmse:5.95889                                        
[198]	validation-rmse:5.95902                                        
[199]	validation-rmse:5.95898                                        
[200]	validation-rmse:5.95898                                        
[201]	validation-rmse:5.95906                                        
[202]	validation-rmse:5.95923                                        
[203]	validation-rmse:5.95934                                        
[204]	validation-rmse:5.95920                                        
[205]	validation-rmse:5.95897                                        
[206]	validation-rmse:5.95901                                        
[207]	validation-rmse:5.95905                                        
[208]	validation-rms

KeyboardInterrupt: 

**Explain the code:**
This code uses the **mlflow** and **xgboost** libraries for training and logging a gradient boosting model with following steps:

- It disables the `mlflow.xgboost.autolog()` function, which is a feature that automatically logs xgboost models and metrics to mlflow. This is done because the code will manually log the parameters, metrics, and models to mlflow later.
- It creates two `xgboost DMatrix` objects from the training and test data `X_train, y_train, X_test, y_test`
- It assigns a dictionary of best hyperparameters to a variable called `best_params`. These hyperparameters were obtained from the **previous hyperopt tuning process**.
- It logs the best hyperparameters to mlflow using `mlflow.log_params()`
- It trains a xgboost model using `xgb.train()` with the best hyperparameters, the training data, and early stopping based on the validation data
- It calculates the root mean squared error `rmse` between the predictions and the actual values `y_test` and logs it to mlflow using `mlflow.log_metric()`
- It saves a preprocessor object `dv` to a binary file called `models/preprocessor.b` using `pickle.dump()`
- It logs the preprocessor file as an `artifact` to mlflow using `mlflow.log_artifact()`, with the artifact path `preprocessor`
- It logs the xgboost model as an artifact to mlflow using `mlflow.xgboost.log_model()`, with the artifact path `models_mlflow`

In [17]:
mlflow.xgboost.autolog(disable=True)

train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_test, label=y_test)

best_params = {
    'learning_rate': 0.42573168654483656,
    'max_depth': 10,
    'min_child_weight': 3.2482026447444605,
    'objective': 'reg:linear',
    'reg_alpha': 0.04266778935139865,
    'reg_lambda': 0.011203537317262537,
    'seed': 42
}

mlflow.log_params(best_params)

booster = xgb.train(
    params=best_params,
    dtrain=train,
    num_boost_round=1000,
    evals=[(valid, 'validation')],
    early_stopping_rounds=50
)

y_test_pred = booster.predict(valid)
rmse = mean_squared_error(y_test, y_test_pred, squared=False)
mlflow.log_metric("rmse", rmse)

with open("models/preprocessor.b", "wb") as f_out:
    pickle.dump(dv, f_out)
    
mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")
mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")

[0]	validation-rmse:11.48781
[1]	validation-rmse:8.49382
[2]	validation-rmse:7.11929
[3]	validation-rmse:6.51371
[4]	validation-rmse:6.24658
[5]	validation-rmse:6.12319
[6]	validation-rmse:6.06411
[7]	validation-rmse:6.02387
[8]	validation-rmse:6.00891
[9]	validation-rmse:5.99871
[10]	validation-rmse:5.99191
[11]	validation-rmse:5.97829
[12]	validation-rmse:5.97262
[13]	validation-rmse:5.96825
[14]	validation-rmse:5.96507
[15]	validation-rmse:5.96157
[16]	validation-rmse:5.95978
[17]	validation-rmse:5.95782
[18]	validation-rmse:5.95288
[19]	validation-rmse:5.95163
[20]	validation-rmse:5.94989
[21]	validation-rmse:5.94738
[22]	validation-rmse:5.94205
[23]	validation-rmse:5.93992
[24]	validation-rmse:5.93798
[25]	validation-rmse:5.93591
[26]	validation-rmse:5.93527
[27]	validation-rmse:5.93288
[28]	validation-rmse:5.92967
[29]	validation-rmse:5.92766
[30]	validation-rmse:5.92448
[31]	validation-rmse:5.92186
[32]	validation-rmse:5.91963
[33]	validation-rmse:5.91635
[34]	validation-rmse:5.

[273]	validation-rmse:5.80539
[274]	validation-rmse:5.80529
[275]	validation-rmse:5.80498
[276]	validation-rmse:5.80546
[277]	validation-rmse:5.80544
[278]	validation-rmse:5.80517
[279]	validation-rmse:5.80513
[280]	validation-rmse:5.80623
[281]	validation-rmse:5.80634
[282]	validation-rmse:5.80635
[283]	validation-rmse:5.80648
[284]	validation-rmse:5.80674
[285]	validation-rmse:5.80643
[286]	validation-rmse:5.80646
[287]	validation-rmse:5.80613
[288]	validation-rmse:5.80598
[289]	validation-rmse:5.80618
[290]	validation-rmse:5.80634
[291]	validation-rmse:5.80624
[292]	validation-rmse:5.80611
[293]	validation-rmse:5.80612
[294]	validation-rmse:5.80618
[295]	validation-rmse:5.80627
[296]	validation-rmse:5.80604
[297]	validation-rmse:5.80600
[298]	validation-rmse:5.80610
[299]	validation-rmse:5.80603
[300]	validation-rmse:5.80600
[301]	validation-rmse:5.80591
[302]	validation-rmse:5.80559
[303]	validation-rmse:5.80566
[304]	validation-rmse:5.80599
[305]	validation-rmse:5.80723
[306]	vali



<mlflow.models.model.ModelInfo at 0x7f1998705570>

In [22]:
logged_model = 'runs:/8264fb3eb8fb47f69ce47e08af58922b/models_mlflow'

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)



In [23]:
loaded_model

mlflow.pyfunc.loaded_model:
  artifact_path: models_mlflow
  flavor: mlflow.xgboost
  run_id: 8264fb3eb8fb47f69ce47e08af58922b

In [24]:
xgboost_model = mlflow.xgboost.load_model(logged_model)



In [25]:
xgboost_model

<xgboost.core.Booster at 0x7f19cdc4b820>

In [27]:
y_pred = xgboost_model.predict(valid)
y_pred[:10]

array([ 6.8129897,  4.1037087, 26.378426 , 36.658863 , 27.697872 ,
        9.381057 , 19.713657 ,  4.2832828, 14.368863 ,  5.8005676],
      dtype=float32)

**Explain the code:**
This code uses the `mlflow` and `sklearn` libraries for training and logging different regression models with following steps:

- It imports sklearn modules
- It enables the `mlflow.sklearn.autolog()` function, which is a feature that automatically logs sklearn models and metrics to mlflow
- It defines a `loop` over four model classes from sklearn: `RandomForestRegressor`, `GradientBoostingRegressor`, `ExtraTreesRegressor`, and `LinearSVR`
- Inside the loop, it does the following for each model class:
    - It starts a new mlflow run
    - It logs two parameters with the names `train-data-path` and `valid-data-path` and the values `./data/green_tripdata_2021-01.csv` and `./data/green_tripdata_2021-02.csv` respectively. These parameters indicate the paths to the training and validation data files
    - It logs an artifact with the local path `models/preprocessor.b` and the artifact path `preprocessor`. This artifact is a binary file that contains the preprocessor object `dv` that was saved earlier
    - It creates an instance of the model class with default hyperparameters and fits it to the training data `X_train, y_train`
    - It makes predictions on the validation data (X_test) using the fitted model and stores them in a variable called y_pred
    - It calculates the root mean squared error `rmse` between the predictions and the actual values (y_test) and logs it to mlflow using `mlflow.log_metric()`

In [None]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR

mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):

    with mlflow.start_run(nested=True):

        mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
        mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_test_pred = mlmodel.predict(X_test)
        rmse = mean_squared_error(y_test, y_test_pred, squared=False)
        mlflow.log_metric("rmse", rmse)