In [1]:
import sys
from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha360
from qlib.utils import init_instance_by_config
from qlib.contrib.strategy import TopkDropoutStrategy
from qlib.contrib.report import analysis_model, analysis_position
# from qlib.contrib.evaluate import (
#     backtest as normal_backtest,
#     risk_analysis,
# )
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.utils import flatten_dict
from qlib.data.dataset.loader import QlibDataLoader
from qlib.contrib.data.handler import Alpha158   #Alpha158内置指标体系
from qlib.data.dataset.loader import QlibDataLoader
import qlib
from qlib.contrib.data.handler import Alpha158   #Alpha158内置指标体系


provider_uri = "/home/shared/qlib-main/qlib_data/cn_data"  # 原始行情数据存放目录
qlib.init(provider_uri=provider_uri, region=REG_CN)  # 初始化
market = "csi100"
benchmark = "SH000300"

 #数据处理器参数配置
data_handler_config = {
    "start_time": "2008-01-01",
    "end_time": "2020-08-01",
    "fit_start_time": "2008-01-01",
    "fit_end_time": "2014-12-31",
    # "start_time": "2020-01-01",
    # "end_time": "2020-02-21",
    # "fit_start_time": "2020-01-01",  # 模型跑数据的开始时间
    # "fit_end_time": "2020-01-31",
    "instruments": market,
    "infer_processors" : [
                                    {'class': 'FilterCol',
                                     'kwargs': {'fields_group': 'feature', 
                                                'col_list': ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", 
                            "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5", 
                            "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
                        ]
                                     }},
                                    {'class': 'RobustZScoreNorm', # RobustZScoreNorm和Fillna，用于进行标准化和填充缺失值。
                                     'kwargs': {'fields_group': 'feature', 
                                                'clip_outlier': True}},
                                    {'class': 'Fillna', 
                                     'kwargs': {'fields_group': 'feature'}}],
    "learn_processors": [{'class': 'DropnaLabel'}, #DropnaLabel删除标注中含有缺失值的样本。

                                    # 对预测的目标进行截面排序处理  DropnaLabel 和 CSRankNorm 预处理器，用于对标签进行处理。
                                    {'class': 'CSRankNorm', 'kwargs': {'fields_group': 'label'}}],

                                    # 预测的目标
                                    'label': ["Ref($close, -2) / Ref($close, -1) - 1"] ,#下一日收益率, Ref($close, -1)表示下一日收盘价
                                    
}
    

# 任务参数配置
task = {
    "model": {  # 模型参数配置
        # 模型类
        "class": "TransformerModel",
        # 模型类所在模块
        "module_path": "qlib.contrib.model.pytorch_transformer_ts_time",
        "kwargs": {  # 模型超参数配置
            'seed': 0,
            'n_jobs': 20,
            'GPU':3
        }, 
    },
    "dataset": {  # 　因子库数据集参数配置
        # 数据集类，是Dataset with Data(H)andler的缩写，即带数据处理器的数据集
        "class": "TSDatasetH",
        # 数据集类所在模块
        "module_path": "qlib.data.dataset",
        "kwargs": {  # 数据集参数配置
            "handler": {  # 数据集使用的数据处理器配置
                #"class": "Alpha158",  # 数据处理器类，继承自DataHandlerLP
                "module_path": "qlib.contrib.data.handler",  # 数据处理器类所在模块
                "class": "Alpha158",
                "kwargs": data_handler_config,  # 数据处理器参数配置
            },
            "segments": {  # 数据集划分标准
                "train": ("2008-01-01", "2014-12-31"), # 此时段的数据为训练集
                "valid": ("2015-01-01", "2016-12-31"), # 此时段的数据为验证集
                "test": ("2017-01-01", "2020-08-01"),  # 此时段的数据为测试集
                # "train": ("2020-01-01", "2020-01-31"),  # 此时段的数据为训练集
                # "valid": ("2020-01-31", "2020-02-20"),  # 此时段的数据为验证集
                # "test": ("2020-02-20", "2020-02-21"),  # 此时段的数据为测试集
            },
        },
    },

}

# 实例化模型对象
model = init_instance_by_config(task["model"])

# 实例化因子库数据集，从基础行情数据计算出的包含所有特征（因子）和标签值的数据集。
dataset = init_instance_by_config(task["dataset"])  # DatasetH

# start exp to train model
with R.start(experiment_name="train_model"):
    R.log_params(**flatten_dict(task))
    model.fit(dataset)
    R.save_objects(trained_model=model)
    rid = R.get_recorder().id

