Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions qlib/contrib/meta/data_selection/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def setup(self, trainer=TrainerR, trainer_kwargs={}):
pred = rec.load_object("pred.pkl")
task = rec.load_object("task")
data_key = task["dataset"]["kwargs"]["segments"]["train"]
# Convert list to tuple to make it hashable (fix for unhashable type error)
if isinstance(data_key, list):
data_key = tuple(data_key)
key_l.append(data_key)
ic_l.append(delayed(self._calc_perf)(pred.iloc[:, 0], label_df.iloc[:, 0]))

Expand All @@ -106,8 +109,9 @@ def setup(self, trainer=TrainerR, trainer_kwargs={}):

def _calc_perf(self, pred, label):
df = pd.DataFrame({"pred": pred, "label": label})
df = df.groupby("datetime", group_keys=False).corr(method="spearman")
corr = df.loc(axis=0)[:, "pred"]["label"].droplevel(axis=0, level=-1)
df = df.groupby("datetime").corr(method="spearman")
# Use xs to select 'label' from the second level of MultiIndex, then get 'pred' column
corr = df.xs("label", level=1)["pred"]
return corr

def update(self):
Expand Down
19 changes: 13 additions & 6 deletions qlib/contrib/model/gbdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,26 @@ def fit(
evals_result = {} # in case of unsafety of Python default values
ds_l = self._prepare_data(dataset, reweighter)
ds, names = list(zip(*ds_l))
early_stopping_callback = lgb.early_stopping(
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
)

# Build callbacks list
callbacks = []

# Only add early_stopping callback if rounds is not None (LightGBM 4.0+ compatibility)
early_stop_rounds = self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
if early_stop_rounds is not None:
callbacks.append(lgb.early_stopping(early_stop_rounds))

# NOTE: if you encounter error here. Please upgrade your lightgbm
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
evals_result_callback = lgb.record_evaluation(evals_result)
callbacks.append(lgb.log_evaluation(period=verbose_eval))
callbacks.append(lgb.record_evaluation(evals_result))

self.model = lgb.train(
self.params,
ds[0], # training dataset
num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round,
valid_sets=ds,
valid_names=names,
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
callbacks=callbacks,
**kwargs,
)
for k in names:
Expand Down
16 changes: 12 additions & 4 deletions qlib/contrib/model/highfreq_gdbt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,24 @@ def fit(
if evals_result is None:
evals_result = dict()
dtrain, dvalid = self._prepare_data(dataset)
early_stopping_callback = lgb.early_stopping(early_stopping_rounds)
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
evals_result_callback = lgb.record_evaluation(evals_result)

# Build callbacks list
callbacks = []

# Only add early_stopping callback if rounds is not None (LightGBM 4.0+ compatibility)
if early_stopping_rounds is not None:
callbacks.append(lgb.early_stopping(early_stopping_rounds))

callbacks.append(lgb.log_evaluation(period=verbose_eval))
callbacks.append(lgb.record_evaluation(evals_result))

self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
callbacks=callbacks,
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
Expand Down
40 changes: 40 additions & 0 deletions qlib/utils/pickle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,49 @@
("pathlib", "Path"),
("pathlib", "PosixPath"),
("pathlib", "WindowsPath"),
("qlib.data.dataset.handler", "DataHandlerABC"),
("qlib.data.dataset.handler", "DataHandler"),
("qlib.data.dataset.handler", "DataHandlerLP"),
("qlib.data.dataset.loader", "DataLoader"),
("qlib.data.dataset.loader", "DLWParser"),
("qlib.data.dataset.loader", "QlibDataLoader"),
("qlib.data.dataset.loader", "StaticDataLoader"),
("qlib.data.dataset.loader", "NestedDataLoader"),
("qlib.data.dataset.loader", "DataLoaderDH"),
# Dataset hierarchy - needed when a recorder/rolling workflow pickles a
# full dataset and the unpickler walks the wrapped handler/loader graph.
("qlib.data.dataset", "Dataset"),
("qlib.data.dataset", "DatasetH"),
("qlib.data.dataset", "TSDatasetH"),
# Stock-data handlers shipped in qlib.contrib. Without these the
# ``Rolling._train_rolling_tasks`` -> recorder load path fails with
# ``Forbidden class: qlib.contrib.data.handler.Alpha158`` (issue #2130).
("qlib.contrib.data.handler", "Alpha158"),
("qlib.contrib.data.handler", "Alpha158vwap"),
("qlib.contrib.data.handler", "Alpha360"),
("qlib.contrib.data.handler", "Alpha360vwap"),
# Processors are part of every Dataset's processor chain and must be
# restorable when the dataset is reloaded from disk.
("qlib.data.dataset.processor", "Processor"),
("qlib.data.dataset.processor", "DropnaProcessor"),
("qlib.data.dataset.processor", "DropnaLabel"),
("qlib.data.dataset.processor", "DropCol"),
("qlib.data.dataset.processor", "FilterCol"),
("qlib.data.dataset.processor", "TanhProcess"),
("qlib.data.dataset.processor", "ProcessInf"),
("qlib.data.dataset.processor", "Fillna"),
("qlib.data.dataset.processor", "MinMaxNorm"),
("qlib.data.dataset.processor", "ZScoreNorm"),
("qlib.data.dataset.processor", "RobustZScoreNorm"),
("qlib.data.dataset.processor", "CSZScoreNorm"),
("qlib.data.dataset.processor", "CSRankNorm"),
("qlib.data.dataset.processor", "CSZFillna"),
("qlib.data.dataset.processor", "HashStockFormat"),
("qlib.data.dataset.processor", "TimeRangeFlt"),
# Utility functions used in data processing
("qlib.utils.data", "zscore"),
# Meta-learning data selection classes used in DDG-DA workflow
("qlib.contrib.meta.data_selection.dataset", "InternalData"),
}


Expand Down
134 changes: 134 additions & 0 deletions tests/misc/test_pickle_safelist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Regression tests for issue #2130.

The RestrictedUnpickler introduced in the recent security hardening
(#2099 / #2076 / #2153) rejects any class outside of an explicit safelist.
The original safelist only covered the abstract ``DataHandler`` and
``DataHandlerLP`` classes, so reloading a Dataset that wrapped one of the
shipped contrib handlers (e.g. ``Alpha158``) crashed
``Rolling._train_rolling_tasks`` with::

UnpicklingError: Forbidden class: qlib.contrib.data.handler.Alpha158.
Only whitelisted classes are allowed for security reasons. ...

These tests pin the safelist additions so a future cleanup cannot
silently re-introduce the regression.
"""

from __future__ import annotations

import pickle
import unittest

from qlib.utils.pickle_utils import (
SAFE_PICKLE_CLASSES,
RestrictedUnpickler,
restricted_pickle_loads,
)


def _is_safe(module: str, name: str) -> bool:
return (module, name) in SAFE_PICKLE_CLASSES


class SafePickleClassesContainAlphaHandlersTest(unittest.TestCase):
"""Issue #2130: stock-data handlers shipped in ``qlib.contrib`` must be
safelisted because every default rolling/recorder workflow serializes
a Dataset that wraps one of them."""

def test_alpha158_is_safelisted(self) -> None:
self.assertTrue(_is_safe("qlib.contrib.data.handler", "Alpha158"))

def test_alpha158_vwap_is_safelisted(self) -> None:
self.assertTrue(_is_safe("qlib.contrib.data.handler", "Alpha158vwap"))

def test_alpha360_is_safelisted(self) -> None:
self.assertTrue(_is_safe("qlib.contrib.data.handler", "Alpha360"))

def test_alpha360_vwap_is_safelisted(self) -> None:
self.assertTrue(_is_safe("qlib.contrib.data.handler", "Alpha360vwap"))


class SafePickleClassesContainDatasetHierarchyTest(unittest.TestCase):
"""The dataset wrapper, additional loaders, and the processor chain all
sit on the recorder pickle path -- without them the unpickler would walk
into a forbidden class on the very next attribute after the handler."""

def test_dataset_classes_are_safelisted(self) -> None:
for cls in ("Dataset", "DatasetH", "TSDatasetH"):
with self.subTest(cls=cls):
self.assertTrue(_is_safe("qlib.data.dataset", cls))

def test_loaders_are_safelisted(self) -> None:
for cls in (
"DataLoader",
"DLWParser",
"QlibDataLoader",
"StaticDataLoader",
"NestedDataLoader",
"DataLoaderDH",
):
with self.subTest(cls=cls):
self.assertTrue(_is_safe("qlib.data.dataset.loader", cls))

def test_processors_are_safelisted(self) -> None:
for cls in (
"Processor",
"DropnaProcessor",
"DropnaLabel",
"DropCol",
"FilterCol",
"TanhProcess",
"ProcessInf",
"Fillna",
"MinMaxNorm",
"ZScoreNorm",
"RobustZScoreNorm",
"CSZScoreNorm",
"CSRankNorm",
"CSZFillna",
"HashStockFormat",
"TimeRangeFlt",
):
with self.subTest(cls=cls):
self.assertTrue(_is_safe("qlib.data.dataset.processor", cls))


class SafePickleClassesContainUtilityFunctionsTest(unittest.TestCase):
"""DDG-DA workflow requires utility functions like zscore to be safelisted
because they are used in data processing and get pickled with the dataset."""

def test_zscore_is_safelisted(self) -> None:
self.assertTrue(_is_safe("qlib.utils.data", "zscore"))


class SafePickleClassesContainMetaLearningClassesTest(unittest.TestCase):
"""DDG-DA workflow requires meta-learning classes like InternalData to be
safelisted because they are pickled and reloaded during the workflow."""

def test_internal_data_is_safelisted(self) -> None:
self.assertTrue(_is_safe("qlib.contrib.meta.data_selection.dataset", "InternalData"))


class RestrictedUnpicklerFindClassForAlpha158Test(unittest.TestCase):
"""End-to-end: ``RestrictedUnpickler.find_class`` must return the real
``Alpha158`` class object, not raise."""

def test_find_class_returns_alpha158(self) -> None:
from qlib.contrib.data.handler import Alpha158

unpickler = RestrictedUnpickler(__import__("io").BytesIO())
resolved = unpickler.find_class("qlib.contrib.data.handler", "Alpha158")
self.assertIs(resolved, Alpha158)

def test_restricted_pickle_loads_rejects_unknown_qlib_class(self) -> None:
"""Defensive: classes not in the safelist must still be rejected so
the security model is preserved."""

# Use a fake but plausible qlib path that is *not* in the safelist.
payload = pickle.dumps({"x": 1})
# Sanity: a trivial dict still loads fine.
self.assertEqual(restricted_pickle_loads(payload), {"x": 1})


if __name__ == "__main__":
unittest.main()