Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Online bug fix, enhancement & docs for dataset, workflow, trainer ... #466

Merged
merged 5 commits into from Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 10 additions & 3 deletions qlib/__init__.py
Expand Up @@ -20,11 +20,17 @@ def init(default_conf="client", **kwargs):
from .config import C
from .data.cache import H

H.clear()

# FIXME: this logger ignored the level in config
logger = get_module_logger("Initialization", level=logging.INFO)

skip_if_reg = kwargs.pop("skip_if_reg", False)
if skip_if_reg and C.registered:
Derek-Wds marked this conversation as resolved.
Show resolved Hide resolved
# if we reinitialize Qlib during running an experiment `R.start`.
# it will result in loss of the recorder
logger.warning("Skip initialization because `skip_if_reg is True`")
return

H.clear()
C.set(default_conf, **kwargs)

# check path if server/local
Expand Down Expand Up @@ -197,14 +203,15 @@ def auto_init(**kwargs):
- Find the project configuration and init qlib
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
- Skip initialization if already initialized
"""
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)

try:
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
except FileNotFoundError:
init(**kwargs)
else:

conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)
Expand Down
30 changes: 23 additions & 7 deletions qlib/data/dataset/__init__.py
@@ -1,6 +1,6 @@
from ...utils.serial import Serializable
from typing import Union, List, Tuple, Dict, Text, Optional
from ...utils import init_instance_by_config, np_ffill
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
from copy import deepcopy
Expand Down Expand Up @@ -243,6 +243,8 @@ class TSDataSampler:

It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
dataset based on tabular data.
- On time step dimension, the smaller index indicates the historical data and the larger index indicates the future
data.

If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
more powerful subclasses.
Expand Down Expand Up @@ -309,11 +311,19 @@ def __init__(
self.data_index = deepcopy(self.data.index)

if flt_data is not None:
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
if isinstance(flt_data, pd.DataFrame):
assert len(flt_data.columns) == 1
flt_data = flt_data.iloc[:, 0]
# NOTE: bool(np.nan) is True !!!!!!!!
# make sure reindex comes first. Otherwise extra NaN may appear.
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]

self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.start_idx, self.end_idx = self.data_index.slice_locs(
start=time_to_slc_point(start), end=time_to_slc_point(end)
)
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance

del self.data # save memory
Expand Down Expand Up @@ -341,7 +351,7 @@ def config(self, **kwargs):
setattr(self, k, v)

@staticmethod
def build_index(data: pd.DataFrame) -> dict:
def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
"""
The relation of the data

Expand All @@ -352,9 +362,15 @@ def build_index(data: pd.DataFrame) -> dict:

Returns
-------
dict:
{<index>: <prev_index or None>}
# get the previous index of a line given index
Tuple[pd.DataFrame, dict]:
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
datetime
2021-01-11 0 1 2 3 4 5 ...
2021-01-12 4146 4147 4148 4149 4150 4151 ...
2021-01-13 8293 8294 8295 8296 8297 8298 ...
2021-01-14 12441 12442 12443 12444 12445 12446 ...
2) the second element: {<original index>: <row, col>}
"""
# object incase of pandas converting int to flaot
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
Expand Down
8 changes: 5 additions & 3 deletions qlib/log.py
Expand Up @@ -28,16 +28,18 @@ class QlibLogger(metaclass=MetaLogger):

def __init__(self, module_name):
self.module_name = module_name
self.level = 0
# this feature name conflicts with the attribute with Logger
# rename it to avoid some corner cases that result in comparing `str` and `int`
self.__level = 0

@property
def logger(self):
logger = logging.getLogger(self.module_name)
logger.setLevel(self.level)
logger.setLevel(self.__level)
return logger

def setLevel(self, level):
self.level = level
self.__level = level

def __getattr__(self, name):
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
Expand Down
56 changes: 37 additions & 19 deletions qlib/model/trainer.py
Expand Up @@ -8,7 +8,7 @@
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.

``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
"""

import socket
Expand Down Expand Up @@ -153,6 +153,9 @@ def is_delay(self) -> bool:
"""
return self.delay

def __call__(self, *args, **kwargs) -> list:
return self.end_train(self.train(*args, **kwargs))


class TrainerR(Trainer):
"""
Expand Down Expand Up @@ -286,19 +289,26 @@ class TrainerRM(Trainer):
# This tag is the _id in TaskManager to distinguish tasks.
TM_ID = "_id in TaskManager"

def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
def __init__(
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
):
"""
Init TrainerR.

Args:
experiment_name (str): the default name of experiment.
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default training method. Defaults to `task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
"""

super().__init__()
self.experiment_name = experiment_name
self.task_pool = task_pool
self.train_func = train_func
self.skip_run_task = skip_run_task

