In [1]:
!python -V

Python 3.9.19


In [2]:
import pandas as pd

In [3]:
import pickle

In [4]:
import seaborn as sns
import matplotlib.pyplot as plt

In [5]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error

In [6]:
import mlflow

mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location='/workspaces/mlops-zoomcamp-mausul/02-experiment-tracking/mlruns/1', creation_time=1716274678962, experiment_id='1', last_update_time=1716274678962, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [7]:
!ls

data			   mlflow-exploration.ipynb  mlruns  requirements.txt
duration-prediction.ipynb  mlflow.db		     models


In [21]:
numerical = ['trip_distance']

In [7]:
def read_dataframe(filename):
    df = pd.read_parquet(filename)
    
    df['duration'] = df['lpep_dropoff_datetime'] - df['lpep_pickup_datetime']
    df['duration'] = df['duration'].apply(lambda td: td.total_seconds() / 60)
    
    df = df[(df.duration >= 1) & (df.duration <= 60)]
    
    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [8]:
df_train = read_dataframe('./data/green_tripdata_2021-01.parquet')
df_val = read_dataframe('./data/green_tripdata_2021-02.parquet')

In [9]:
len(df_train), len(df_val)

(73908, 61921)

In [10]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [12]:
df_val['PU_DO'].head()

0    130_205
1    152_244
2     152_48
3    152_241
4      75_42
Name: PU_DO, dtype: object

In [11]:
categorical = ['PU_DO']  # ['PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [12]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [29]:
with mlflow.start_run():
    lr = LinearRegression()
    lr.fit(X_train, y_train)
    y_pred = lr.predict(X_val)
    mean_squared_error(y_val, y_pred, squared=False)



7.758715208009878

In [31]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

In [16]:
with mlflow.start_run():
    mlflow.set_tag("developer", "Abu Tyeb Azad Mausul")
    
    mlflow.log_param("train_data-path", "./data/green_tripdata_2021-01.parquet")
    mlflow.log_param("val-data-path", "./data/green_tripdata_2021-02.parquet")
    
    alpha = 0.1
    mlflow.log_param("alpha", alpha)
    
    lr = Lasso(alpha=alpha)
    lr.fit(X_train, y_train)
    y_pred = lr.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)



In [14]:
import xgboost as xgb

from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

**Flow for using hyperopt: define a objective function --> specify search_space dict --> pass params to fmin**

In [15]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [39]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, "validation")],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [40]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0), # range: log(uniform_dist([exp(-3), exp(0)]))
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42,
}

In [41]:
best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

  0%|                                                                                 | 0/50 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:10.74865                                                                                                 
[1]	validation-rmse:9.64348                                                                                                  
[2]	validation-rmse:8.82056                                                                                                  
[3]	validation-rmse:8.21779                                                                                                  
[4]	validation-rmse:7.78031                                                                                                  
[5]	validation-rmse:7.46674                                                                                                  
[6]	validation-rmse:7.24104                                                                                                  
[7]	validation-rmse:7.08006                                                                                           





[0]	validation-rmse:11.42565                                                                                                 
[1]	validation-rmse:10.73951                                                                                                 
[2]	validation-rmse:10.14409                                                                                                 
[3]	validation-rmse:9.62883                                                                                                  
[4]	validation-rmse:9.18583                                                                                                  
[5]	validation-rmse:8.80485                                                                                                  
[6]	validation-rmse:8.47940                                                                                                  
[7]	validation-rmse:8.20195                                                                                           





[0]	validation-rmse:8.81049                                                                                                  
[1]	validation-rmse:7.46732                                                                                                  
[2]	validation-rmse:6.96694                                                                                                  
[3]	validation-rmse:6.78451                                                                                                  
[4]	validation-rmse:6.69228                                                                                                  
[5]	validation-rmse:6.65361                                                                                                  
[6]	validation-rmse:6.63155                                                                                                  
[7]	validation-rmse:6.62023                                                                                           





[0]	validation-rmse:11.04364                                                                                                 
[1]	validation-rmse:10.09468                                                                                                 
[2]	validation-rmse:9.32999                                                                                                  
[3]	validation-rmse:8.72087                                                                                                  
[4]	validation-rmse:8.23841                                                                                                  
[5]	validation-rmse:7.85897                                                                                                  
[6]	validation-rmse:7.56283                                                                                                  
[7]	validation-rmse:7.33249                                                                                           





