In [1]:
import pandas as pd

In [2]:
import pickle

In [3]:
import seaborn as sns
import matplotlib.pyplot as plt

In [4]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error

In [5]:
import mlflow


mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location='./mlruns/1', experiment_id='1', lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [6]:
def read_dataframe(filename):
    df = pd.read_parquet(filename)

    df.lpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
    df.lpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [7]:
df_train = read_dataframe('green_tripdata_2021-01.parquet')
df_val = read_dataframe('green_tripdata_2021-02.parquet')

In [8]:
len(df_train), len(df_val)

(73908, 61921)

In [9]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [10]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [11]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [12]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

mean_squared_error(y_val, y_pred, squared=False)

7.758715204122913

In [13]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

In [14]:
with mlflow.start_run():

    mlflow.set_tag("developer", "cristian")

    mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
    mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")

    alpha = 0.001
    max_iter=500

    mlflow.log_param("alpha", alpha)
    mlflow.log_param("max_iter", max_iter)

    lr = Lasso(alpha,max_iter= max_iter)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle")

### 

In [15]:
import xgboost as xgb

In [16]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [17]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [18]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [21]:
scope.int(hp.quniform('max_depth', 4, 100, 1))

<hyperopt.pyll.base.Apply at 0x7f0d2cae1910>

In [22]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

[0]	validation-rmse:7.14199                                                     
[1]	validation-rmse:6.66080                                                     
[2]	validation-rmse:6.62573                                                     
[3]	validation-rmse:6.60526                                                     
[4]	validation-rmse:6.59262                                                     
[5]	validation-rmse:6.58527                                                     
[6]	validation-rmse:6.57823                                                     
[7]	validation-rmse:6.57354                                                     
[8]	validation-rmse:6.56588                                                     
[9]	validation-rmse:6.55801                                                     
[10]	validation-rmse:6.55155                                                    
[11]	validation-rmse:6.54464                                                    
[12]	validation-rmse:6.53915

[110]	validation-rmse:6.50131                                                   
[111]	validation-rmse:6.50126                                                   
[112]	validation-rmse:6.50024                                                   
[113]	validation-rmse:6.49880                                                   
[114]	validation-rmse:6.50153                                                   
[115]	validation-rmse:6.50126                                                   
[116]	validation-rmse:6.49761                                                   
[117]	validation-rmse:6.49628                                                   
[118]	validation-rmse:6.49522                                                   
[119]	validation-rmse:6.49388                                                   
[120]	validation-rmse:6.49234                                                   
[121]	validation-rmse:6.49235                                                   
[122]	validation-rmse:6.4918

[312]	validation-rmse:6.38069                                                   
[313]	validation-rmse:6.38055                                                   
[314]	validation-rmse:6.38030                                                   
[315]	validation-rmse:6.38022                                                   
[316]	validation-rmse:6.38018                                                   
[317]	validation-rmse:6.37964                                                   
[318]	validation-rmse:6.37950                                                   
[319]	validation-rmse:6.37922                                                   
[320]	validation-rmse:6.37718                                                   
[321]	validation-rmse:6.37812                                                   
[322]	validation-rmse:6.37782                                                   
[323]	validation-rmse:6.37778                                                   
[324]	validation-rmse:6.3775

[514]	validation-rmse:6.33622                                                   
[515]	validation-rmse:6.33617                                                   
[516]	validation-rmse:6.33603                                                   
[517]	validation-rmse:6.33546                                                   
[518]	validation-rmse:6.33545                                                   
[519]	validation-rmse:6.33539                                                   
[520]	validation-rmse:6.33524                                                   
[521]	validation-rmse:6.33515                                                   
[522]	validation-rmse:6.33496                                                   
[523]	validation-rmse:6.33466                                                   
[524]	validation-rmse:6.33453                                                   
[525]	validation-rmse:6.33472                                                   
[526]	validation-rmse:6.3340

[716]	validation-rmse:6.31935                                                   
[717]	validation-rmse:6.31972                                                   
[718]	validation-rmse:6.31884                                                   
[719]	validation-rmse:6.31886                                                   
[720]	validation-rmse:6.31881                                                   
[721]	validation-rmse:6.31826                                                   
[722]	validation-rmse:6.31830                                                   
[723]	validation-rmse:6.31822                                                   
[724]	validation-rmse:6.31822                                                   
[725]	validation-rmse:6.31759                                                   
[726]	validation-rmse:6.31765                                                   
[727]	validation-rmse:6.31768                                                   
[728]	validation-rmse:6.3177

[136]	validation-rmse:6.58885                                                   
[137]	validation-rmse:6.58839                                                   
[138]	validation-rmse:6.58819                                                   
[139]	validation-rmse:6.58771                                                   
[140]	validation-rmse:6.58728                                                   
[141]	validation-rmse:6.58668                                                   
[142]	validation-rmse:6.58632                                                   
[143]	validation-rmse:6.58599                                                   
[144]	validation-rmse:6.58535                                                   
[145]	validation-rmse:6.58502                                                   
[146]	validation-rmse:6.58424                                                   
[147]	validation-rmse:6.58390                                                   
[148]	validation-rmse:6.5834

