From 629cab2399889c8af8976c26144ecaf48490ecf0 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Sat, 23 May 2026 19:37:38 +0800 Subject: [PATCH] fix: allow disabling LightGBM early stopping --- qlib/contrib/model/gbdt.py | 12 +++- qlib/contrib/model/highfreq_gdbt_model.py | 7 ++- tests/model/test_lightgbm_early_stopping.py | 65 +++++++++++++++++++++ 3 files changed, 79 insertions(+), 5 deletions(-) create mode 100644 tests/model/test_lightgbm_early_stopping.py diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index 22c29cd4997..b3c10998dec 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -68,19 +68,25 @@ def fit( evals_result = {} # in case of unsafety of Python default values ds_l = self._prepare_data(dataset, reweighter) ds, names = list(zip(*ds_l)) - early_stopping_callback = lgb.early_stopping( - self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds + early_stopping_rounds = ( + self.early_stopping_rounds + if early_stopping_rounds is None + else early_stopping_rounds ) + callbacks = [] + if early_stopping_rounds is not None: + callbacks.append(lgb.early_stopping(early_stopping_rounds)) # NOTE: if you encounter error here. Please upgrade your lightgbm verbose_eval_callback = lgb.log_evaluation(period=verbose_eval) evals_result_callback = lgb.record_evaluation(evals_result) + callbacks.extend([verbose_eval_callback, evals_result_callback]) self.model = lgb.train( self.params, ds[0], # training dataset num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round, valid_sets=ds, valid_names=names, - callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback], + callbacks=callbacks, **kwargs, ) for k in names: diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py index ad0641136f2..e2f09846d8c 100644 --- a/qlib/contrib/model/highfreq_gdbt_model.py +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -124,16 +124,19 @@ def fit( if evals_result is None: evals_result = dict() dtrain, dvalid = self._prepare_data(dataset) - early_stopping_callback = lgb.early_stopping(early_stopping_rounds) + callbacks = [] + if early_stopping_rounds is not None: + callbacks.append(lgb.early_stopping(early_stopping_rounds)) verbose_eval_callback = lgb.log_evaluation(period=verbose_eval) evals_result_callback = lgb.record_evaluation(evals_result) + callbacks.extend([verbose_eval_callback, evals_result_callback]) self.model = lgb.train( self.params, dtrain, num_boost_round=num_boost_round, valid_sets=[dtrain, dvalid], valid_names=["train", "valid"], - callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback], + callbacks=callbacks, ) evals_result["train"] = list(evals_result["train"].values())[0] evals_result["valid"] = list(evals_result["valid"].values())[0] diff --git a/tests/model/test_lightgbm_early_stopping.py b/tests/model/test_lightgbm_early_stopping.py new file mode 100644 index 00000000000..90606331a3f --- /dev/null +++ b/tests/model/test_lightgbm_early_stopping.py @@ -0,0 +1,65 @@ +from qlib.contrib.model import gbdt, highfreq_gdbt_model + + +def test_lgb_model_skips_none_early_stopping(monkeypatch): + early_stopping_calls = [] + train_kwargs = {} + model = gbdt.LGBModel(early_stopping_rounds=None) + + monkeypatch.setattr( + model, + "_prepare_data", + lambda dataset, reweighter=None: [("train-dataset", "train")], + ) + monkeypatch.setattr( + gbdt.lgb, + "early_stopping", + lambda rounds: early_stopping_calls.append(rounds) or "early", + ) + monkeypatch.setattr(gbdt.lgb, "log_evaluation", lambda period: "log") + monkeypatch.setattr(gbdt.lgb, "record_evaluation", lambda evals_result: "record") + + def train(*args, **kwargs): + train_kwargs.update(kwargs) + return object() + + monkeypatch.setattr(gbdt.lgb, "train", train) + + model.fit(object(), evals_result={"train": {}}) + + assert early_stopping_calls == [] + assert train_kwargs["callbacks"] == ["log", "record"] + + +def test_hflgb_model_skips_none_early_stopping(monkeypatch): + early_stopping_calls = [] + train_kwargs = {} + model = highfreq_gdbt_model.HFLGBModel() + + monkeypatch.setattr( + model, "_prepare_data", lambda dataset: ("train-dataset", "valid-dataset") + ) + monkeypatch.setattr( + highfreq_gdbt_model.lgb, + "early_stopping", + lambda rounds: early_stopping_calls.append(rounds) or "early", + ) + monkeypatch.setattr(highfreq_gdbt_model.lgb, "log_evaluation", lambda period: "log") + monkeypatch.setattr( + highfreq_gdbt_model.lgb, "record_evaluation", lambda evals_result: "record" + ) + + def train(*args, **kwargs): + train_kwargs.update(kwargs) + return object() + + monkeypatch.setattr(highfreq_gdbt_model.lgb, "train", train) + + model.fit( + object(), + early_stopping_rounds=None, + evals_result={"train": {"loss": [1.0]}, "valid": {"loss": [1.1]}}, + ) + + assert early_stopping_calls == [] + assert train_kwargs["callbacks"] == ["log", "record"]