[0]	validation-rmse:9.83071                                                                                                  
[1]	validation-rmse:8.41735                                                                                                  
[2]	validation-rmse:7.63124                                                                                                  
[3]	validation-rmse:7.17692                                                                                                  
[4]	validation-rmse:6.93203                                                                                                  
[5]	validation-rmse:6.79492                                                                                                  
[6]	validation-rmse:6.70227                                                                                                  
[7]	validation-rmse:6.65064                                                                                           





[0]	validation-rmse:11.50903                                                                                                 
[1]	validation-rmse:10.88634                                                                                                 
[2]	validation-rmse:10.33879                                                                                                 
[3]	validation-rmse:9.85834                                                                                                  
[4]	validation-rmse:9.43845                                                                                                  
[5]	validation-rmse:9.07112                                                                                                  
[6]	validation-rmse:8.75252                                                                                                  
[7]	validation-rmse:8.47560                                                                                           





[0]	validation-rmse:7.21728                                                                                                  
[1]	validation-rmse:6.69048                                                                                                  
[2]	validation-rmse:6.61455                                                                                                  
[3]	validation-rmse:6.59775                                                                                                  
[4]	validation-rmse:6.58200                                                                                                  
[5]	validation-rmse:6.57687                                                                                                  
[6]	validation-rmse:6.56860                                                                                                  
[7]	validation-rmse:6.56332                                                                                           





[0]	validation-rmse:11.59601                                                                                                 
[1]	validation-rmse:11.03924                                                                                                 
[2]	validation-rmse:10.53784                                                                                                 
[3]	validation-rmse:10.08770                                                                                                 
[4]	validation-rmse:9.68543                                                                                                  
[5]	validation-rmse:9.32575                                                                                                  
[6]	validation-rmse:9.00480                                                                                                  
[7]	validation-rmse:8.71884                                                                                           





[0]	validation-rmse:11.32923                                                                                                 
[1]	validation-rmse:10.56867                                                                                                 
[2]	validation-rmse:9.91973                                                                                                  
[3]	validation-rmse:9.36698                                                                                                  
[4]	validation-rmse:8.89558                                                                                                  
[5]	validation-rmse:8.50121                                                                                                  
[6]	validation-rmse:8.16915                                                                                                  
[7]	validation-rmse:7.89011                                                                                           





[0]	validation-rmse:11.74127                                                                                                 
[1]	validation-rmse:11.30486                                                                                                 
[2]	validation-rmse:10.90172                                                                                                 
[3]	validation-rmse:10.52976                                                                                                 
[4]	validation-rmse:10.18679                                                                                                 
[5]	validation-rmse:9.87142                                                                                                  
[6]	validation-rmse:9.58195                                                                                                  
[7]	validation-rmse:9.31607                                                                                           





[0]	validation-rmse:10.91915                                                                                                 
[1]	validation-rmse:9.90414                                                                                                  
[2]	validation-rmse:9.11069                                                                                                  
[3]	validation-rmse:8.50158                                                                                                  
[4]	validation-rmse:8.02815                                                                                                  
[5]	validation-rmse:7.67860                                                                                                  
[6]	validation-rmse:7.40794                                                                                                  
[7]	validation-rmse:7.21271                                                                                           





[0]	validation-rmse:11.57458                                                                                                 
[1]	validation-rmse:11.00100                                                                                                 
[2]	validation-rmse:10.48709                                                                                                 
[3]	validation-rmse:10.02694                                                                                                 
[4]	validation-rmse:9.61584                                                                                                  
[5]	validation-rmse:9.24921                                                                                                  
[6]	validation-rmse:8.92285                                                                                                  
[7]	validation-rmse:8.63461                                                                                           





[0]	validation-rmse:8.24033                                                                                                  
[1]	validation-rmse:7.09669                                                                                                  
[2]	validation-rmse:6.78742                                                                                                  
[3]	validation-rmse:6.68269                                                                                                  
[4]	validation-rmse:6.64418                                                                                                  
[5]	validation-rmse:6.62189                                                                                                  
[6]	validation-rmse:6.61310                                                                                                  
[7]	validation-rmse:6.60482                                                                                           





