Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Online bug fix, enhancement & docs for dataset, workflow, trainer ... #466

Merged
merged 5 commits into from Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 10 additions & 3 deletions qlib/__init__.py
Expand Up @@ -20,11 +20,17 @@ def init(default_conf="client", **kwargs):
from .config import C
from .data.cache import H

H.clear()

# FIXME: this logger ignored the level in config
logger = get_module_logger("Initialization", level=logging.INFO)

skip_if_reg = kwargs.pop("skip_if_reg", False)
if skip_if_reg and C.registered:
Derek-Wds marked this conversation as resolved.
Show resolved Hide resolved
# if we reinitialize Qlib during running an experiment `R.start`.
# it will result in loss of the recorder
logger.warning("Skip initialization because `skip_if_reg is True`")
return

H.clear()
C.set(default_conf, **kwargs)

# check path if server/local
Expand Down Expand Up @@ -197,14 +203,15 @@ def auto_init(**kwargs):
- Find the project configuration and init qlib
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
- Skip initialization if already initialized
"""
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)

try:
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
except FileNotFoundError:
init(**kwargs)
else:

conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)
Expand Down
30 changes: 23 additions & 7 deletions qlib/data/dataset/__init__.py
@@ -1,6 +1,6 @@
from ...utils.serial import Serializable
from typing import Union, List, Tuple, Dict, Text, Optional
from ...utils import init_instance_by_config, np_ffill
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
from copy import deepcopy
Expand Down Expand Up @@ -243,6 +243,8 @@ class TSDataSampler:

It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
dataset based on tabular data.
- On time step dimension, the smaller index indicates the historical data and the larger index indicates the future
data.

If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
more powerful subclasses.
Expand Down Expand Up @@ -309,11 +311,19 @@ def __init__(
self.data_index = deepcopy(self.data.index)

if flt_data is not None:
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
if isinstance(flt_data, pd.DataFrame):
assert len(flt_data.columns) == 1
flt_data = flt_data.iloc[:, 0]
# NOTE: bool(np.nan) is True !!!!!!!!
# make sure reindex comes first. Otherwise extra NaN may appear.
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]

self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.start_idx, self.end_idx = self.data_index.slice_locs(
start=time_to_slc_point(start), end=time_to_slc_point(end)
)
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance

del self.data # save memory
Expand Down Expand Up @@ -341,7 +351,7 @@ def config(self, **kwargs):
setattr(self, k, v)

@staticmethod
def build_index(data: pd.DataFrame) -> dict:
def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
"""
The relation of the data

Expand All @@ -352,9 +362,15 @@ def build_index(data: pd.DataFrame) -> dict:

Returns
-------
dict:
{<index>: <prev_index or None>}
# get the previous index of a line given index
Tuple[pd.DataFrame, dict]:
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
datetime
2021-01-11 0 1 2 3 4 5 ...
2021-01-12 4146 4147 4148 4149 4150 4151 ...
2021-01-13 8293 8294 8295 8296 8297 8298 ...
2021-01-14 12441 12442 12443 12444 12445 12446 ...
2) the second element: {<original index>: <row, col>}
"""
# object incase of pandas converting int to flaot
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
Expand Down
8 changes: 5 additions & 3 deletions qlib/log.py
Expand Up @@ -28,16 +28,18 @@ class QlibLogger(metaclass=MetaLogger):

def __init__(self, module_name):
self.module_name = module_name
self.level = 0
# this feature name conflicts with the attribute with Logger
# rename it to avoid some corner cases that result in comparing `str` and `int`
self.__level = 0

@property
def logger(self):
logger = logging.getLogger(self.module_name)
logger.setLevel(self.level)
logger.setLevel(self.__level)
return logger

def setLevel(self, level):
self.level = level
self.__level = level

def __getattr__(self, name):
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
Expand Down
56 changes: 37 additions & 19 deletions qlib/model/trainer.py
Expand Up @@ -8,7 +8,7 @@
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.

``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
"""

import socket
Expand Down Expand Up @@ -153,6 +153,9 @@ def is_delay(self) -> bool:
"""
return self.delay

def __call__(self, *args, **kwargs) -> list:
return self.end_train(self.train(*args, **kwargs))


class TrainerR(Trainer):
"""
Expand Down Expand Up @@ -286,19 +289,26 @@ class TrainerRM(Trainer):
# This tag is the _id in TaskManager to distinguish tasks.
TM_ID = "_id in TaskManager"

def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
def __init__(
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
):
"""
Init TrainerR.

Args:
experiment_name (str): the default name of experiment.
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default training method. Defaults to `task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
"""

super().__init__()
self.experiment_name = experiment_name
self.task_pool = task_pool
self.train_func = train_func
self.skip_run_task = skip_run_task

