# Setting up

In [1]:
# Setting up

import warnings
warnings.filterwarnings('ignore')

import pickle

import mlflow
import numpy as np
import pandas as pd

from sklearn.feature_extraction import DictVectorizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

import xgboost as xgb
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [2]:
import mlflow
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("house-price-prediction")

<Experiment: artifact_location='./mlruns/1', experiment_id='1', lifecycle_stage='active', name='house-price-prediction', tags={}>

In [3]:
RANDOM_SEED=42

# Reading data

In [4]:
def read_data(filename):
    df = pd.read_csv(filename)
    
    df = df.drop(df[(df['GrLivArea']>4000) & (df['SalePrice']<300000)].index)
    df['TotalSF'] = df['TotalBsmtSF'] + df['1stFlrSF'] + df['2ndFlrSF']
    
    return df

In [5]:
train = read_data('../data/train.csv')

# Data Processing

In [6]:
def prepare_features(df: pd.DataFrame, dv=None):
    
    df_label = df.SalePrice.values
    df.drop(['SalePrice'], axis=1, inplace=True) 
    
    num_to_categ_cols = ['MSSubClass', 'OverallCond', 'YrSold', 'MoSold']
    df[num_to_categ_cols] = df[num_to_categ_cols].astype(str)
    
    df_dict = df.to_dict("records")
    df_processed = dv.fit_transform(df_dict)
    
    return df_processed, df_label, dv

In [7]:
dv = DictVectorizer(sparse=True)

X, y, dv = prepare_features(train, dv)

In [8]:
def split_dataset(X, y, split_sizes=[0.8, 0.5], random_seed=42):
    X_train, X_rem, y_train, y_rem = train_test_split(X, y, 
        train_size=split_sizes[0], 
        random_state=random_seed)
    X_valid, X_test, y_valid, y_test = train_test_split(X_rem, y_rem, 
        test_size=split_sizes[1], 
        random_state=random_seed)
    return X_train, y_train, X_valid, y_valid, X_test, y_test

In [9]:
X_train, y_train, X_valid, y_valid, X_test, y_test = split_dataset(X, y, random_seed=RANDOM_SEED)

# Training

In [10]:
train_xgb = xgb.DMatrix(X_train, label=y_train)
valid_xgb = xgb.DMatrix(X_valid, label=y_valid)
test_xgb = xgb.DMatrix(X_test, label=y_test)

In [11]:
def train_model_search(train, valid, test, y_test):
    def objective(params):
        with mlflow.start_run():
            mlflow.set_tag("model", "xgboost")
            mlflow.log_params(params)

            model_xgb = xgb.train(
                params=params,
                dtrain=train,
                evals=[(valid, 'validation')],
                num_boost_round=100,
                early_stopping_rounds=50,   
            )

            y_pred = model_xgb.predict(test)
            rmse = mean_squared_error(y_test, y_pred, squared=False)
            mlflow.log_metric("rmse", rmse)

        return {'loss': rmse, 'status': STATUS_OK}
    
    search_space = {
        'gamma': hp.loguniform('gamma', -5, -1),
        'colsample_bytree' : hp.uniform('colsample_bytree', 0.3,1),
        'subsample': hp.uniform('subsample', 0.4,1),
        'max_depth': scope.int(hp.quniform('max_depth', 3, 10, 1)),
        'learning_rate': hp.loguniform('learning_rate', -3, 0),
        'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
        'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
        'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
        'objective': 'reg:linear',
        'seed': RANDOM_SEED
    }
    
    best_result = fmin(
        fn=objective,
        space=search_space,
        algo=tpe.suggest,
        max_evals=10,
        trials=Trials()
    )
    
    return

In [12]:
train_model_search(train_xgb, valid_xgb, test_xgb, y_test)

[0]	validation-rmse:120020.85061                                                         
[1]	validation-rmse:71424.94852                                                          
[2]	validation-rmse:45747.24964                                                          
[3]	validation-rmse:33053.82943                                                          
[4]	validation-rmse:29002.69600                                                          
[5]	validation-rmse:27317.83608                                                          
[6]	validation-rmse:27004.58106                                                          
[7]	validation-rmse:27324.60676                                                          
[8]	validation-rmse:27529.58684                                                          
[9]	validation-rmse:27460.91663                                                          
[10]	validation-rmse:27407.72345                                                         
[11]	valid

[30]	validation-rmse:25022.41675                                                         
[31]	validation-rmse:24260.82719                                                         
[32]	validation-rmse:23787.16310                                                         
[33]	validation-rmse:23263.39419                                                         
[34]	validation-rmse:22930.24423                                                         
[35]	validation-rmse:22696.98876                                                         
[36]	validation-rmse:22422.30450                                                         
[37]	validation-rmse:22324.53165                                                         
[38]	validation-rmse:22185.66545                                                         
[39]	validation-rmse:22049.43671                                                         
[40]	validation-rmse:21962.20975                                                         
[41]	valid

[23]	validation-rmse:22881.90708                                                         
[24]	validation-rmse:22967.98197                                                         
[25]	validation-rmse:22967.27422                                                         
[26]	validation-rmse:23086.38642                                                         
[27]	validation-rmse:23062.30234                                                         
[28]	validation-rmse:23075.81581                                                         
[29]	validation-rmse:23004.48526                                                         
[30]	validation-rmse:23017.05672                                                         
[31]	validation-rmse:23100.67172                                                         
[32]	validation-rmse:23075.90718                                                         
[33]	validation-rmse:23135.69006                                                         
[34]	valid

