Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make OnlineToolR more user-friendly and fix some bugs #475

Merged
merged 5 commits into from Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion qlib/data/dataset/loader.py
Expand Up @@ -207,7 +207,12 @@ def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame
df = self._data.loc(axis=0)[:, instruments]
if start_time is None and end_time is None:
return df # NOTE: avoid copy by loc
return df.loc[pd.Timestamp(start_time) : pd.Timestamp(end_time)]
# pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
if start_time is not None:
start_time = pd.Timestamp(start_time)
if end_time is not None:
end_time = pd.Timestamp(end_time)
return df.loc[start_time:end_time]

def _maybe_load_raw_data(self):
if self._data is not None:
Expand Down
2 changes: 1 addition & 1 deletion qlib/log.py
Expand Up @@ -70,7 +70,7 @@ def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logge

class TimeInspector:

timer_logger = get_module_logger("timer", level=logging.WARNING)
timer_logger = get_module_logger("timer", level=logging.INFO)

time_marks = []

Expand Down
6 changes: 3 additions & 3 deletions qlib/utils/serial.py
Expand Up @@ -92,16 +92,16 @@ def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list
@classmethod
def load(cls, filepath):
"""
Load the collector from a filepath.
Load the serializable class from a filepath.

Args:
filepath (str): the path of file

Raises:
TypeError: the pickled file must be `Collector`
TypeError: the pickled file must be `type(cls)`

Returns:
Collector: the instance of Collector
`type(cls)`: the instance of `type(cls)`
"""
with open(filepath, "rb") as f:
object = cls.get_backend().load(f)
Expand Down
5 changes: 2 additions & 3 deletions qlib/workflow/online/update.py
Expand Up @@ -135,10 +135,9 @@ def update(self, dataset: DatasetH = None):
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
# https://github.com/pytorch/pytorch/issues/16797

start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
if start_time > self.to_date:
if self.last_end >= self.to_date:
self.logger.info(
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
f"The prediction in {self.record.info['id']} are latest ({self.last_end}). No need to update to {self.to_date}."
)
return

Expand Down
56 changes: 43 additions & 13 deletions qlib/workflow/online/utils.py
Expand Up @@ -12,6 +12,7 @@

from qlib.log import get_module_logger
from qlib.utils import get_cls_kwargs
from qlib.utils.exceptions import QlibException
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
Expand Down Expand Up @@ -90,15 +91,15 @@ class OnlineToolR(OnlineTool):
The implementation of OnlineTool based on (R)ecorder.
"""

def __init__(self, experiment_name: str):
def __init__(self, default_exp_name: str = None):
"""
Init OnlineToolR.

Args:
experiment_name (str): the experiment name.
default_exp_name (str): the default experiment name.
"""
super().__init__()
self.exp_name = experiment_name
self.default_exp_name = default_exp_name

def set_online_tag(self, tag, recorder: Union[Recorder, List]):
"""
Expand Down Expand Up @@ -127,45 +128,74 @@ def get_online_tag(self, recorder: Recorder) -> str:
tags = recorder.list_tags()
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)

def reset_online_tag(self, recorder: Union[Recorder, List]):
def reset_online_tag(self, recorder: Union[Recorder, List], exp_name: str = None):
"""
Offline all models and set the recorders to 'online'.

Args:
recorder (Union[Recorder, List]):
the recorder you want to reset to 'online'.
exp_name (str): the experiment name. If None, then use default_exp_name.

"""
if exp_name is None:
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
if self.default_exp_name is None:
raise ValueError(
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
)
exp_name = self.default_exp_name
if isinstance(recorder, Recorder):
recorder = [recorder]
recs = list_recorders(self.exp_name)
recs = list_recorders(exp_name)
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
self.set_online_tag(self.ONLINE_TAG, recorder)

def online_models(self) -> list:
def online_models(self, exp_name: str = None) -> list:
"""
Get current `online` models

Args:
exp_name (str): the experiment name. If None, then use default_exp_name.

Returns:
list: a list of `online` models.
"""
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
if exp_name is None:
if self.default_exp_name is None:
raise ValueError(
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
)
exp_name = self.default_exp_name
return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())

def update_online_pred(self, to_date=None):
def update_online_pred(self, to_date=None, exp_name: str = None):
"""
Update the predictions of online models to to_date.

Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
"""
online_models = self.online_models()
exp_name (str): the experiment name. If None, then use default_exp_name.
"""
if exp_name is None:
if self.default_exp_name is None:
raise ValueError(
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
)
exp_name = self.default_exp_name
online_models = self.online_models(exp_name=exp_name)
for rec in online_models:
hist_ref = 0
task = rec.load_object("task")
# Special treatment of historical dependencies
cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset")
if issubclass(cls, TSDatasetH):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()

self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
try:
updater = PredUpdater(rec, to_date=to_date, hist_ref=hist_ref)
except QlibException as e:
# skip the recorder without pred
self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.")
continue
updater.update()

self.logger.info(f"Finished updating {len(online_models)} online model predictions of {exp_name}.")
12 changes: 9 additions & 3 deletions qlib/workflow/recorder.py
Expand Up @@ -5,6 +5,9 @@
import shutil, os, pickle, tempfile, codecs, pickle
from pathlib import Path
from datetime import datetime

from mlflow.exceptions import MlflowException
from qlib.utils.exceptions import QlibException
from ..utils.objm import FileManager
from ..log import get_module_logger

Expand Down Expand Up @@ -308,9 +311,12 @@ def save_objects(self, local_path=None, artifact_path=None, **kwargs):

def load_object(self, name):
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
return pickle.load(f)
try:
path = self.client.download_artifacts(self.id, name)
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
with Path(path).open("rb") as f:
return pickle.load(f)
except OSError as e:
raise QlibException(message=str(e))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This exception should be defined in the interface of Recorder


def log_params(self, **kwargs):
for name, data in kwargs.items():
Expand Down