Skip to content

Commit

Permalink
Enhance Task Dict Var (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
you-n-g committed Dec 25, 2021
1 parent e33de44 commit 3493f29
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 10 deletions.
17 changes: 11 additions & 6 deletions examples/model_rolling/task_manager_rolling.py
Expand Up @@ -17,7 +17,7 @@
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM, task_train
from qlib.model.trainer import TrainerR, TrainerRM, task_train
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG


Expand All @@ -29,7 +29,7 @@ def __init__(
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
experiment_name="rolling_exp",
task_pool="rolling_task",
task_pool=None, # if user want to "rolling_task"
task_config=None,
rolling_step=550,
rolling_type=RollingGen.ROLL_SD,
Expand All @@ -43,14 +43,19 @@ def __init__(
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.experiment_name = experiment_name
self.task_pool = task_pool
if task_pool is None:
self.trainer = TrainerR(experiment_name=self.experiment_name)
else:
self.task_pool = task_pool
self.trainer = TrainerRM(self.experiment_name, self.task_pool)
self.task_config = task_config
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)

# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
TaskManager(task_pool=self.task_pool).remove()
if isinstance(self.trainer, TrainerRM):
TaskManager(task_pool=self.task_pool).remove()
exp = R.get_exp(experiment_name=self.experiment_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
Expand All @@ -66,10 +71,10 @@ def task_generating(self):

def task_training(self, tasks):
print("========== task_training ==========")
trainer = TrainerRM(self.experiment_name, self.task_pool)
trainer.train(tasks)
self.trainer.train(tasks)

def worker(self):
# NOTE: this is only used for TrainerRM
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
print("========== worker ==========")
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
Expand Down
60 changes: 58 additions & 2 deletions qlib/model/trainer.py
Expand Up @@ -86,10 +86,61 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str
return R.get_recorder()


def get_item_from_obj(config: dict, name_path: str) -> object:
"""
Follow the name_path to get values from config
For example:
If we follow the example in in the Parameters section,
Timestamp('2008-01-02 00:00:00') will be returned
Parameters
----------
config : dict
e.g.
{'dataset': {'class': 'DatasetH',
'kwargs': {'handler': {'class': 'Alpha158',
'kwargs': {'end_time': '2020-08-01',
'fit_end_time': '<dataset.kwargs.segments.train.1>',
'fit_start_time': '<dataset.kwargs.segments.train.0>',
'instruments': 'csi100',
'start_time': '2008-01-01'},
'module_path': 'qlib.contrib.data.handler'},
'segments': {'test': (Timestamp('2017-01-03 00:00:00'),
Timestamp('2019-04-08 00:00:00')),
'train': (Timestamp('2008-01-02 00:00:00'),
Timestamp('2014-12-31 00:00:00')),
'valid': (Timestamp('2015-01-05 00:00:00'),
Timestamp('2016-12-30 00:00:00'))}}
}}
name_path : str
e.g.
"dataset.kwargs.segments.train.1"
Returns
-------
object
the retrieved object
"""
cur_cfg = config
for k in name_path.split("."):
if isinstance(cur_cfg, dict):
cur_cfg = cur_cfg[k]
elif k.isdigit():
cur_cfg = cur_cfg[int(k)]
else:
raise ValueError(f"Error when getting {k} from cur_cfg")
return cur_cfg


def fill_placeholder(config: dict, config_extend: dict):
"""
Detect placeholder in config and fill them with config_extend.
The item of dict must be single item(int, str, etc), dict and list. Tuples are not supported.
There are two type of variables:
- user-defined variables :
e.g. when config_extend is `{"<MODEL>": model, "<DATASET>": dataset}`, "<MODEL>" and "<DATASET>" in `config` will be replaced with `model` `dataset`
- variables extracted from `config` :
e.g. the variables like "<dataset.kwargs.segments.train.0>" will be replaced with the values from `config`
Parameters
----------
Expand Down Expand Up @@ -122,8 +173,13 @@ def fill_placeholder(config: dict, config_extend: dict):
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
item_queue.append(now_item[key])
tail += 1
elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys():
now_item[key] = config_extend[now_item[key]]
elif isinstance(now_item[key], str):
if now_item[key] in config_extend.keys():
now_item[key] = config_extend[now_item[key]]
else:
m = re.match(r"<(?P<name_path>[^<>]+)>", now_item[key])
if m is not None:
now_item[key] = get_item_from_obj(config, m.groupdict()["name_path"])
return config


Expand Down
4 changes: 2 additions & 2 deletions qlib/tests/config.py
Expand Up @@ -50,8 +50,8 @@
def get_data_handler_config(
start_time="2008-01-01",
end_time="2020-08-01",
fit_start_time="2008-01-01",
fit_end_time="2014-12-31",
fit_start_time="<dataset.kwargs.segments.train.0>",
fit_end_time="<dataset.kwargs.segments.train.1>",
instruments=CSI300_MARKET,
):
return {
Expand Down

0 comments on commit 3493f29

Please sign in to comment.