[338]	validation-rmse:6.52740                                                   
[339]	validation-rmse:6.52714                                                   
[340]	validation-rmse:6.52696                                                   
[341]	validation-rmse:6.52687                                                   
[342]	validation-rmse:6.52675                                                   
[343]	validation-rmse:6.52648                                                   
[344]	validation-rmse:6.52632                                                   
[345]	validation-rmse:6.52611                                                   
[346]	validation-rmse:6.52592                                                   
[347]	validation-rmse:6.52564                                                   
[348]	validation-rmse:6.52540                                                   
[349]	validation-rmse:6.52527                                                   
[350]	validation-rmse:6.5248

[540]	validation-rmse:6.48918                                                   
[541]	validation-rmse:6.48909                                                   
[542]	validation-rmse:6.48897                                                   
[543]	validation-rmse:6.48868                                                   
[544]	validation-rmse:6.48854                                                   
[545]	validation-rmse:6.48845                                                   
[546]	validation-rmse:6.48826                                                   
[547]	validation-rmse:6.48812                                                   
[548]	validation-rmse:6.48805                                                   
[549]	validation-rmse:6.48785                                                   
[550]	validation-rmse:6.48760                                                   
[551]	validation-rmse:6.48751                                                   
[552]	validation-rmse:6.4874

[742]	validation-rmse:6.46083                                                   
[743]	validation-rmse:6.46069                                                   
[744]	validation-rmse:6.46057                                                   
[745]	validation-rmse:6.46044                                                   
[746]	validation-rmse:6.46030                                                   
[747]	validation-rmse:6.46021                                                   
[748]	validation-rmse:6.45984                                                   
[749]	validation-rmse:6.45975                                                   
[750]	validation-rmse:6.45966                                                   
[751]	validation-rmse:6.45956                                                   
[752]	validation-rmse:6.45942                                                   
[753]	validation-rmse:6.45930                                                   
[754]	validation-rmse:6.4591

[944]	validation-rmse:6.43840                                                   
[945]	validation-rmse:6.43835                                                   
[946]	validation-rmse:6.43811                                                   
[947]	validation-rmse:6.43802                                                   
[948]	validation-rmse:6.43787                                                   
[949]	validation-rmse:6.43783                                                   
[950]	validation-rmse:6.43770                                                   
[951]	validation-rmse:6.43763                                                   
[952]	validation-rmse:6.43762                                                   
[953]	validation-rmse:6.43754                                                   
[954]	validation-rmse:6.43760                                                   
[955]	validation-rmse:6.43746                                                   
[956]	validation-rmse:6.4374

[70]	validation-rmse:6.57743                                                    
[71]	validation-rmse:6.57029                                                    
[72]	validation-rmse:6.56374                                                    
[73]	validation-rmse:6.55756                                                    
[74]	validation-rmse:6.55159                                                    
[75]	validation-rmse:6.54594                                                    
[76]	validation-rmse:6.54079                                                    
[77]	validation-rmse:6.53551                                                    
[78]	validation-rmse:6.53110                                                    
[79]	validation-rmse:6.52640                                                    
[80]	validation-rmse:6.52204                                                    
[81]	validation-rmse:6.51780                                                    
[82]	validation-rmse:6.51390

[272]	validation-rmse:6.39359                                                   
[273]	validation-rmse:6.39323                                                   
[274]	validation-rmse:6.39302                                                   
[275]	validation-rmse:6.39289                                                   
[276]	validation-rmse:6.39256                                                   
[277]	validation-rmse:6.39231                                                   
[278]	validation-rmse:6.39228                                                   
[279]	validation-rmse:6.39190                                                   
[280]	validation-rmse:6.39167                                                   
[281]	validation-rmse:6.39136                                                   
[282]	validation-rmse:6.39096                                                   
[283]	validation-rmse:6.39096                                                   
[284]	validation-rmse:6.3908

[474]	validation-rmse:6.35696                                                   
[475]	validation-rmse:6.35681                                                   
[476]	validation-rmse:6.35666                                                   
[477]	validation-rmse:6.35655                                                   
[478]	validation-rmse:6.35650                                                   
[479]	validation-rmse:6.35637                                                   
[480]	validation-rmse:6.35627                                                   
[481]	validation-rmse:6.35611                                                   
[482]	validation-rmse:6.35596                                                   
[483]	validation-rmse:6.35573                                                   
[484]	validation-rmse:6.35553                                                   
[485]	validation-rmse:6.35537                                                   
[486]	validation-rmse:6.3552

[676]	validation-rmse:6.33529                                                   
[677]	validation-rmse:6.33510                                                   
[678]	validation-rmse:6.33504                                                   
[679]	validation-rmse:6.33496                                                   
[680]	validation-rmse:6.33486                                                   
[681]	validation-rmse:6.33472                                                   
[682]	validation-rmse:6.33454                                                   
[683]	validation-rmse:6.33440                                                   
[684]	validation-rmse:6.33432                                                   
[685]	validation-rmse:6.33425                                                   
[686]	validation-rmse:6.33408                                                   
[687]	validation-rmse:6.33399                                                   
[688]	validation-rmse:6.3339

