In [1]:
import sys
import copy
from pathlib import Path

import qlib
import numpy as np
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
    backtest as normal_backtest,
    risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.utils import flatten_dict

In [2]:
# use default data
# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
if not exists_qlib_data(provider_uri):
    print(f"Qlib data is not found in {provider_uri}")
    sys.path.append(str(Path.cwd().parent.joinpath("scripts")))
    from get_data import GetData
    GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)

[35366:MainThread](2020-11-27 10:31:09,528) INFO - qlib.Initialization - [__init__.py:41] - default_conf: client.
[35366:MainThread](2020-11-27 10:31:09,531) INFO - qlib.Initialization - [__init__.py:76] - qlib successfully initialized based on client settings.
[35366:MainThread](2020-11-27 10:31:09,532) INFO - qlib.Initialization - [__init__.py:79] - data_path=/home/dongzho/.qlib/qlib_data/cn_data


In [3]:
market = "csi300"
benchmark = "SH000300"

## Model Training

In [4]:
###################################
# train model
###################################
data_handler_config = {
    "start_time": "2008-01-01",
    "end_time": "2020-08-01",
    "fit_start_time": "2008-01-01",
    "fit_end_time": "2014-12-31",
    "instruments": market,
}

task = {
    "model": {
        "class": "LGBModel",
        "module_path": "qlib.contrib.model.gbdt",
        "kwargs": {
            "loss": "mse",
            "colsample_bytree": 0.8879,
            "learning_rate": 0.0421,
            "subsample": 0.8789,
            "lambda_l1": 205.6999,
            "lambda_l2": 580.9768,
            "max_depth": 8,
            "num_leaves": 210,
            "num_threads": 20,
        },
    },
    "dataset": {
        "class": "DatasetH",
        "module_path": "qlib.data.dataset",
        "kwargs": {
            "handler": {
                "class": "Alpha158",
                "module_path": "qlib.contrib.data.handler",
                "kwargs": data_handler_config,
            },
            "segments": {
                "train": ("2008-01-01", "2014-12-31"),
                "valid": ("2015-01-01", "2016-12-31"),
                "test": ("2017-01-01", "2020-08-01"),
            },
        },
    },
}

# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])

# start exp to train model
with R.start(experiment_name="train_model"):
    R.log_params(**flatten_dict(task))
    model.fit(dataset)
    R.save_objects(trained_model=model)
    rid = R.get_recorder().id


[35366:MainThread](2020-11-27 10:31:29,731) INFO - qlib.timer - [log.py:81] - Time cost: 20.103s | Loading data Done
[35366:MainThread](2020-11-27 10:31:30,557) INFO - qlib.timer - [log.py:81] - Time cost: 0.241s | DropnaLabel Done
[35366:MainThread](2020-11-27 10:31:38,518) INFO - qlib.timer - [log.py:81] - Time cost: 7.960s | CSZScoreNorm Done
[35366:MainThread](2020-11-27 10:31:38,519) INFO - qlib.timer - [log.py:81] - Time cost: 8.786s | fit & process data Done
[35366:MainThread](2020-11-27 10:31:38,520) INFO - qlib.timer - [log.py:81] - Time cost: 28.891s | Init data Done
[35366:MainThread](2020-11-27 10:31:38,527) INFO - qlib.workflow - [exp.py:180] - Experiment 2 starts running ...
[35366:MainThread](2020-11-27 10:31:38,651) INFO - qlib.workflow - [recorder.py:234] - Recorder c81375e3b5474feb9c77711babd158c3 starts running under Experiment 2 ...
[35366:MainThread](2020-11-27 10:31:38,652) INFO - qlib.workflow - [expm.py:251] - No tracking URI is provided. The default tracking UR

## Optimization Based Strategy

In [5]:
from qlib.contrib.strategy.strategy import BaseStrategy


