diff --git a/src/flexcode/regression_models.py b/src/flexcode/regression_models.py index cb42678..d26308f 100644 --- a/src/flexcode/regression_models.py +++ b/src/flexcode/regression_models.py @@ -52,7 +52,8 @@ def __init__(self, max_basis, params): raise Exception("RandomForest requires sklearn to be installed") super(RandomForest, self).__init__(max_basis) - self.models = [sklearn.ensemble.RandomForestRegressor() + self.n_estimators = params.get("n_estimators", 10) + self.models = [sklearn.ensemble.RandomForestRegressor(self.n_estimators) for ii in range(self.max_basis)] def fit(self, x_train, z_basis): @@ -68,12 +69,15 @@ def predict(self, x_test): class XGBoost(FlexCodeRegression): def __init__(self, max_basis, params): + if not XGBOOST_AVAILABLE: + raise Exception("XGBoost requires xgboost to be installed") super(XGBoost, self).__init__(max_basis) self.params = {'max_depth' : params.get("max_depth", 6), 'eta' : params.get("eta", 0.3), 'silent' : params.get("silent", 1), - 'objective' : 'reg:linear'} + 'objective' : params.get("objective", 'reg:linear') + } self.num_round = params.get("num_round", 500) def fit(self, x_train, z_basis):