In [1]:
!python -V

Python 3.9.7


In [2]:
!pip install pyarrow



In [3]:
!pip install mlflow



# Setup

In [4]:
import pandas as pd
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error

In [5]:
!export MLFLOW_TRACKING_URI sqlite:///mlflow.db

/bin/bash: line 0: export: `sqlite:///mlflow.db': not a valid identifier


In [6]:
#from mlflow.tracking import MlflowClient
#MlflowClient
#tracking_uri = MlflowClient.mlflow_tracking_uri

In [7]:
# import mlflow
import mlflow
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment_2")#, artifact_location='.') # mlflow checks, if this experiment exists, if not it is created, if not runs are  appended
#mlflow.set_tracking_uri("nyc-test") #creates local tracking folder
#mlflow.set_experiment("/test")
print("Tracking URI:", mlflow.tracking.get_tracking_uri())

Tracking URI: sqlite:///mlflow.db


# Read the Data

In [8]:
df = pd.read_parquet('../data/green_tripdata_2021-01.parquet')

df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

df = df[(df.duration >= 1) & (df.duration <= 60)]

categorical = ['PULocationID', 'DOLocationID']
numerical = ['trip_distance']

df[categorical] = df[categorical].astype(str)

# Data Preparation

In [9]:
train_dicts = df[categorical + numerical].to_dict(orient='records')

dv = DictVectorizer()
X_train = dv.fit_transform(train_dicts)

target = 'duration'
y_train = df[target].values

In [10]:
def read_dataframe(filename):
    df = pd.read_parquet(filename)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [11]:
df_train = read_dataframe('../data/green_tripdata_2021-01.parquet')
df_val = read_dataframe('../data/green_tripdata_2021-02.parquet')

In [12]:
len(df_train), len(df_val)

(73908, 61921)

In [13]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [14]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [15]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [16]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

mean_squared_error(y_val, y_pred, squared=False)

7.75871520559622

In [17]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

# Model - Experiment Tracking with MLflow

In [18]:
with mlflow.start_run(run_name="test_run_1") as run:
    
    mlflow.set_tag("developer", "frauke")
    
    mlflow.log_param("train-data-path", "../data/green_tripdata_2021-01.parquet")
    mlflow.log_param("valid-data-path", "../data/green_tripdata_2021-02.parquet")
    
    alpha = 0.01
    mlflow.log_param("alpha", alpha)
    
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)

    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

# Second Model - with Hyperparameter Tuning

In [19]:
!pip install hyperopt



In [20]:
!pip install xgboost



In [21]:
import xgboost as xgb
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [22]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [23]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        # log paramters in mlflow
        mlflow.log_params(params)
        booster = xgb.train(
            # paramters are passed to xgboost
            params=params,
            # training on train data
            dtrain=train,
            # set boosting rounds
            num_boost_round=100,
            # validation is done on validation dataset
            evals=[(valid, 'validation')],
            # if model does not improve for 50 methods->stop
            early_stopping_rounds=50
        )
        # make predictions
        y_pred = booster.predict(valid)
        # calculate error
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        # log metric
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [24]:
# define the search space, i.e. the range of parameters for hyperparamter tuning
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0), #[exp(-3), exp(0)] = [0.05, 1]
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

# fmin method tries to minimize the metric
best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

[0]	validation-rmse:12.76520                                                                                                        
[1]	validation-rmse:9.08599                                                                                                         
[2]	validation-rmse:7.64047                                                                                                         
[3]	validation-rmse:7.09458                                                                                                         
[4]	validation-rmse:6.88028                                                                                                         
[5]	validation-rmse:6.78690                                                                                                         
[6]	validation-rmse:6.73811                                                                                                         
[7]	validation-rmse:6.71413                                          

[60]	validation-rmse:6.65294                                                                                                        
[61]	validation-rmse:6.65375                                                                                                        
[62]	validation-rmse:6.65333                                                                                                        
[63]	validation-rmse:6.65301                                                                                                        
[64]	validation-rmse:6.65299                                                                                                        
[65]	validation-rmse:6.65270                                                                                                        
[66]	validation-rmse:6.65299                                                                                                        
[67]	validation-rmse:6.65265                                         

[20]	validation-rmse:6.52786                                                                                                        
[21]	validation-rmse:6.52325                                                                                                        
[22]	validation-rmse:6.52026                                                                                                        
[23]	validation-rmse:6.51825                                                                                                        
[24]	validation-rmse:6.51618                                                                                                        
[25]	validation-rmse:6.51360                                                                                                        
[26]	validation-rmse:6.51226                                                                                                        
[27]	validation-rmse:6.51021                                         

[81]	validation-rmse:6.47195                                                                                                        
[82]	validation-rmse:6.47178                                                                                                        
[83]	validation-rmse:6.47092                                                                                                        
[84]	validation-rmse:6.46990                                                                                                        
[85]	validation-rmse:6.47142                                                                                                        
[86]	validation-rmse:6.46998                                                                                                        
[87]	validation-rmse:6.46939                                                                                                        
[88]	validation-rmse:6.46905                                         

[41]	validation-rmse:6.65686                                                                                                        
[42]	validation-rmse:6.65605                                                                                                        
[43]	validation-rmse:6.65551                                                                                                        
[44]	validation-rmse:6.65499                                                                                                        
[45]	validation-rmse:6.65404                                                                                                        
[46]	validation-rmse:6.65332                                                                                                        
[47]	validation-rmse:6.65273                                                                                                        
[48]	validation-rmse:6.65213                                         

[1]	validation-rmse:9.08878                                                                                                         
[2]	validation-rmse:7.57494                                                                                                         
[3]	validation-rmse:6.98290                                                                                                         
[4]	validation-rmse:6.74228                                                                                                         
[5]	validation-rmse:6.63329                                                                                                         
[6]	validation-rmse:6.58423                                                                                                         
[7]	validation-rmse:6.55092                                                                                                         
[8]	validation-rmse:6.53277                                          

[62]	validation-rmse:6.41668                                                                                                        
[63]	validation-rmse:6.41619                                                                                                        
[64]	validation-rmse:6.41566                                                                                                        
[65]	validation-rmse:6.41538                                                                                                        
[66]	validation-rmse:6.41588                                                                                                        
[67]	validation-rmse:6.41475                                                                                                        
[68]	validation-rmse:6.41418                                                                                                        
[69]	validation-rmse:6.41285                                         

[22]	validation-rmse:9.58634                                                                                                        
[23]	validation-rmse:9.38120                                                                                                        
[24]	validation-rmse:9.19063                                                                                                        
[25]	validation-rmse:9.01352                                                                                                        
[26]	validation-rmse:8.84901                                                                                                        
[27]	validation-rmse:8.69659                                                                                                        
[28]	validation-rmse:8.55530                                                                                                        
[29]	validation-rmse:8.42461                                         

[83]	validation-rmse:6.77559                                                                                                        
[84]	validation-rmse:6.77216                                                                                                        
[85]	validation-rmse:6.76863                                                                                                        
[86]	validation-rmse:6.76521                                                                                                        
[87]	validation-rmse:6.76200                                                                                                        
[88]	validation-rmse:6.75894                                                                                                        
[89]	validation-rmse:6.75622                                                                                                        
[90]	validation-rmse:6.75372                                         

[43]	validation-rmse:6.68217                                                                                                        
[44]	validation-rmse:6.68396                                                                                                        
[45]	validation-rmse:6.68528                                                                                                        
[46]	validation-rmse:6.68884                                                                                                        
[47]	validation-rmse:6.69479                                                                                                        
[48]	validation-rmse:6.69693                                                                                                        
[49]	validation-rmse:6.69756                                                                                                        
[50]	validation-rmse:6.70164                                         

[17]	validation-rmse:6.51501                                                                                                        
[18]	validation-rmse:6.50929                                                                                                        
[19]	validation-rmse:6.50557                                                                                                        
[20]	validation-rmse:6.50135                                                                                                        
[21]	validation-rmse:6.49952                                                                                                        
[22]	validation-rmse:6.49660                                                                                                        
[23]	validation-rmse:6.49394                                                                                                        
[24]	validation-rmse:6.49183                                         

[78]	validation-rmse:6.40958                                                                                                        
[79]	validation-rmse:6.40839                                                                                                        
[80]	validation-rmse:6.40781                                                                                                        
[81]	validation-rmse:6.40671                                                                                                        
[82]	validation-rmse:6.40613                                                                                                        
[83]	validation-rmse:6.40503                                                                                                        
[84]	validation-rmse:6.40472                                                                                                        
[85]	validation-rmse:6.40426                                         

[38]	validation-rmse:6.65599                                                                                                        
[39]	validation-rmse:6.65558                                                                                                        
[40]	validation-rmse:6.65545                                                                                                        
[41]	validation-rmse:6.65854                                                                                                        
[42]	validation-rmse:6.65720                                                                                                        
[43]	validation-rmse:6.65752                                                                                                        
[44]	validation-rmse:6.65795                                                                                                        
[45]	validation-rmse:6.65982                                         

[33]	validation-rmse:6.77009                                                                                                        
[34]	validation-rmse:6.76580                                                                                                        
[35]	validation-rmse:6.76135                                                                                                        
[36]	validation-rmse:6.75778                                                                                                        
[37]	validation-rmse:6.75431                                                                                                        
[38]	validation-rmse:6.75025                                                                                                        
[39]	validation-rmse:6.74701                                                                                                        
[40]	validation-rmse:6.74533                                         

[94]	validation-rmse:6.69771                                                                                                        
[95]	validation-rmse:6.69686                                                                                                        
[96]	validation-rmse:6.69623                                                                                                        
[97]	validation-rmse:6.69562                                                                                                        
[98]	validation-rmse:6.69527                                                                                                        
[99]	validation-rmse:6.69516                                                                                                        
[0]	validation-rmse:18.08780                                                                                                        
[1]	validation-rmse:15.58297                                         

[54]	validation-rmse:6.46538                                                                                                        
[55]	validation-rmse:6.46412                                                                                                        
[56]	validation-rmse:6.46360                                                                                                        
[57]	validation-rmse:6.46261                                                                                                        
[58]	validation-rmse:6.46153                                                                                                        
[59]	validation-rmse:6.46108                                                                                                        
[60]	validation-rmse:6.45977                                                                                                        
[61]	validation-rmse:6.45914                                         

[14]	validation-rmse:6.64292                                                                                                        
[15]	validation-rmse:6.63400                                                                                                        
[16]	validation-rmse:6.63125                                                                                                        
[17]	validation-rmse:6.62894                                                                                                        
[18]	validation-rmse:6.62524                                                                                                        
[19]	validation-rmse:6.62138                                                                                                        
[20]	validation-rmse:6.61757                                                                                                        
[21]	validation-rmse:6.61691                                         

[75]	validation-rmse:6.61039                                                                                                        
[76]	validation-rmse:6.60998                                                                                                        
[77]	validation-rmse:6.61645                                                                                                        
[78]	validation-rmse:6.61815                                                                                                        
[79]	validation-rmse:6.61764                                                                                                        
[80]	validation-rmse:6.61929                                                                                                        
[81]	validation-rmse:6.61934                                                                                                        
[82]	validation-rmse:6.61984                                         

[37]	validation-rmse:6.60895                                                                                                        
[38]	validation-rmse:6.59579                                                                                                        
[39]	validation-rmse:6.58381                                                                                                        
[40]	validation-rmse:6.57362                                                                                                        
[41]	validation-rmse:6.56397                                                                                                        
[42]	validation-rmse:6.55499                                                                                                        
[43]	validation-rmse:6.54717                                                                                                        
[44]	validation-rmse:6.54024                                         

[98]	validation-rmse:6.45132                                                                                                        
[99]	validation-rmse:6.45084                                                                                                        
[0]	validation-rmse:19.11516                                                                                                        
[1]	validation-rmse:17.30604                                                                                                        
[2]	validation-rmse:15.73517                                                                                                        
[3]	validation-rmse:14.37512                                                                                                        
[4]	validation-rmse:13.20309                                                                                                        
[5]	validation-rmse:12.19649                                         

[58]	validation-rmse:6.67467                                                                                                        
[59]	validation-rmse:6.67394                                                                                                        
[60]	validation-rmse:6.67349                                                                                                        
[61]	validation-rmse:6.67271                                                                                                        
[62]	validation-rmse:6.67236                                                                                                        
[63]	validation-rmse:6.67167                                                                                                        
[64]	validation-rmse:6.67122                                                                                                        
[65]	validation-rmse:6.67075                                         

[18]	validation-rmse:8.56717                                                                                                        
[19]	validation-rmse:8.36235                                                                                                        
[20]	validation-rmse:8.17828                                                                                                        
[21]	validation-rmse:8.01426                                                                                                        
[22]	validation-rmse:7.86841                                                                                                        
[23]	validation-rmse:7.73848                                                                                                        
[24]	validation-rmse:7.62225                                                                                                        
[25]	validation-rmse:7.51940                                         

[79]	validation-rmse:6.57674                                                                                                        
[80]	validation-rmse:6.57575                                                                                                        
[81]	validation-rmse:6.57481                                                                                                        
[82]	validation-rmse:6.57385                                                                                                        
[83]	validation-rmse:6.57260                                                                                                        
[84]	validation-rmse:6.57135                                                                                                        
[85]	validation-rmse:6.57052                                                                                                        
[86]	validation-rmse:6.56911                                         

[39]	validation-rmse:6.44063                                                                                                        
[40]	validation-rmse:6.43988                                                                                                        
[41]	validation-rmse:6.44249                                                                                                        
[42]	validation-rmse:6.44262                                                                                                        
[43]	validation-rmse:6.44219                                                                                                        
[44]	validation-rmse:6.44132                                                                                                        
[45]	validation-rmse:6.44124                                                                                                        
[46]	validation-rmse:6.43971                                         

[3]	validation-rmse:12.68039                                                                                                        
[4]	validation-rmse:11.42473                                                                                                        
[5]	validation-rmse:10.41176                                                                                                        
[6]	validation-rmse:9.60207                                                                                                         
[7]	validation-rmse:8.95972                                                                                                         
[8]	validation-rmse:8.45414                                                                                                         
[9]	validation-rmse:8.05893                                                                                                         
[10]	validation-rmse:7.74740                                         

[64]	validation-rmse:6.51433                                                                                                        
[65]	validation-rmse:6.51394                                                                                                        
[66]	validation-rmse:6.51376                                                                                                        
[67]	validation-rmse:6.51302                                                                                                        
[68]	validation-rmse:6.51275                                                                                                        
[69]	validation-rmse:6.51229                                                                                                        
[70]	validation-rmse:6.51175                                                                                                        
[71]	validation-rmse:6.51144                                         

[24]	validation-rmse:8.97437                                                                                                        
[25]	validation-rmse:8.79859                                                                                                        
[26]	validation-rmse:8.63599                                                                                                        
[27]	validation-rmse:8.48529                                                                                                        
[28]	validation-rmse:8.34563                                                                                                        
[29]	validation-rmse:8.21660                                                                                                        
[30]	validation-rmse:8.09726                                                                                                        
[31]	validation-rmse:7.98682                                         

[85]	validation-rmse:6.59358                                                                                                        
[86]	validation-rmse:6.59041                                                                                                        
[87]	validation-rmse:6.58726                                                                                                        
[88]	validation-rmse:6.58428                                                                                                        
[89]	validation-rmse:6.58158                                                                                                        
[90]	validation-rmse:6.57895                                                                                                        
[91]	validation-rmse:6.57649                                                                                                        
[92]	validation-rmse:6.57421                                         

[45]	validation-rmse:6.69260                                                                                                        
[46]	validation-rmse:6.68745                                                                                                        
[47]	validation-rmse:6.68287                                                                                                        
[48]	validation-rmse:6.67851                                                                                                        
[49]	validation-rmse:6.67525                                                                                                        
[50]	validation-rmse:6.67216                                                                                                        
[51]	validation-rmse:6.66955                                                                                                        
[52]	validation-rmse:6.66747                                         

[5]	validation-rmse:6.97887                                                                                                         
[6]	validation-rmse:6.82027                                                                                                         
[7]	validation-rmse:6.72642                                                                                                         
[8]	validation-rmse:6.67468                                                                                                         
[9]	validation-rmse:6.64035                                                                                                         
[10]	validation-rmse:6.61872                                                                                                        
[11]	validation-rmse:6.60216                                                                                                        
[12]	validation-rmse:6.59081                                         

[66]	validation-rmse:6.54257                                                                                                        
[67]	validation-rmse:6.54279                                                                                                        
[68]	validation-rmse:6.54276                                                                                                        
[69]	validation-rmse:6.54263                                                                                                        
[70]	validation-rmse:6.54348                                                                                                        
[71]	validation-rmse:6.54326                                                                                                        
[72]	validation-rmse:6.54350                                                                                                        
[73]	validation-rmse:6.54346                                         

[26]	validation-rmse:6.95173                                                                                                        
[27]	validation-rmse:6.89968                                                                                                        
[28]	validation-rmse:6.85320                                                                                                        
[29]	validation-rmse:6.81237                                                                                                        
[30]	validation-rmse:6.77691                                                                                                        
[31]	validation-rmse:6.74509                                                                                                        
[32]	validation-rmse:6.71690                                                                                                        
[33]	validation-rmse:6.69152                                         

[87]	validation-rmse:6.45654                                                                                                        
[88]	validation-rmse:6.45558                                                                                                        
[89]	validation-rmse:6.45511                                                                                                        
[90]	validation-rmse:6.45412                                                                                                        
[91]	validation-rmse:6.45325                                                                                                        
[92]	validation-rmse:6.45254                                                                                                        
[93]	validation-rmse:6.45189                                                                                                        
[94]	validation-rmse:6.45123                                         

[47]	validation-rmse:6.39645                                                                                                        
[48]	validation-rmse:6.39556                                                                                                        
[49]	validation-rmse:6.39408                                                                                                        
[50]	validation-rmse:6.39264                                                                                                        
[51]	validation-rmse:6.39154                                                                                                        
[52]	validation-rmse:6.39070                                                                                                        
[53]	validation-rmse:6.38939                                                                                                        
[54]	validation-rmse:6.38893                                         

[7]	validation-rmse:7.04299                                                                                                         
[8]	validation-rmse:6.86109                                                                                                         
[9]	validation-rmse:6.73932                                                                                                         
[10]	validation-rmse:6.65640                                                                                                        
[11]	validation-rmse:6.60224                                                                                                        
[12]	validation-rmse:6.55999                                                                                                        
[13]	validation-rmse:6.52840                                                                                                        
[14]	validation-rmse:6.50440                                         

[68]	validation-rmse:6.37553                                                                                                        
[69]	validation-rmse:6.37496                                                                                                        
[70]	validation-rmse:6.37386                                                                                                        
[71]	validation-rmse:6.37370                                                                                                        
[72]	validation-rmse:6.37321                                                                                                        
[73]	validation-rmse:6.37267                                                                                                        
[74]	validation-rmse:6.37288                                                                                                        
[75]	validation-rmse:6.37211                                         

[28]	validation-rmse:6.43689                                                                                                        
[29]	validation-rmse:6.43424                                                                                                        
[30]	validation-rmse:6.43195                                                                                                        
[31]	validation-rmse:6.43003                                                                                                        
[32]	validation-rmse:6.42767                                                                                                        
[33]	validation-rmse:6.42510                                                                                                        
[34]	validation-rmse:6.42331                                                                                                        
[35]	validation-rmse:6.42202                                         

KeyboardInterrupt: 

# Train the Model with a set of selected Parameters

In [None]:
params =  {'learning_rate': 0.20905792515510074,
            'max_depth': 7,
            'min_child_weight': 0.5241500975917085,
            'objective': 'reg:squarederror',
            'reg_alpha': 0.13309121698466933,
            'reg_lambda': 0.11277257081373988,
            'seed': 42}

In [None]:
# enable autologging
mlflow.xgboost.autolog()

booster = xgb.train(
            # paramters are passed to xgboost
            params=params,
            # training on train data
            dtrain=train,
            # set boosting rounds
            num_boost_round=1000,
            # validation is done on validation dataset
            evals=[(valid, 'validation')],
            # if model does not improve for 50 methods->stop
            early_stopping_rounds=50
        )

# Model Management

same as above, but save the model as an artifact. Lower the number of boosting rounds to save time.

In [None]:
with mlflow.start_run(run_name="test_run") as run:
    
    mlflow.set_tag("developer", "frauke")
    
    mlflow.log_param("train-data-path", "../data/green_tripdata_2021-01.parquet")
    mlflow.log_param("valid-data-path", "../data/green_tripdata_2021-02.parquet")
    
    alpha = 0.01
    mlflow.log_param("alpha", alpha)
    
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)

    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    # save the model as an artifact
    mlflow.log_artifact(local_path="models/lin_reg.bin")#, artifact_path="models_pickle/")

A better way to save the model.

In [None]:
# disable autolog
mlflow.xgboost.autolog(disable=True)

In [None]:
with mlflow.start_run():
    best_params =  {'learning_rate': 0.20905792515510074,
                    'max_depth': 7,
                    'min_child_weight': 0.5241500975917085,
                    'objective': 'reg:squarederror',
                    'reg_alpha': 0.13309121698466933,
                    'reg_lambda': 0.11277257081373988,
                    'seed': 42}
    mlflow.log_params(best_params)
        
    booster = xgb.train(
                # paramters are passed to xgboost
                params=params,
                # training on train data
                dtrain=train,
                # set boosting rounds
                num_boost_round=100,
                # validation is done on validation dataset
                evals=[(valid, 'validation')],
                # if model does not improve for 50 methods->stop
                early_stopping_rounds=50
            )
        
    # make predictions
    y_pred = booster.predict(valid)
    # calculate error
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    # log metric
    mlflow.log_metric("rmse", rmse)
     
    # save preprocessor
    with open("models/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)
    # log preprocessor
    mlflow.log_artifacts("models/preprocessor.b", artifact_path="preprocessor")
    # log the model
    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")

# Make Predictions using a saved Model from MLflow

In [58]:
logged_model = 'runs:/dbe24b153d444422bc22eb0baed7a9de/models_mlflow'

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)

In [59]:
loaded_model

mlflow.pyfunc.loaded_model:
  artifact_path: models_mlflow
  flavor: mlflow.xgboost
  run_id: dbe24b153d444422bc22eb0baed7a9de

load as xgboost model

In [62]:
xgboost_model = mlflow.xgboost.load_model(logged_model)

In [63]:
xgboost_model

<xgboost.core.Booster at 0x7fa6de41cca0>

In [64]:
y_pred = xgboost_model.predict(valid)

In [65]:
y_pred

array([17.384949 ,  7.2070913, 21.283058 , ..., 13.594102 ,  7.2070913,
        8.739125 ], dtype=float32)

# Train some more Models

In [None]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR

# set autolog
mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):
    print(model_class)
    with mlflow.start_run():
        
        # data path
        mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
        mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")
        # log preprocessor
        #mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        
        # log the rmse as metric
        mlflow.log_metric("rmse", rmse)

<class 'sklearn.ensemble._forest.RandomForestRegressor'>
<class 'sklearn.ensemble._gb.GradientBoostingRegressor'>
<class 'sklearn.ensemble._forest.ExtraTreesRegressor'>