###################################
# prediction, backtest & analysis
###################################
port_analysis_config = {
    "executor": {
        "class": "SimulatorExecutor",
        "module_path": "qlib.backtest.executor",
        "kwargs": {
            "time_per_step": "day",
            "generate_portfolio_metrics": True,
        },
    },
    "strategy": {
        "class": "TopkDropoutStrategy",
        "module_path": "qlib.contrib.strategy",
        "kwargs": {
            "model": model,
            "dataset": dataset,
            "topk": 50,
            "n_drop": 5,
        },
    },
    "backtest": {
        "start_time": "2017-01-01",
        "end_time": "2020-08-01",
        "account": 100000000,
        "benchmark": benchmark,
        "exchange_kwargs": {
            "freq": "day",
            "limit_threshold": 0.095,
            "deal_price": "close",
            "open_cost": 0.0005,
            "close_cost": 0.0015,
            "min_cost": 5,
        },
    },
}

# backtest and analysis
with R.start(experiment_name="backtest_analysis"):
    recorder = R.get_recorder(recorder_id=rid, experiment_name="train_model")
    model = recorder.load_object("trained_model")

    # prediction
    recorder = R.get_recorder()
    ba_rid = recorder.id
    sr = SignalRecord(model, dataset, recorder)
    sr.generate()

    # backtest & analysis
    par = PortAnaRecord(recorder, port_analysis_config, "day")
    par.generate()


  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError. XGBModel is skipped(optional: maybe installing xgboost can fix it).


[13272:MainThread](2023-07-25 12:03:28,588) INFO - qlib.Initialization - [config.py:413] - default_conf: client.
[13272:MainThread](2023-07-25 12:03:28,593) INFO - qlib.workflow - [expm.py:31] - experiment manager uri is at file:/home/shared/qlib-main/mlruns
[13272:MainThread](2023-07-25 12:03:28,594) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.
[13272:MainThread](2023-07-25 12:03:28,595) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': PosixPath('/home/shared/qlib-main/qlib_data/cn_data')}
[13272:MainThread](2023-07-25 12:03:28,935) INFO - qlib.TransformerModel - [pytorch_transformer_ts_time.py:63] - Naive Transformer:
batch_size : 256
device : cuda:3
[13272:MainThread](2023-07-25 12:03:53,017) INFO - qlib.timer - [log.py:117] - Time cost: 19.854s | Loading data Done
[13272:MainThread](2023-07-25 12:03:53,382) INFO - qlib.timer - [log.py:117] - Time cost: 0.010s | FilterCol Done
[13272:MainThread](2

'The following are prediction results of the TransformerModel model.'
                          score
datetime   instrument          
2017-01-03 SH600000   -1.385936
           SH600010   -0.762190
           SH600015   -0.685163
           SH600016   -0.442984
           SH600018   -0.802854


[13272:MainThread](2023-07-25 12:18:25,233) INFO - qlib.backtest caller - [__init__.py:94] - Create new exchange
backtest loop: 100%|██████████| 871/871 [00:16<00:00, 51.75it/s]
[13272:MainThread](2023-07-25 12:19:04,545) INFO - qlib.workflow - [record_temp.py:499] - Portfolio analysis record 'port_analysis_1day.pkl' has been saved as the artifact of the Experiment 2
[13272:MainThread](2023-07-25 12:19:04,558) INFO - qlib.workflow - [record_temp.py:524] - Indicator analysis record 'indicator_analysis_1day.pkl' has been saved as the artifact of the Experiment 2
[13272:MainThread](2023-07-25 12:19:04,621) INFO - qlib.timer - [log.py:117] - Time cost: 0.045s | waiting `async_log` Done


'The following are analysis results of benchmark return(1day).'
                       risk
mean               0.000477
std                0.012295
annualized_return  0.113561
information_ratio  0.598699
max_drawdown      -0.370479
'The following are analysis results of the excess return without cost(1day).'
                       risk
mean              -0.000066
std                0.002869
annualized_return -0.015731
information_ratio -0.355394
max_drawdown      -0.115043
'The following are analysis results of the excess return with cost(1day).'
                       risk
mean              -0.000253
std                0.002867
annualized_return -0.060314
information_ratio -1.363674
max_drawdown      -0.229483
'The following are analysis results of indicators(1day).'
     value
ffr    1.0
pa     0.0
pos    0.0
