In [1]:
import numpy as np
import pandas as pd
from scipy.stats import randint
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import RandomizedSearchCV

from data_utils import train_test_split, build_pipeline

In [2]:
SEED = 42

In [3]:
(df_train_X, train_y), (df_test_X, test_y) = train_test_split(seed=SEED)

pipeline = build_pipeline(df_train_X)
pipeline.fit(df_train_X)
train_X = pipeline.transform(df_train_X)
test_X = pipeline.transform(df_test_X)
print(f"train_X.shape = {train_X.shape}")
print(f"test_X.shape = {test_X.shape}")

train_X.shape = (16512, 16)
test_X.shape = (4128, 16)


In [4]:
param_distributions = {
    'n_estimators': randint(low=1, high=200),
    'max_features': randint(low=1, high=8),
}

model = RandomForestRegressor(random_state=SEED)

random_search = RandomizedSearchCV(
    model, 
    param_distributions=param_distributions,
    n_iter=10, cv=5,
    scoring='neg_mean_squared_error', 
    random_state=SEED
)

random_search.fit(train_X, train_y)

RandomizedSearchCV(cv=5, estimator=RandomForestRegressor(random_state=42),
                   param_distributions={'max_features': <scipy.stats._distn_infrastructure.rv_frozen object at 0x0000023BFB080BC8>,
                                        'n_estimators': <scipy.stats._distn_infrastructure.rv_frozen object at 0x0000023BFB076508>},
                   random_state=42, scoring='neg_mean_squared_error')

In [5]:
df_result = pd.DataFrame(random_search.cv_results_)
df_result['rmse'] = np.sqrt(- df_result['mean_test_score'])
df_result[['params', 'rmse']].sort_values(by='rmse')

Unnamed: 0,params,rmse
0,"{'max_features': 7, 'n_estimators': 180}",49150.707569
4,"{'max_features': 7, 'n_estimators': 122}",49280.944983
7,"{'max_features': 5, 'n_estimators': 100}",49608.996081
8,"{'max_features': 3, 'n_estimators': 150}",50473.619304
6,"{'max_features': 3, 'n_estimators': 88}",50682.788882
5,"{'max_features': 3, 'n_estimators': 75}",50774.906624
2,"{'max_features': 3, 'n_estimators': 72}",50796.155224
3,"{'max_features': 5, 'n_estimators': 21}",50835.133603
1,"{'max_features': 5, 'n_estimators': 15}",51389.889203
9,"{'max_features': 5, 'n_estimators': 2}",64429.841433
