Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightGBM - GBDT model earlystop bug fix #3614

Merged
merged 15 commits into from
Nov 8, 2023
Merged
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 87 additions & 25 deletions deepchem/models/gbdt_models/gbdt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,19 @@ def __init__(self,
early_stopping_rounds: int, optional (default 50)
Activates early stopping. Validation metric needs to improve at least once
in every early_stopping_rounds round(s) to continue training.
eval_metric: Union[str, Callbale]
eval_metric: Union[str, Callable]
If string, it should be a built-in evaluation metric to use.
If callable, it should be a custom evaluation metric, see official note for more details.
"""

try:
import xgboost
import lightgbm
except:
raise ModuleNotFoundError(
"XGBoost or LightGBM modules not found. Please install them to use this class."
)

if model_dir is not None:
if not os.path.exists(model_dir):
os.makedirs(model_dir)
Expand All @@ -54,28 +63,44 @@ def __init__(self,
self.model_class = model.__class__
self.early_stopping_rounds = early_stopping_rounds
self.model_type = self._check_model_type()
self.eval_dict = dict()

if self.early_stopping_rounds <= 0:
raise ValueError("Early Stopping Rounds cannot be less than 1.")

if self.model.__class__.__name__.startswith('XGB'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a unit test for callbacks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll do that

self.callbacks = [
xgboost.callback.EarlyStopping(
rounds=self.early_stopping_rounds)
]
elif self.model.__class__.__name__.startswith('LGBM'):
self.callbacks = [
lightgbm.early_stopping(
stopping_rounds=self.early_stopping_rounds),
lightgbm.record_evaluation(self.eval_dict)
]

if eval_metric is None:
if self.model_type == 'classification':
self.eval_metric: Optional[Union[str, Callable]] = 'auc'
elif self.model_type == 'regression':
self.eval_metric = 'mae'
if self.model_type == "classification":
self.eval_metric: Optional[Union[str, Callable]] = "auc"
elif self.model_type == "regression":
self.eval_metric = "mae"
else:
self.eval_metric = eval_metric
else:
self.eval_metric = eval_metric

def _check_model_type(self) -> str:
class_name = self.model.__class__.__name__
if class_name.endswith('Classifier'):
return 'classification'
elif class_name.endswith('Regressor'):
return 'regression'
elif class_name == 'NoneType':
return 'none'
if class_name.endswith("Classifier"):
return "classification"
elif class_name.endswith("Regressor"):
return "regression"
elif class_name == "NoneType":
return "none"
else:
raise ValueError(
'{} is not a supported model instance.'.format(class_name))
"{} is not a supported model instance.".format(class_name))

def fit(self, dataset: Dataset):
"""Fits GDBT model with all data.
Expand All @@ -98,7 +123,7 @@ def fit(self, dataset: Dataset):

seed = self.model.random_state
stratify = None
if self.model_type == 'classification':
if self.model_type == "classification":
stratify = y

# Find optimal n_estimators based on original learning_rate and early_stopping_rounds
Expand All @@ -107,17 +132,34 @@ def fit(self, dataset: Dataset):
test_size=0.2,
random_state=seed,
stratify=stratify)
self.model.fit(X_train,
y_train,
early_stopping_rounds=self.early_stopping_rounds,
eval_metric=self.eval_metric,
eval_set=[(X_test, y_test)])
self.model.fit(
X_train,
y_train,
callbacks=self.callbacks,
eval_metric=self.eval_metric,
eval_set=[(X_test, y_test)],
)

# retrain model to whole data using best n_estimators * 1.25
if self.model.__class__.__name__.startswith('XGB'):
estimated_best_round = np.round(self.model.best_ntree_limit * 1.25)
res = list(self.model.evals_result_['validation_0'].values())
estimated_best_round = np.round(
(self.model.best_iteration + 1) * 1.25)
else:
res = list(self.eval_dict['valid_0'].values())
estimated_best_round = np.round(self.model.best_iteration_ * 1.25)

# If ES rounds are more than total epochs, it will never trigger.
if self.early_stopping_rounds < self.model.n_estimators:
# Check the number of boosting rounds
rounds_ran = len(res[0])
# If rounds ran are less than estimators, it means ES was triggered.
if rounds_ran < self.model.n_estimators:
if self.model.__class__.__name__.startswith('XGB'):
assert self.model.best_iteration < self.model.n_estimators - 1
else:
assert self.model.best_iteration_ < self.model.n_estimators

self.model.n_estimators = np.int64(estimated_best_round)
self.model.fit(X, y, eval_metric=self.eval_metric)

Expand All @@ -139,11 +181,30 @@ def fit_with_eval(self, train_dataset: Dataset, valid_dataset: Dataset):
if len(y_train.shape) != 1 or len(y_valid.shape) != 1:
raise ValueError("GDBT model doesn't support multi-output(task)")

self.model.fit(X_train,
y_train,
early_stopping_rounds=self.early_stopping_rounds,
eval_metric=self.eval_metric,
eval_set=[(X_valid, y_valid)])
self.model.fit(
X_train,
y_train,
callbacks=[self.callbacks],
eval_metric=self.eval_metric,
eval_set=[(X_valid, y_valid)],
)

if self.model.__class__.__name__.startswith('XGB'):
res = list(self.model.evals_result_['validation_0'].values())
else:
res = list(self.eval_dict['valid_0'].values())
assert self.model.best_iteration_ < self.model.n_estimators

# If ES rounds are more than total epochs, it will never trigger.
if self.early_stopping_rounds < self.model.n_estimators:
# Check the number of boosting rounds
rounds_ran = len(res[0])
# If rounds ran are less than estimators, it means ES was triggered.
if rounds_ran < self.model.n_estimators:
if self.model.__class__.__name__.startswith('XGB'):
assert self.model.best_iteration < self.model.n_estimators - 1
else:
assert self.model.best_iteration_ < self.model.n_estimators


#########################################
Expand All @@ -156,5 +217,6 @@ class XGBoostModel(GBDTModel):
def __init__(self, *args, **kwargs):
warnings.warn(
"XGBoostModel is deprecated and has been renamed to GBDTModel.",
FutureWarning)
FutureWarning,
)
super(XGBoostModel, self).__init__(*args, **kwargs)
Loading