class OptBasedStrategy(BaseStrategy):
    """Optimization Based Strategy"""

    def __init__(self, data_handler, cov_estimator, optimizer):
        self.data_handler = data_handler
        self.cov_estimator = cov_estimator
        self.optimizer = optimizer

    def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
        """
        Parameters
        -----------
        score_series : pd.Seires
            stock_id , score.
        current : Position()
            current of account.
        trade_exchange : Exchange()
            exchange.
        trade_date : pd.Timestamp
            date.
        """
        score_series = score_series.dropna()

        # check stock holdings, if
        # 1. doesn't have score: target amount = 0 (force sell)
        # 2. stock not tradable: target amount = current amount
        current_position = current.get_stock_amount_dict()
        target_position = {}
        for stock_id in current_position:
            if not trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
                target_position[stock_id] = current_position[stock_id]
            elif stock_id not in score_series.index:
                target_position[stock_id] = 0
            else:
                # need to be solved by optimizer
                pass

        # filter scores, if
        # 1. kept in `amount_dict` by previous rules
        # 2. not tradable
        skipped = []
        for stock_id in score_series.index:
            if stock_id in target_position:
                skipped.append(stock_id)
            elif not trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
                skipped.append(stock_id)
        score_series = score_series[~score_series.index.isin(skipped)]

        # calc remaining value
        current_value = pd.Series({
            stock_id: current.get_stock_price(stock_id) * amount
            for stock_id, amount in current_position.items()
        })
        risk_total_value = self.get_risk_degree(trade_date) * current.calculate_value()
        traded_value = risk_total_value - current_value.loc[list(target_position)].sum()

        # portfolio init weight
        init_weight = current_value.reindex(score_series.index, fill_value=0)
        init_weight_sum = init_weight.sum()
        if init_weight_sum > 0:
            init_weight /= init_weight_sum

        # covariance estimation
        selector = (self.data_handler.get_range_selector(pred_date, 252), score_series.index)
        price = self.data_handler.fetch(selector, level=None, squeeze=True)
        cov = self.cov_estimator(price)
        cov = cov.reindex(
            index=score_series.index, 
            columns=score_series.index, 
            #fill_value=cov.max().max()
        )

        # optimize target portfolio
        if init_weight.sum() > 0:
            target_weight = self.optimizer(cov, score_series, init_weight)
        else:
            target_weight = self.optimizer(cov, score_series)
        target_weight = target_weight[target_weight > 1e-6]
        for stock_id, weight in target_weight.items():
            try:
                target_position[stock_id] = int(traded_value * weight / trade_exchange.get_close(stock_id, pred_date))
            except Exception as e:
                # TODO: unknown exception
                print('Exception:', e)

        # for debug
        print('trade date:', trade_date)
        print('target weight:', target_weight.to_dict())
        print('target position:', target_position)

        # generate order list
        order_list = trade_exchange.generate_order_for_target_amount_position(
            target_position=target_position,
            current_position=current_position,
            trade_date=trade_date,
        )

        return order_list

In [6]:
from qlib.data.dataset.loader import QlibDataLoader
from qlib.data.dataset.handler import DataHandler
from qlib.model.riskmodel import ShrinkCovEstimator
from qlib.portfolio.optimizer import PortfolioOptimizer

In [7]:
data_loader = QlibDataLoader(["$close"])
data_handler = DataHandler("all", "2015-01-01", "2020-08-01", data_loader)
cov_estimator = ShrinkCovEstimator(nan_option="mask")
optimizer = PortfolioOptimizer("mvo", lamb=2, delta=0.2, tol=1e-5)
strategy = OptBasedStrategy(data_handler, cov_estimator, optimizer)

[35366:MainThread](2020-11-27 10:31:56,951) INFO - qlib.timer - [log.py:81] - Time cost: 6.763s | Loading data Done
[35366:MainThread](2020-11-27 10:31:56,953) INFO - qlib.timer - [log.py:81] - Time cost: 6.766s | Init data Done


In [49]:
###################################
# prediction, backtest & analysis
###################################
port_analysis_config = {
    "strategy": strategy,
    "backtest": {
        "verbose": False,
        "limit_threshold": 0.095,
        "account": 100000000,
        "benchmark": benchmark,
        "deal_price": "close",
        "open_cost": 0.0005,
        "close_cost": 0.0015,
        "min_cost": 5,
    },
}


# backtest and analysis
with R.start(experiment_name="backtest_analysis"):
    recorder = R.get_recorder(rid, experiment_name="train_model")
    model = recorder.load_object("trained_model")

    # prediction
    recorder = R.get_recorder()
    ba_rid = recorder.id
    sr = SignalRecord(model, dataset, recorder)
    sr.generate()

    # backtest & analysis
    par = PortAnaRecord(recorder, port_analysis_config)
    par.generate()

1': 0.08936553334387595, 'SH601800': 0.011014844457113308, 'SH601939': 0.013378001170219945, 'SH603993': 0.013820193926861863, 'SZ000338': 0.002455991798001457, 'SZ000423': 0.004893338273543826, 'SZ000538': 0.010686211189620477, 'SZ002065': 0.09095125419435357, 'SZ002074': 0.010299013738522475, 'SZ002085': 0.19844965949420615, 'SZ002236': 0.09210003831704765, 'SZ002310': 0.05664352912360013, 'SZ300017': 0.0197442255539771}
target position: {'SZ002299': 6184584.0980107365, 'SH600000': 272224, 'SH600009': 604839, 'SH600018': 3097398, 'SH600028': 335726, 'SH600196': 23243, 'SH600276': 71634, 'SH600519': 17354, 'SH600585': 269686, 'SH600900': 2501521, 'SH601111': 2400659, 'SH601800': 334062, 'SH601939': 1283164, 'SH603993': 742901, 'SZ000338': 95285, 'SZ000423': 21697, 'SZ000538': 14518, 'SZ002065': 498253, 'SZ002074': 111674, 'SZ002085': 591507, 'SZ002236': 394197, 'SZ002310': 2202674, 'SZ300017': 206128}
target weight: {'SH600000': 0.02310668460556249, 'SH600009': 0.06170206213753432, 'S

KeyboardInterrupt: 