In [23]:
import mlflow
mlflow.end_run()

In [24]:
import qlib
from qlib.config import REG_CN
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.config import C

data_handler_template = {
    "start_time": "2008-01-01",
    "end_time": "2020-08-01",
    "fit_start_time": "2008-01-01",
    "fit_end_time": "2014-12-31",
    "instruments": 'csi100',
}

dataset_template = {
        "class": "DatasetH",
        "module_path": "qlib.data.dataset",
        "kwargs": {
            "handler": {
                "class": "Alpha158",
                "module_path": "qlib.contrib.data.handler",
                "kwargs": data_handler_template,
            },
            "segments": {
                "train": ("2008-01-01", "2014-12-31"),
                "valid": ("2015-01-01", "2016-12-31"),
                "test": ("2017-01-01", "2020-08-01"),
            },
        },
    }

record_template = [
    {
        "class": "SignalRecord",
        "module_path": "qlib.workflow.record_temp",
    },
    {
        "class": "SigAnaRecord",
        "module_path": "qlib.workflow.record_temp",
    }
]

# use lgb
lgb_task_template = {
    "model": {
        "class": "LGBModel",
        "module_path": "qlib.contrib.model.gbdt",
    },
    "dataset": dataset_template,
    "record": record_template,
}

# use xgboost
xgboost_task_template = {
    "model": {
        "class": "XGBModel",
        "module_path": "qlib.contrib.model.xgboost",
    },
    "dataset": dataset_template,
    "record": record_template,
}

provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
qlib.init(provider_uri=provider_uri, region=REG_CN)

C["mongo"] = {
    "task_url" : "mongodb://localhost:27017/", # maybe you need to change it to your url
    "task_db_name" : "rolling_db3"
}

exp_name = 'rolling_exp3' # experiment name, will be used as the experiment in MLflow
task_pool = 'rolling_task3' # task pool name, will be used as the document in MongoDB

[8348:MainThread](2021-03-09 14:55:48,543) INFO - qlib.Initialization - [config.py:279] - default_conf: client.
[8348:MainThread](2021-03-09 14:55:50,597) INFO - qlib.Initialization - [__init__.py:48] - qlib successfully initialized based on client settings.
[8348:MainThread](2021-03-09 14:55:50,601) INFO - qlib.Initialization - [__init__.py:49] - data_path=C:\Users\lzh222333\.qlib\qlib_data\cn_data


In [25]:
tasks = task_generator(
    xgboost_task_template, # default task name
    RollingGen(step=550,rtype=RollingGen.ROLL_SD), # generate different date segment
    task_lgb=lgb_task_template # use "task_lgb" as the task name
)
# Uncomment next two lines to see the generated tasks
from pprint import pprint
pprint(tasks)
tm = TaskManager(task_pool=task_pool)
tm.create_task(tasks) # all tasks will be saved to MongoDB

[{'dataset': {'class': 'DatasetH',
              'kwargs': {'handler': {'class': 'Alpha158',
                                     'kwargs': {'end_time': '2020-08-01',
                                                'fit_end_time': '2014-12-31',
                                                'fit_start_time': '2008-01-01',
                                                'instruments': 'csi100',
                                                'start_time': '2008-01-01'},
                                     'module_path': 'qlib.contrib.data.handler'},
                         'segments': {'test': (Timestamp('2017-01-03 00:00:00'),
                                               Timestamp('2019-04-08 00:00:00')),
                                      'train': (Timestamp('2008-01-02 00:00:00'),
                                                Timestamp('2014-12-31 00:00:00')),
                                      'valid': (Timestamp('2015-01-05 00:00:00'),
                                 

In [26]:
from qlib.workflow.task.manage import run_task
from qlib.workflow.task.collect import TaskCollector
from qlib.model.trainer import task_train

run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method

2021-03-09 14:55:51.600 | INFO     | qlib.workflow.task.manage:run_task:355 - {'model': {'class': 'XGBModel', 'module_path': 'qlib.contrib.model.xgboost'}, 'dataset': {'class': 'DatasetH', 'module_path': 'qlib.data.dataset', 'kwargs': {'handler': {'class': 'Alpha158', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': '2008-01-01', 'end_time': '2020-08-01', 'fit_start_time': '2008-01-01', 'fit_end_time': '2014-12-31', 'instruments': 'csi100'}}, 'segments': {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')), 'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')), 'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2019-04-08 00:00:00'))}}}, 'record': [{'class': 'SignalRecord', 'module_path': 'qlib.workflow.record_temp'}, {'class': 'SigAnaRecord', 'module_path': 'qlib.workflow.record_temp'}], 'task_key': 1}
[8348:MainThread](2021-03-09 14:56:46,051) INFO - qlib.timer - [log.py:81] - Time cost: 54.448s | Loading data D

True

In [27]:
def get_task_key(task):
    task_key = task["task_key"]
    rolling_end_timestamp = task["dataset"]["kwargs"]["segments"]["test"][1]
    #rolling_end_datatime = rolling_end_timestamp.to_pydatetime()
    return task_key, rolling_end_timestamp.strftime('%Y-%m-%d')

def my_filter(task):
    # only choose the results of "task_lgb" and test segment end in 2019 from all tasks
    task_key, rolling_end = get_task_key(task)
    if task_key=="task_lgb" and rolling_end.startswith('2019'):
        return True
    return False

# name tasks by "get_task_key" and filter tasks by "my_filter"
pred_rolling = TaskCollector.collect(exp_name, get_task_key, my_filter) 
pred_rolling

Loading data: 100%|██████████| 4/4 [00:00<00:00, 37.38it/s]


{('task_lgb', '2019-04-08'): datetime    instrument
 2017-01-03  SH600000     -0.013089
             SH600010     -0.006642
             SH600015     -0.035137
             SH600016     -0.034634
             SH600018     -0.029493
                             ...   
 2019-04-08  SZ002415      0.049199
             SZ002450     -0.013450
             SZ002594      0.022395
             SZ002736      0.091433
             SZ300059     -0.016237
 Name: score, Length: 55000, dtype: float64}