def train(
self,
Expand Down Expand Up @@ -340,15 +350,16 @@ def train(
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
query = {"_id": {"$in": _id_list}}
run_task(
train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
if not self.skip_run_task:
run_task(
train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)

if not self.is_delay():
tm.wait(query=query)
Expand Down Expand Up @@ -411,6 +422,7 @@ def __init__(
task_pool: str = None,
train_func=begin_task_train,
end_train_func=end_task_train,
skip_run_task: bool = False,
):
"""
Init DelayTrainerRM.
Expand All @@ -420,10 +432,15 @@ def __init__(
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
"""
super().__init__(experiment_name, task_pool, train_func)
self.end_train_func = end_train_func
self.delay = True
self.skip_run_task = skip_run_task

def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Expand Down Expand Up @@ -477,14 +494,15 @@ def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kw
_id_list.append(rec.list_tags()[self.TM_ID])

query = {"_id": {"$in": _id_list}}
run_task(
end_train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
if not self.skip_run_task:
run_task(
end_train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)

TaskManager(task_pool=task_pool).wait(query=query)

Expand Down
22 changes: 22 additions & 0 deletions qlib/utils/__init__.py
Expand Up @@ -642,6 +642,28 @@ def split_pred(pred, number=None, split_date=None):
return pred_left, pred_right


def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:
"""
Time slicing in Qlib or Pandas is a frequently-used action.
However, user often input all kinds of data format to represent time.
This function will help user to convert these inputs into a uniform format which is friendly to time slicing.

Parameters
----------
t : Union[None, str, pd.Timestamp]
original time

Returns
-------
Union[None, pd.Timestamp]:
"""
if t is None:
# None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]).
return t
else:
return pd.Timestamp(t)


def can_use_cache():
res = True
r = get_redis_connection()
Expand Down
10 changes: 6 additions & 4 deletions qlib/workflow/__init__.py
Expand Up @@ -215,9 +215,9 @@ def list_recorders(self, experiment_id=None, experiment_name=None):
-------
A dictionary (id -> recorder) of recorder information that being stored.
"""
return self.get_exp(experiment_id, experiment_name).list_recorders()
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()

def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
"""
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
Expand Down Expand Up @@ -262,7 +262,7 @@ def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True)

# Case 2
with R.start('test'):
exp = R.get_exp('test1')
exp = R.get_exp(experiment_name='test1')

# Case 3
exp = R.get_exp() -> a default experiment.
Expand All @@ -287,7 +287,9 @@ def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True)
-------
An experiment instance with given id or name.
"""
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
return self.exp_manager.get_exp(
experiment_id=experiment_id, experiment_name=experiment_name, create=create, start=False
)

def delete_exp(self, experiment_id=None, experiment_name=None):
"""
Expand Down
21 changes: 18 additions & 3 deletions qlib/workflow/exp.py
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from typing import Union
import mlflow, logging
from mlflow.entities import ViewType
from mlflow.exceptions import MlflowException
Expand Down Expand Up @@ -213,11 +214,15 @@ def _get_recorder(self, recorder_id=None, recorder_name=None):
"""
raise NotImplementedError(f"Please implement the `_get_recorder` method")

def list_recorders(self):
def list_recorders(self, **flt_kwargs):
"""
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.

flt_kwargs : dict
filter recorders by conditions
e.g. list_recorders(status=Recorder.STATUS_FI)

Returns
-------
A dictionary (id -> recorder) of recorder information that being stored.
Expand Down Expand Up @@ -320,11 +325,21 @@ def delete_recorder(self, recorder_id=None, recorder_name=None):

UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!

def list_recorders(self, max_results=UNLIMITED):
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None):
"""
Parameters
----------
max_results : int
the number limitation of the results
status : str
the criteria based on status to filter results.
`None` indicates no filtering.
"""
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
recorders = dict()
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
recorders[runs[i].info.run_id] = recorder
if status is None or recorder.status == status:
recorders[runs[i].info.run_id] = recorder

return recorders
13 changes: 11 additions & 2 deletions qlib/workflow/expm.py
Expand Up @@ -109,7 +109,7 @@ def search_records(self, experiment_ids=None, **kwargs):
"""
raise NotImplementedError(f"Please implement the `search_records` method.")

def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
"""
Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.

Expand Down Expand Up @@ -190,7 +190,7 @@ def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (objec
except ValueError:
if experiment_name is None:
experiment_name = self._default_exp_name
logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
return self.create_exp(experiment_name), True

def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
Expand Down Expand Up @@ -351,6 +351,15 @@ def _get_exp(self, experiment_id=None, experiment_name=None):
experiment_id is not None or experiment_name is not None
), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder."
if experiment_id is not None:
try:
Derek-Wds marked this conversation as resolved.
Show resolved Hide resolved
experiment_id = int(experiment_id)
except ValueError as e:
msg = "The `experiment_id` for mlflow backend must be `int`"
logger.error(msg)
# We have to raise type error here
# - The error looks like type error
# - Value Error will be catched
raise TypeError(msg)
try:
exp = self.client.get_experiment(experiment_id)
if exp.lifecycle_stage.upper() == "DELETED":
Expand Down