diff --git a/qlib/__init__.py b/qlib/__init__.py index b43dd6fc42a..8a306787786 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -27,8 +27,8 @@ def init(default_conf="client", **kwargs): Parameters ---------- - default_conf: str - the default value is client. Accepted values: client/server. + default_conf: str or ModeType + the default value is client. Accepted values: ModeType.CLIENT/ModeType.SERVER (or "client"/"server"). **kwargs : clear_mem_cache: str the default value is True; diff --git a/qlib/config.py b/qlib/config.py index 4e5d62564f7..c9e289e103b 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -22,7 +22,7 @@ from typing import Callable, Optional, Union from typing import TYPE_CHECKING -from qlib.constant import REG_CN, REG_US, REG_TW +from qlib.constant import REG_CN, REG_US, REG_TW, ModeType if TYPE_CHECKING: from qlib.utils.time import Freq @@ -247,7 +247,7 @@ def register_from_C(config, skip_register=True): } MODE_CONF = { - "server": { + ModeType.SERVER: { # config it in qlib.init() "provider_uri": "", # redis @@ -260,7 +260,7 @@ def register_from_C(config, skip_register=True): "local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(), "mount_path": None, }, - "client": { + ModeType.CLIENT: { # config it in user's own code "provider_uri": QSETTINGS.provider_uri, # cache @@ -385,7 +385,7 @@ def get_data_uri(self, freq: Optional[Union[str, Freq]] = None) -> Path: def set_mode(self, mode): # raise KeyError - self.update(MODE_CONF[mode]) + self.update(MODE_CONF[ModeType(mode)]) # TODO: update region based on kwargs def set_region(self, region): @@ -420,7 +420,7 @@ def resolve_path(self): self["provider_uri"] = _provider_uri self["mount_path"] = _mount_path - def set(self, default_conf: str = "client", **kwargs): + def set(self, default_conf: Union[str, ModeType] = ModeType.CLIENT, **kwargs): """ configure qlib based on the input parameters @@ -435,8 +435,8 @@ def set(self, default_conf: str = "client", **kwargs): Parameters ---------- - default_conf : str - the default config template chosen by user: "server", "client" + default_conf : str or ModeType + the default config template chosen by user: ModeType.SERVER, ModeType.CLIENT (or "server", "client") """ from .utils import set_log_with_config, get_module_logger, can_use_cache # pylint: disable=C0415 diff --git a/qlib/constant.py b/qlib/constant.py index ac6c76ae22c..f9daa677a78 100644 --- a/qlib/constant.py +++ b/qlib/constant.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. # REGION CONST +from enum import Enum from typing import TypeVar import numpy as np @@ -20,3 +21,10 @@ ONE_MIN = pd.Timedelta("1min") EPS_T = pd.Timedelta("1s") # use 1 second to exclude the right interval point float_or_ndarray = TypeVar("float_or_ndarray", float, np.ndarray) + + +class ModeType(str, Enum): + """Mode type for qlib initialization: client or server.""" + + CLIENT = "client" + SERVER = "server"