Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using config to register custom operators #476

Merged
merged 2 commits into from
Jun 24, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 12 additions & 6 deletions qlib/data/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .base import Expression, ExpressionOps
from ..log import get_module_logger
from ..utils import get_cls_kwargs

try:
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
Expand Down Expand Up @@ -1496,15 +1497,20 @@ def reset(self):
self._ops = {}

def register(self, ops_list):
Copy link
Collaborator

@you-n-g you-n-g Jun 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add more docs about the register function (especially the ops_list).

for operator in ops_list:
if not issubclass(operator, ExpressionOps):
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator))
for _operator in ops_list:
if isinstance(_operator, dict):
_ops_class, _ = get_cls_kwargs(_operator)
else:
_ops_class = _operator

if operator.__name__ in self._ops:
if not issubclass(_ops_class, ExpressionOps):
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))

if _ops_class.__name__ in self._ops:
get_module_logger(self.__class__.__name__).warning(
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
"The custom operator [{}] will override the qlib default definition".format(_ops_class.__name__)
)
self._ops[operator.__name__] = operator
self._ops[_ops_class.__name__] = _ops_class

def __getattr__(self, key):
if key not in self._ops:
Expand Down