Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions qlib/contrib/model/gbdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,26 @@ def fit(
evals_result = {} # in case of unsafety of Python default values
ds_l = self._prepare_data(dataset, reweighter)
ds, names = list(zip(*ds_l))
early_stopping_callback = lgb.early_stopping(
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
)

# Build callbacks list
callbacks = []

# Only add early_stopping callback if rounds is not None (LightGBM 4.0+ doesn't accept None)
early_stop_rounds = self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
if early_stop_rounds is not None:
callbacks.append(lgb.early_stopping(early_stop_rounds))

# NOTE: if you encounter error here. Please upgrade your lightgbm
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
evals_result_callback = lgb.record_evaluation(evals_result)
callbacks.append(lgb.log_evaluation(period=verbose_eval))
callbacks.append(lgb.record_evaluation(evals_result))

self.model = lgb.train(
self.params,
ds[0], # training dataset
num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round,
valid_sets=ds,
valid_names=names,
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
callbacks=callbacks,
**kwargs,
)
for k in names:
Expand Down
13 changes: 9 additions & 4 deletions qlib/contrib/model/highfreq_gdbt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,21 @@ def fit(
if evals_result is None:
evals_result = dict()
dtrain, dvalid = self._prepare_data(dataset)
early_stopping_callback = lgb.early_stopping(early_stopping_rounds)
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
evals_result_callback = lgb.record_evaluation(evals_result)

# Build callbacks list
callbacks = []
if early_stopping_rounds is not None:
callbacks.append(lgb.early_stopping(early_stopping_rounds))
callbacks.append(lgb.log_evaluation(period=verbose_eval))
callbacks.append(lgb.record_evaluation(evals_result))

self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
callbacks=callbacks,
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
Expand Down
4 changes: 3 additions & 1 deletion qlib/contrib/rolling/ddgda.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def _dump_meta_ipt(self):
sim_task = replace_task_handler_with_cache(sim_task, self.working_dir)

if self.sim_task_model == "gbdt":
sim_task["model"].setdefault("kwargs", {}).update({"early_stopping_rounds": None, "num_boost_round": 150})
sim_task["model"].setdefault("kwargs", {}).update({"num_boost_round": 150})
# Don't set early_stopping_rounds to disable it (LightGBM 4.0+ doesn't accept None)
sim_task["model"]["kwargs"].pop("early_stopping_rounds", None)

exp_name_sim = f"data_sim_s{self.step}"

Expand Down