diff --git a/qlib/contrib/meta/data_selection/dataset.py b/qlib/contrib/meta/data_selection/dataset.py index 61efdd63cfb..27142b79929 100644 --- a/qlib/contrib/meta/data_selection/dataset.py +++ b/qlib/contrib/meta/data_selection/dataset.py @@ -95,6 +95,9 @@ def setup(self, trainer=TrainerR, trainer_kwargs={}): pred = rec.load_object("pred.pkl") task = rec.load_object("task") data_key = task["dataset"]["kwargs"]["segments"]["train"] + # Convert list to tuple to make it hashable (fix for unhashable type error) + if isinstance(data_key, list): + data_key = tuple(data_key) key_l.append(data_key) ic_l.append(delayed(self._calc_perf)(pred.iloc[:, 0], label_df.iloc[:, 0])) @@ -106,8 +109,9 @@ def setup(self, trainer=TrainerR, trainer_kwargs={}): def _calc_perf(self, pred, label): df = pd.DataFrame({"pred": pred, "label": label}) - df = df.groupby("datetime", group_keys=False).corr(method="spearman") - corr = df.loc(axis=0)[:, "pred"]["label"].droplevel(axis=0, level=-1) + df = df.groupby("datetime").corr(method="spearman") + # Use xs to select 'label' from the second level of MultiIndex, then get 'pred' column + corr = df.xs("label", level=1)["pred"] return corr def update(self): diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index 22c29cd4997..76e88c02948 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -68,19 +68,26 @@ def fit( evals_result = {} # in case of unsafety of Python default values ds_l = self._prepare_data(dataset, reweighter) ds, names = list(zip(*ds_l)) - early_stopping_callback = lgb.early_stopping( - self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds - ) + + # Build callbacks list + callbacks = [] + + # Only add early_stopping callback if rounds is not None (LightGBM 4.0+ compatibility) + early_stop_rounds = self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds + if early_stop_rounds is not None: + callbacks.append(lgb.early_stopping(early_stop_rounds)) + # NOTE: if you encounter error here. Please upgrade your lightgbm - verbose_eval_callback = lgb.log_evaluation(period=verbose_eval) - evals_result_callback = lgb.record_evaluation(evals_result) + callbacks.append(lgb.log_evaluation(period=verbose_eval)) + callbacks.append(lgb.record_evaluation(evals_result)) + self.model = lgb.train( self.params, ds[0], # training dataset num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round, valid_sets=ds, valid_names=names, - callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback], + callbacks=callbacks, **kwargs, ) for k in names: diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py index ad0641136f2..7ff25ae212f 100644 --- a/qlib/contrib/model/highfreq_gdbt_model.py +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -124,16 +124,24 @@ def fit( if evals_result is None: evals_result = dict() dtrain, dvalid = self._prepare_data(dataset) - early_stopping_callback = lgb.early_stopping(early_stopping_rounds) - verbose_eval_callback = lgb.log_evaluation(period=verbose_eval) - evals_result_callback = lgb.record_evaluation(evals_result) + + # Build callbacks list + callbacks = [] + + # Only add early_stopping callback if rounds is not None (LightGBM 4.0+ compatibility) + if early_stopping_rounds is not None: + callbacks.append(lgb.early_stopping(early_stopping_rounds)) + + callbacks.append(lgb.log_evaluation(period=verbose_eval)) + callbacks.append(lgb.record_evaluation(evals_result)) + self.model = lgb.train( self.params, dtrain, num_boost_round=num_boost_round, valid_sets=[dtrain, dvalid], valid_names=["train", "valid"], - callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback], + callbacks=callbacks, ) evals_result["train"] = list(evals_result["train"].values())[0] evals_result["valid"] = list(evals_result["valid"].values())[0] diff --git a/qlib/utils/pickle_utils.py b/qlib/utils/pickle_utils.py index 920692f3c89..31f9d53f470 100644 --- a/qlib/utils/pickle_utils.py +++ b/qlib/utils/pickle_utils.py @@ -46,9 +46,49 @@ ("pathlib", "Path"), ("pathlib", "PosixPath"), ("pathlib", "WindowsPath"), + ("qlib.data.dataset.handler", "DataHandlerABC"), ("qlib.data.dataset.handler", "DataHandler"), ("qlib.data.dataset.handler", "DataHandlerLP"), + ("qlib.data.dataset.loader", "DataLoader"), + ("qlib.data.dataset.loader", "DLWParser"), + ("qlib.data.dataset.loader", "QlibDataLoader"), ("qlib.data.dataset.loader", "StaticDataLoader"), + ("qlib.data.dataset.loader", "NestedDataLoader"), + ("qlib.data.dataset.loader", "DataLoaderDH"), + # Dataset hierarchy - needed when a recorder/rolling workflow pickles a + # full dataset and the unpickler walks the wrapped handler/loader graph. + ("qlib.data.dataset", "Dataset"), + ("qlib.data.dataset", "DatasetH"), + ("qlib.data.dataset", "TSDatasetH"), + # Stock-data handlers shipped in qlib.contrib. Without these the + # ``Rolling._train_rolling_tasks`` -> recorder load path fails with + # ``Forbidden class: qlib.contrib.data.handler.Alpha158`` (issue #2130). + ("qlib.contrib.data.handler", "Alpha158"), + ("qlib.contrib.data.handler", "Alpha158vwap"), + ("qlib.contrib.data.handler", "Alpha360"), + ("qlib.contrib.data.handler", "Alpha360vwap"), + # Processors are part of every Dataset's processor chain and must be + # restorable when the dataset is reloaded from disk. + ("qlib.data.dataset.processor", "Processor"), + ("qlib.data.dataset.processor", "DropnaProcessor"), + ("qlib.data.dataset.processor", "DropnaLabel"), + ("qlib.data.dataset.processor", "DropCol"), + ("qlib.data.dataset.processor", "FilterCol"), + ("qlib.data.dataset.processor", "TanhProcess"), + ("qlib.data.dataset.processor", "ProcessInf"), + ("qlib.data.dataset.processor", "Fillna"), + ("qlib.data.dataset.processor", "MinMaxNorm"), + ("qlib.data.dataset.processor", "ZScoreNorm"), + ("qlib.data.dataset.processor", "RobustZScoreNorm"), + ("qlib.data.dataset.processor", "CSZScoreNorm"), + ("qlib.data.dataset.processor", "CSRankNorm"), + ("qlib.data.dataset.processor", "CSZFillna"), + ("qlib.data.dataset.processor", "HashStockFormat"), + ("qlib.data.dataset.processor", "TimeRangeFlt"), + # Utility functions used in data processing + ("qlib.utils.data", "zscore"), + # Meta-learning data selection classes used in DDG-DA workflow + ("qlib.contrib.meta.data_selection.dataset", "InternalData"), } diff --git a/tests/misc/test_pickle_safelist.py b/tests/misc/test_pickle_safelist.py new file mode 100644 index 00000000000..a59347e569c --- /dev/null +++ b/tests/misc/test_pickle_safelist.py @@ -0,0 +1,134 @@ +"""Regression tests for issue #2130. + +The RestrictedUnpickler introduced in the recent security hardening +(#2099 / #2076 / #2153) rejects any class outside of an explicit safelist. +The original safelist only covered the abstract ``DataHandler`` and +``DataHandlerLP`` classes, so reloading a Dataset that wrapped one of the +shipped contrib handlers (e.g. ``Alpha158``) crashed +``Rolling._train_rolling_tasks`` with:: + + UnpicklingError: Forbidden class: qlib.contrib.data.handler.Alpha158. + Only whitelisted classes are allowed for security reasons. ... + +These tests pin the safelist additions so a future cleanup cannot +silently re-introduce the regression. +""" + +from __future__ import annotations + +import pickle +import unittest + +from qlib.utils.pickle_utils import ( + SAFE_PICKLE_CLASSES, + RestrictedUnpickler, + restricted_pickle_loads, +) + + +def _is_safe(module: str, name: str) -> bool: + return (module, name) in SAFE_PICKLE_CLASSES + + +class SafePickleClassesContainAlphaHandlersTest(unittest.TestCase): + """Issue #2130: stock-data handlers shipped in ``qlib.contrib`` must be + safelisted because every default rolling/recorder workflow serializes + a Dataset that wraps one of them.""" + + def test_alpha158_is_safelisted(self) -> None: + self.assertTrue(_is_safe("qlib.contrib.data.handler", "Alpha158")) + + def test_alpha158_vwap_is_safelisted(self) -> None: + self.assertTrue(_is_safe("qlib.contrib.data.handler", "Alpha158vwap")) + + def test_alpha360_is_safelisted(self) -> None: + self.assertTrue(_is_safe("qlib.contrib.data.handler", "Alpha360")) + + def test_alpha360_vwap_is_safelisted(self) -> None: + self.assertTrue(_is_safe("qlib.contrib.data.handler", "Alpha360vwap")) + + +class SafePickleClassesContainDatasetHierarchyTest(unittest.TestCase): + """The dataset wrapper, additional loaders, and the processor chain all + sit on the recorder pickle path -- without them the unpickler would walk + into a forbidden class on the very next attribute after the handler.""" + + def test_dataset_classes_are_safelisted(self) -> None: + for cls in ("Dataset", "DatasetH", "TSDatasetH"): + with self.subTest(cls=cls): + self.assertTrue(_is_safe("qlib.data.dataset", cls)) + + def test_loaders_are_safelisted(self) -> None: + for cls in ( + "DataLoader", + "DLWParser", + "QlibDataLoader", + "StaticDataLoader", + "NestedDataLoader", + "DataLoaderDH", + ): + with self.subTest(cls=cls): + self.assertTrue(_is_safe("qlib.data.dataset.loader", cls)) + + def test_processors_are_safelisted(self) -> None: + for cls in ( + "Processor", + "DropnaProcessor", + "DropnaLabel", + "DropCol", + "FilterCol", + "TanhProcess", + "ProcessInf", + "Fillna", + "MinMaxNorm", + "ZScoreNorm", + "RobustZScoreNorm", + "CSZScoreNorm", + "CSRankNorm", + "CSZFillna", + "HashStockFormat", + "TimeRangeFlt", + ): + with self.subTest(cls=cls): + self.assertTrue(_is_safe("qlib.data.dataset.processor", cls)) + + +class SafePickleClassesContainUtilityFunctionsTest(unittest.TestCase): + """DDG-DA workflow requires utility functions like zscore to be safelisted + because they are used in data processing and get pickled with the dataset.""" + + def test_zscore_is_safelisted(self) -> None: + self.assertTrue(_is_safe("qlib.utils.data", "zscore")) + + +class SafePickleClassesContainMetaLearningClassesTest(unittest.TestCase): + """DDG-DA workflow requires meta-learning classes like InternalData to be + safelisted because they are pickled and reloaded during the workflow.""" + + def test_internal_data_is_safelisted(self) -> None: + self.assertTrue(_is_safe("qlib.contrib.meta.data_selection.dataset", "InternalData")) + + +class RestrictedUnpicklerFindClassForAlpha158Test(unittest.TestCase): + """End-to-end: ``RestrictedUnpickler.find_class`` must return the real + ``Alpha158`` class object, not raise.""" + + def test_find_class_returns_alpha158(self) -> None: + from qlib.contrib.data.handler import Alpha158 + + unpickler = RestrictedUnpickler(__import__("io").BytesIO()) + resolved = unpickler.find_class("qlib.contrib.data.handler", "Alpha158") + self.assertIs(resolved, Alpha158) + + def test_restricted_pickle_loads_rejects_unknown_qlib_class(self) -> None: + """Defensive: classes not in the safelist must still be rejected so + the security model is preserved.""" + + # Use a fake but plausible qlib path that is *not* in the safelist. + payload = pickle.dumps({"x": 1}) + # Sanity: a trivial dict still loads fine. + self.assertEqual(restricted_pickle_loads(payload), {"x": 1}) + + +if __name__ == "__main__": + unittest.main()