[38]	validation-rmse:31531.24147                                                         
[39]	validation-rmse:31550.32385                                                         
[40]	validation-rmse:31272.23417                                                         
[41]	validation-rmse:31438.09909                                                         
[42]	validation-rmse:31338.18360                                                         
[43]	validation-rmse:31614.86172                                                         
[44]	validation-rmse:31816.33646                                                         
[45]	validation-rmse:31644.06236                                                         
[46]	validation-rmse:31689.27585                                                         
[47]	validation-rmse:31657.61865                                                         
[48]	validation-rmse:31598.09559                                                         
[49]	valid

[69]	validation-rmse:28730.53656                                                         
[70]	validation-rmse:28579.50776                                                         
[71]	validation-rmse:28483.81184                                                         
[72]	validation-rmse:28507.70108                                                         
[73]	validation-rmse:28480.55917                                                         
[74]	validation-rmse:28753.78547                                                         
[75]	validation-rmse:28890.02388                                                         
[76]	validation-rmse:28817.97861                                                         
[77]	validation-rmse:28790.71299                                                         
[78]	validation-rmse:28915.08589                                                         
[0]	validation-rmse:112893.24140                                                         
[1]	valida

[21]	validation-rmse:31048.17930                                                         
[22]	validation-rmse:31006.80750                                                         
[23]	validation-rmse:31177.56403                                                         
[24]	validation-rmse:31163.20649                                                         
[25]	validation-rmse:31303.64046                                                         
[26]	validation-rmse:31239.05372                                                         
[27]	validation-rmse:31124.57642                                                         
[28]	validation-rmse:31105.81606                                                         
[29]	validation-rmse:31120.85522                                                         
[30]	validation-rmse:31147.68739                                                         
[31]	validation-rmse:31117.51559                                                         
[32]	valid

[49]	validation-rmse:26959.01486                                                         
[50]	validation-rmse:26970.45828                                                         
[51]	validation-rmse:26972.80641                                                         
[52]	validation-rmse:26970.75483                                                         
[53]	validation-rmse:26967.28874                                                         
[54]	validation-rmse:26973.88770                                                         
[55]	validation-rmse:26971.33250                                                         
[56]	validation-rmse:26978.02305                                                         
[57]	validation-rmse:26985.23457                                                         
[58]	validation-rmse:26987.47139                                                         
[59]	validation-rmse:26967.62936                                                         
[60]	valid

[12]	validation-rmse:26434.68442                                                         
[13]	validation-rmse:26055.11495                                                         
[14]	validation-rmse:25611.90886                                                         
[15]	validation-rmse:25112.37219                                                         
[16]	validation-rmse:24931.21948                                                         
[17]	validation-rmse:24962.65870                                                         
[18]	validation-rmse:24795.84591                                                         
[19]	validation-rmse:25057.32547                                                         
[20]	validation-rmse:25062.33352                                                         
[21]	validation-rmse:24951.91692                                                         
[22]	validation-rmse:25064.64681                                                         
[23]	valid

In [13]:
def train_best_model(train, valid, test, y_test, dv):
    with mlflow.start_run():

        best_params = {
            'colsample_bytree': 0.9250870893919794,
            'gamma': 0.007995628667745471,
            'learning_rate': 0.20384373996439606,
            'max_depth': 4,
            'min_child_weight': 0.41092408055939844,
            'reg_alpha': 0.007444391334457018,
            'reg_lambda': 0.017392816466180783,
            'subsample': 0.7772671896767146,
            'seed': RANDOM_SEED
        }

        mlflow.log_params(best_params)

        model_xgb = xgb.train(
            params=best_params,
            dtrain=train,
            evals=[(valid, 'validation')],
            num_boost_round=2200,
            early_stopping_rounds=50)

        y_pred = model_xgb.predict(test)
        rmse = mean_squared_error(y_test, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

        with open("../models/preprocessor.b", "wb") as f_out:
            pickle.dump(dv, f_out)
        mlflow.log_artifact("../models/preprocessor.b", artifact_path="preprocessor")

        mlflow.xgboost.log_model(model_xgb, artifact_path="models_mlflow")

In [14]:
train_best_model(train_xgb, valid_xgb, test_xgb, y_test, dv)

Parameters: { "n_estimators" } might not be used.

  This could be a false alarm, with some parameters getting used by language bindings but
  then being mistakenly passed down to XGBoost core, or some parameter actually being used
  but getting flagged wrongly here. Please open an issue if you find any such cases.


[0]	validation-rmse:158494.29941
[1]	validation-rmse:126484.51480
[2]	validation-rmse:101066.17423
[3]	validation-rmse:81220.04929
[4]	validation-rmse:65258.07088
[5]	validation-rmse:52934.61472
[6]	validation-rmse:44064.58494
[7]	validation-rmse:37790.94475
[8]	validation-rmse:32820.68480
[9]	validation-rmse:29195.42834
[10]	validation-rmse:26515.86373
[11]	validation-rmse:24829.85717
[12]	validation-rmse:24181.78056
[13]	validation-rmse:23545.26726
[14]	validation-rmse:22802.64532
[15]	validation-rmse:22222.40097
[16]	validation-rmse:22270.32785
[17]	validation-rmse:22007.41964
[18]	validation-rmse:21984.39723
[19]	validation-rmse:22045.26249
[20]	validation-rmse:22302.4