[878]	validation-rmse:6.31870                                                   
[879]	validation-rmse:6.31863                                                   
[880]	validation-rmse:6.31856                                                   
[881]	validation-rmse:6.31841                                                   
[882]	validation-rmse:6.31841                                                   
[883]	validation-rmse:6.31826                                                   
[884]	validation-rmse:6.31806                                                   
[885]	validation-rmse:6.31803                                                   
[886]	validation-rmse:6.31805                                                   
[887]	validation-rmse:6.31809                                                   
[888]	validation-rmse:6.31803                                                   
[889]	validation-rmse:6.31797                                                   
[890]	validation-rmse:6.3178

[78]	validation-rmse:6.53922                                                    
[79]	validation-rmse:6.53936                                                    
[80]	validation-rmse:6.53917                                                    
[81]	validation-rmse:6.53920                                                    
[82]	validation-rmse:6.53949                                                    
[83]	validation-rmse:6.54049                                                    
[84]	validation-rmse:6.53982                                                    
[85]	validation-rmse:6.53963                                                    
[86]	validation-rmse:6.53901                                                    
[87]	validation-rmse:6.53926                                                    
[88]	validation-rmse:6.53919                                                    
[89]	validation-rmse:6.53886                                                    
[90]	validation-rmse:6.54277

[139]	validation-rmse:6.42886                                                   
[140]	validation-rmse:6.42863                                                   
[141]	validation-rmse:6.42840                                                   
[142]	validation-rmse:6.42806                                                   
[143]	validation-rmse:6.42788                                                   
[144]	validation-rmse:6.42762                                                   
[145]	validation-rmse:6.42756                                                   
[146]	validation-rmse:6.42702                                                   
[147]	validation-rmse:6.42690                                                   
[148]	validation-rmse:6.42664                                                   
[149]	validation-rmse:6.42636                                                   
[150]	validation-rmse:6.42624                                                   
[151]	validation-rmse:6.4259

[341]	validation-rmse:6.41533                                                   
[0]	validation-rmse:16.04767                                                    
[1]	validation-rmse:12.62269                                                    
[2]	validation-rmse:10.35953                                                    
[3]	validation-rmse:8.91442                                                     
[4]	validation-rmse:8.01093                                                     
[5]	validation-rmse:7.46192                                                     
[6]	validation-rmse:7.12539                                                     
[7]	validation-rmse:6.92655                                                     
[8]	validation-rmse:6.79466                                                     
[9]	validation-rmse:6.70736                                                     
[10]	validation-rmse:6.65389                                                    
[11]	validation-rmse:6.61994

KeyboardInterrupt: 

In [24]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [25]:
X_train

<73908x13221 sparse matrix of type '<class 'numpy.float64'>'
	with 147816 stored elements in Compressed Sparse Row format>

In [18]:
mlflow.xgboost.autolog(disable=True)

In [26]:
from xgboost import XGBRegressor


In [32]:
with mlflow.start_run():
    
#     train = xgb.DMatrix(X_train, label=y_train)
#     valid = xgb.DMatrix(X_val, label=y_val)
    mlflow.set_tag("user","Galileo")

    best_params = {
        'learning_rate': 0.09585355369315604,
        'max_depth': 30,
        'min_child_weight': 1.060597050922164,
        'objective': 'reg:linear',
        'reg_alpha': 0.018060244040060163,
        'reg_lambda': 0.011658731377413597,
        'seed': 42
    }

    mlflow.log_params(best_params)

    booster = XGBRegressor(
        params=best_params,
#         dtrain=train,
        num_boost_round=1000,
        evals=[(X_val, 'validation')],
#         early_stopping_rounds=50
    )
    booster.fit(X_train,y_train)

    y_pred = booster.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    with open("models/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)
    mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")

Parameters: { "evals", "num_boost_round", "params" } might not be used.

  This could be a false alarm, with some parameters getting used by language bindings but
  then being mistakenly passed down to XGBoost core, or some parameter actually being used
  but getting flagged wrongly here. Please open an issue if you find any such cases.






In [23]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR

mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):

    with mlflow.start_run():

        mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
        mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)
        



In [6]:
import argparse
import os
import pickle

from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error


In [None]:


def load_pickle(filename: str):
    with open(filename, "rb") as f_in:
        return pickle.load(f_in)


def run(data_path):

    X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl"))
    X_valid, y_valid = load_pickle(os.path.join(data_path, "valid.pkl"))

    rf = RandomForestRegressor(max_depth=10, random_state=0)
    rf.fit(X_train, y_train)
    y_pred = rf.predict(X_valid)

    rmse = mean_squared_error(y_valid, y_pred, squared=False)


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_path",
        default="./output",
        help="the location where the processed NYC taxi trip data was saved."
    )
    args = parser.parse_args()

    run(args.data_path)