def train(
self,
Expand Down Expand Up @@ -340,15 +350,16 @@ def train(
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
query = {"_id": {"$in": _id_list}}
run_task(
train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
if not self.skip_run_task:
run_task(
train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)

if not self.is_delay():
tm.wait(query=query)
Expand Down Expand Up @@ -411,6 +422,7 @@ def __init__(
task_pool: str = None,
train_func=begin_task_train,
end_train_func=end_task_train,
skip_run_task: bool = False,
):
"""
Init DelayTrainerRM.
Expand All @@ -420,10 +432,15 @@ def __init__(
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
"""
super().__init__(experiment_name, task_pool, train_func)
self.end_train_func = end_train_func
self.delay = True
self.skip_run_task = skip_run_task

def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Expand Down Expand Up @@ -477,14 +494,15 @@ def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kw
_id_list.append(rec.list_tags()[self.TM_ID])

query = {"_id": {"$in": _id_list}}
run_task(
end_train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
if not self.skip_run_task:
run_task(
end_train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)

TaskManager(task_pool=task_pool).wait(query=query)

Expand Down
22 changes: 22 additions & 0 deletions qlib/utils/__init__.py
Expand Up @@ -642,6 +642,28 @@ def split_pred(pred, number=None, split_date=None):
return pred_left, pred_right


def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:
"""
Time slicing in Qlib or Pandas is a frequently-used action.
However, user often input all kinds of data format to represent time.
This function will help user to convert these inputs into a uniform format which is friendly to time slicing.

Parameters
----------
t : Union[None, str, pd.Timestamp]
original time

Returns
-------
Union[None, pd.Timestamp]:
"""
if t is None:
# None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]).
return t
else:
return pd.Timestamp(t)


def can_use_cache():
res = True
r = get_redis_connection()
Expand Down
11 changes: 9 additions & 2 deletions qlib/workflow/exp.py
Expand Up @@ -213,11 +213,15 @@ def _get_recorder(self, recorder_id=None, recorder_name=None):
"""
raise NotImplementedError(f"Please implement the `_get_recorder` method")

def list_recorders(self):
def list_recorders(self, **flt_kwargs):
"""
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.

flt_kwargs : dict
filter recorders by conditions
e.g. list_recorders(status=Recorder.STATUS_FI)

Returns
-------
A dictionary (id -> recorder) of recorder information that being stored.
Expand Down Expand Up @@ -320,11 +324,14 @@ def delete_recorder(self, recorder_id=None, recorder_name=None):

UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!

def list_recorders(self, max_results=UNLIMITED):
def list_recorders(self, max_results=UNLIMITED, status=None):
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
recorders = dict()
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
if status is not None:
Derek-Wds marked this conversation as resolved.
Show resolved Hide resolved
if recorder.status != status:
continue
recorders[runs[i].info.run_id] = recorder

return recorders
9 changes: 9 additions & 0 deletions qlib/workflow/expm.py
Expand Up @@ -351,6 +351,15 @@ def _get_exp(self, experiment_id=None, experiment_name=None):
experiment_id is not None or experiment_name is not None
), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder."
if experiment_id is not None:
try:
Derek-Wds marked this conversation as resolved.
Show resolved Hide resolved
experiment_id = int(experiment_id)
except ValueError as e:
msg = "The `experiment_id` for mlflow backend must be `int`"
logger.error(msg)
# We have to raise type error here
# - The error looks like type error
# - Value Error will be catched
raise TypeError(msg)
try:
exp = self.client.get_experiment(experiment_id)
if exp.lifecycle_stage.upper() == "DELETED":
Expand Down
15 changes: 10 additions & 5 deletions qlib/workflow/online/manager.py
Expand Up @@ -6,7 +6,7 @@

With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
So this module provides a series of methods to control this process.
So this module provides a series of methods to control this process.

This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
Which means you can verify your strategy or find a better one.
Expand All @@ -31,7 +31,7 @@

Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
for the ability to multitasking. It means all tasks in all routines
can be REAL trained at the end of simulating. The signals will be prepared well at
can be REAL trained at the end of simulating. The signals will be prepared well at
different time segments (based on whether or not any new model is online).
========================= ===================================================================================
"""
Expand Down Expand Up @@ -113,6 +113,8 @@ def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dic
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
# FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the
# start.
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models

Expand Down Expand Up @@ -148,8 +150,6 @@ def routine(
models_list = []
for strategy in self.strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()

tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
Expand All @@ -158,6 +158,11 @@ def routine(
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models

# The online model may changes in the above processes
# So updating the predictions of online models should be the last step
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()

if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
for strategy, models in zip(self.strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
Expand Down Expand Up @@ -236,7 +241,7 @@ def get_signals(self) -> Union[pd.Series, pd.DataFrame]:
SIM_LOG_NAME = "SIMULATE_INFO"

def simulate(
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
self, end_time=None, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
) -> Union[pd.Series, pd.DataFrame]:
"""
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
Expand Down
6 changes: 6 additions & 0 deletions qlib/workflow/online/strategy.py
Expand Up @@ -52,6 +52,12 @@ def prepare_online_models(self, trained_models, cur_time=None) -> List[object]:

NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.

**NOTE**:
Current implementation is very naive. Here is a more complex situation which is more closer to the
practical scenarios.
1. Train new models at the day before `test_start` (at time stamp `T`)
2. Switch models at the `test_start` (at time timestamp `T + 1` typically)

Args:
models (list): a list of models.
cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.
Expand Down