[0]	validation-rmse:8.28517                                                                                                  
[1]	validation-rmse:7.09823                                                                                                  
[2]	validation-rmse:6.75864                                                                                                  
[3]	validation-rmse:6.64471                                                                                                  
[4]	validation-rmse:6.59364                                                                                                  
[5]	validation-rmse:6.57092                                                                                                  
[6]	validation-rmse:6.55802                                                                                                  
[7]	validation-rmse:6.53852                                                                                           





[0]	validation-rmse:8.60556                                                                                                  
[1]	validation-rmse:7.25362                                                                                                  
[2]	validation-rmse:6.78384                                                                                                  
[3]	validation-rmse:6.60929                                                                                                  
[4]	validation-rmse:6.53398                                                                                                  
[5]	validation-rmse:6.49487                                                                                                  
[6]	validation-rmse:6.47650                                                                                                  
[7]	validation-rmse:6.46483                                                                                           





[0]	validation-rmse:9.43302                                                                                                  
[1]	validation-rmse:8.00091                                                                                                  
[2]	validation-rmse:7.30405                                                                                                  
[3]	validation-rmse:6.96856                                                                                                  
[4]	validation-rmse:6.80546                                                                                                  
[5]	validation-rmse:6.70892                                                                                                  
[6]	validation-rmse:6.65672                                                                                                  
[7]	validation-rmse:6.62920                                                                                           





[0]	validation-rmse:11.44319                                                                                                 
[1]	validation-rmse:10.76981                                                                                                 
[2]	validation-rmse:10.18199                                                                                                 
[3]	validation-rmse:9.67240                                                                                                  
[4]	validation-rmse:9.22878                                                                                                  
[5]	validation-rmse:8.84735                                                                                                  
[6]	validation-rmse:8.52008                                                                                                  
[7]	validation-rmse:8.23963                                                                                           

KeyboardInterrupt: 

In [17]:
# Use the hyperparams from best(?) run for training
params = {
    'learning_rate': 0.0786595878556475,
    'max_depth': 51,
    'min_child_weight': 1.9905475557209629,
    'objective': 'reg:linear',
    'reg_alpha': 0.04386766359221766,
    'reg_lambda': 0.0047249573930281084,
    'seed': 42
}

In [18]:
# autolog
mlflow.xgboost.autolog()

booster = xgb.train(
    params=params,
    dtrain=train,
    num_boost_round=1000,
    evals=[(valid, "validation")],
    early_stopping_rounds=50
)
y_pred = booster.predict(valid)
rmse = mean_squared_error(y_val, y_pred, squared=False)

2024/05/22 23:17:44 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID 'b407a028d36d4e26b4d56324cbcd7e1a', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current xgboost workflow


[0]	validation-rmse:11.57458
[1]	validation-rmse:11.00100
[2]	validation-rmse:10.48709
[3]	validation-rmse:10.02694
[4]	validation-rmse:9.61584
[5]	validation-rmse:9.24921
[6]	validation-rmse:8.92285
[7]	validation-rmse:8.63461
[8]	validation-rmse:8.37978
[9]	validation-rmse:8.15610
[10]	validation-rmse:7.95759
[11]	validation-rmse:7.78402
[12]	validation-rmse:7.63205
[13]	validation-rmse:7.49670
[14]	validation-rmse:7.37852
[15]	validation-rmse:7.27310
[16]	validation-rmse:7.18303
[17]	validation-rmse:7.10275
[18]	validation-rmse:7.03238
[19]	validation-rmse:6.97067
[20]	validation-rmse:6.91515
[21]	validation-rmse:6.86822
[22]	validation-rmse:6.82645
[23]	validation-rmse:6.78886
[24]	validation-rmse:6.75543
[25]	validation-rmse:6.72604
[26]	validation-rmse:6.70012
[27]	validation-rmse:6.67683
[28]	validation-rmse:6.65645
[29]	validation-rmse:6.63746
[30]	validation-rmse:6.62088
[31]	validation-rmse:6.60632
[32]	validation-rmse:6.59244
[33]	validation-rmse:6.58073
[34]	validation-rmse



In [33]:
lr = Ridge()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

mean_squared_error(y_val, y_pred, squared=False)



7.70373515548786