In [1]:
import pickle

import pandas as pd

from sklearn.feature_extraction import DictVectorizer
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

In [2]:
from sklearn.pipeline import make_pipeline

In [3]:
import mlflow

mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("green-taxi-duration")

<Experiment: artifact_location='mlflow-artifacts:/2', creation_time=1734615170885, experiment_id='2', last_update_time=1734615170885, lifecycle_stage='active', name='green-taxi-duration', tags={}>

In [4]:
def read_dataframe(filename: str):
    df = pd.read_parquet(filename)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.dt.total_seconds() / 60
    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    return df


def prepare_dictionaries(df: pd.DataFrame):
    df['PU_DO'] = df['PULocationID'] + '_' + df['DOLocationID']
    categorical = ['PU_DO']
    numerical = ['trip_distance']
    dicts = df[categorical + numerical].to_dict(orient='records')
    return dicts

In [5]:
df_train = read_dataframe('../data/green_tripdata_2024-08.parquet')
df_val = read_dataframe('../data/green_tripdata_2024-09.parquet')

target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

dict_train = prepare_dictionaries(df_train)
dict_val = prepare_dictionaries(df_val)

In [6]:
with mlflow.start_run():
    params = dict(max_depth=20, n_estimators=100, min_samples_leaf=10, random_state=0)
    mlflow.log_params(params)

    pipeline = make_pipeline(
        DictVectorizer(),
        RandomForestRegressor(**params, n_jobs=-1)
    )

    pipeline.fit(dict_train, y_train)
    y_pred = pipeline.predict(dict_val)

    rmse = mean_squared_error(y_pred, y_val, squared=False)
    print(params, rmse)
    mlflow.log_metric('rmse', rmse)

    mlflow.sklearn.log_model(pipeline, artifact_path="model")



{'max_depth': 20, 'n_estimators': 100, 'min_samples_leaf': 10, 'random_state': 0} 6.163798291995869




🏃 View run rare-mouse-199 at: http://127.0.0.1:5000/#/experiments/2/runs/49a69ddb8a01467b9a14d2c10eea62fa
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/2


In [10]:
from mlflow.tracking import MlflowClient


In [20]:
MLFLOW_TRACKING_URI = 'http://127.0.0.1:5000'
RUN_ID = 'b4d3bca8aa8e46a6b8257fe4541b1136'

client = MlflowClient(tracking_uri=MLFLOW_TRACKING_URI)

In [21]:
path = client.download_artifacts(run_id=RUN_ID, path='dict_vectorizer.bin')

In [22]:
with open(path, 'rb') as f_out:
    dv = pickle.load(f_out)

In [23]:
dv

DictVectorizer()

In [8]:
pd.read_parquet('./output/nyc-duration-pred_green_2024_10_49a69ddb8a01467b9a14d2c10eea62fa.parquet')

Unnamed: 0,ride_id,lpep_pickup_datetime,PULocationID,DOLocationID,actual_duration,predicted_duration,diff,model_version
0,3fe4051d-3ed5-4bd4-a447-2198ecbe520d,2024-10-01 00:52:13,75,238,10.433333,12.883188,-2.449855,49a69ddb8a01467b9a14d2c10eea62fa
1,fe9b4145-88fd-4311-8595-cbc481067726,2024-10-01 00:56:34,134,82,7.283333,21.107591,-13.824258,49a69ddb8a01467b9a14d2c10eea62fa
2,07775003-ec42-4756-9a0f-ba5e4a82f593,2024-10-01 00:23:31,202,260,21.766667,19.195790,2.570876,49a69ddb8a01467b9a14d2c10eea62fa
3,d2b56adb-09b9-4239-b853-3abf8e20358b,2024-10-01 00:25:02,130,218,12.233333,17.791162,-5.557829,49a69ddb8a01467b9a14d2c10eea62fa
4,8357e36b-97dc-4cf0-b178-2f5560823af6,2024-10-01 00:11:11,42,94,14.533333,22.176303,-7.642969,49a69ddb8a01467b9a14d2c10eea62fa
...,...,...,...,...,...,...,...,...
53624,ccaa3ba5-a289-497c-b71b-fd353ff01ab6,2024-10-31 21:58:14,65,97,10.500000,9.508959,0.991041,49a69ddb8a01467b9a14d2c10eea62fa
53625,284811cd-d098-400a-9d76-8d09e31d0888,2024-10-31 22:44:00,116,143,17.000000,22.948301,-5.948301,49a69ddb8a01467b9a14d2c10eea62fa
53626,7a5f831e-642d-43c8-8b4d-147cab577ad6,2024-10-31 22:06:00,7,129,13.000000,17.239260,-4.239260,49a69ddb8a01467b9a14d2c10eea62fa
53627,454250ae-0b83-40bb-a995-c6559754a828,2024-10-31 23:19:17,112,36,21.166667,19.029241,2.137426,49a69ddb8a01467b9a14d2c10eea62fa
