diff --git a/CHANGES.rst b/CHANGES.rst new file mode 100644 index 0000000000..a340b19e14 --- /dev/null +++ b/CHANGES.rst @@ -0,0 +1,152 @@ +Changelog +==================== +Here you can see the full list of changes between each QLib release. + +Version 0.1.0 +-------------------- +This is the initial release of QLib library. + +Version 0.1.1 +-------------------- +Performance optimize. Add more features and operators. + +Version 0.1.2 +-------------------- +- Support operator syntax. Now ``High() - Low()`` is equivalent to ``Sub(High(), Low())``. +- Add more technical indicators. + +Version 0.1.3 +-------------------- +Bug fix and add instruments filtering mechanism. + +Version 0.2.0 +-------------------- +- Redesign ``LocalProvider`` database format for performance improvement. +- Support load features as string fields. +- Add scripts for database construction. +- More operators and technical indicators. + +Version 0.2.1 +-------------------- +- Support registering user-defined ``Provider``. +- Support use operators in string format, e.g. ``['Ref($close, 1)']`` is valid field format. +- Support dynamic fields in ``$some_field`` format. And exising fields like ``Close()`` may be deprecated in the future. + +Version 0.2.2 +-------------------- +- Add ``disk_cache`` for reusing features (enabled by default). +- Add ``qlib.contrib`` for experimental model construction and evaluation. + + +Version 0.2.3 +-------------------- +- Add ``backtest`` module +- Decoupling the Strategy, Account, Position, Exchange from the backtest module + +Version 0.2.4 +-------------------- +- Add ``profit attribution`` module +- Add ``rick_control`` and ``cost_control`` strategies + +Version 0.3.0 +-------------------- +- Add ``estimator`` module + +Version 0.3.1 +-------------------- +- Add ``filter`` module + +Version 0.3.2 +-------------------- +- Add real price trading, if the ``factor`` field in the data set is incomplete, use ``adj_price`` trading +- Refactor ``handler`` ``launcher`` ``trainer`` code +- Support ``backtest`` configuration parameters in the configuration file +- Fix bug in position ``amount`` is 0 +- Fix bug of ``filter`` module + +Version 0.3.3 +------------------- +- Fix bug of ``filter`` module + +Version 0.3.4 +-------------------- +- Support for ``finetune model`` +- Refactor ``fetcher`` code + +Version 0.3.5 +-------------------- +- Support multi-label training, you can provide multiple label in ``handler``. (But LightGBM doesn't support due to the algorithm itself) +- Refactor ``handler`` code, dataset.py is no longer used, and you can deploy your own labels and features in ``feature_label_config`` +- Handler only offer DataFrame. Also, ``trainer`` and model.py only receive DataFrame +- Change ``split_rolling_data``, we roll the data on market calender now, not on normal date +- Move some date config from ``handler`` to ``trainer`` + +Version 0.4.0 +-------------------- +- Add `data` package that holds all data-related codes +- Reform the data provider structure +- Create a server for data centralized management `qlib-server`_ +- Add a `ClientProvider` to work with server +- Add a pluggable cache mechanism +- Add a recursive backtracking algorithm to inspect the furthest reference date for an expression + +.. note:: + The ``D.instruments`` function does not support ``start_time``, ``end_time``, and ``as_list`` parameters, if you want to get the results of previous versions of ``D.instruments``, you can do this: + + + >>> from qlib.data import D + >>> instruments = D.instruments(market='csi500') + >>> D.list_instruments(instruments=instruments, start_time='2015-01-01', end_time='2016-02-15', as_list=True) + + +Version 0.4.1 +-------------------- +- Add support Windows +- Fix ``instruments`` type bug +- Fix ``features`` is empty bug(It will cause failure in updating) +- Fix ``cache`` lock and update bug +- Fix use the same cache for the same field (the original space will add a new cache) +- Change "logger handler" from config +- Change model load support 0.4.0 later +- The default value of the ``method`` parameter of ``risk_analysis`` function is changed from **ci** to **si** + + +Version 0.4.2 +-------------------- +- Refactor DataHandler +- Add ``ALPHA360`` DataHandler + + +Version 0.4.3 +-------------------- +- Implementing Online Inference and Trading Framework +- Refactoring The interfaces of backtest and strategy module. + + +Version 0.4.4 +-------------------- +- Optimize cache generation performance +- Add report module +- Fix bug when using ``ServerDatasetCache`` offline. +- In the previous version of ``long_short_backtest``, there is a case of ``np.nan`` in long_short. The current version ``0.4.4`` has been fixed, so ``long_short_backtest`` will be different from the previous version. +- In the ``0.4.2`` version of ``risk_analysis`` function, ``N`` is ``250``, and ``N`` is ``252`` from ``0.4.3``, so ``0.4.2`` is ``0.002122`` smaller than the ``0.4.3`` the backtest result is slightly different between ``0.4.2`` and ``0.4.3``. +- refactor the argument of backtest function. + - **NOTE**: + - The default arguments of topk margin strategy is changed. Please pass the arguments explicitly if you want to get the same backtest result as previous version. + - The TopkWeightStrategy is changed slightly. It will try to sell the stocks more than ``topk``. (The backtest result of TopkAmountStrategy remains the same) +- The margin ratio mechanism is supported in the Topk Margin strategies. + + +Version 0.4.5 +-------------------- +- Add multi-kernel implementation for both client and server. + - Support a new way to load data from client which skips dataset cache. + - Change the default dataset method from single kernel implementation to multi kernel implementation. +- Accelerate the high frequency data reading by optimizing the relative modules. +- Support a new method to write config file by using dict. + +Version 0.4.6 +-------------------- +- Some bugs are fixed + - The default config in `Version 0.4.5` is not friendly to daily frequency data. + - Backtest error in TopkWeightStrategy when `WithInteract=True`. diff --git a/README.md b/README.md index 8eeee9c7c4..6eda5f0f94 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,189 @@ +Qlib is a an AI-oriented quantitative investment platform. aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment. + +With you Qlib, you can easily apply your favorite model to create better Quant investment strategy. + + +- [Framework of Qlib](#framework-of-qlib) +- [Quick start](#quick-start) + - [Installation](#installation) + - [Get Data](#get-data) + - [Auto Quant research workflow with _estimator_](#auto-quant-research-workflow-with-estimator) + - [Customized Quant research workflow by code](#customized-quant-research-workflow-by-code) +- [More About Qlib](#more-about-qlib) + - [Offline mode and online mode](#offline-mode-and-online-mode) + - [Performance of Qlib Data Server](#performance-of-qlib-data-server) + + + +# Framework of Qlib +![framework](docs/_static/img/framework.png) + +At module level, Qlib is a platform that consists of the above components. Each components is loose-coupling and can be used stand-alone. + +| Name | Description| +|------| -----| +| _Data layer_ | _DataServer_ focus on providing high performance infrastructure for user to retreive and get raw data. _DataEnhancement_ will preprocess the data and provide the best dataset to be fed in to the models | +| _Interday Model_ | _Interday model_ focus on produce forecasting signals(aka. _alpha_). Models are trained by _Model Creator_ and managed by _Model Manager_. User could choose one or multiple models for forecasting. Multiple models could be combined with _Ensemble_ module | +| _Interday Strategy_ | _Portfolio Generator_ will take forecasting signals as input and output the orders based on current position to achieve target portfolio | +| _Intraday Trading_ | _Order Executor_ is responsible for executing orders output by _Interday Strategy_ and returning the executed results. | +| _Analysis_ | User could get detailed analysis report of forecasting signal and portfolio in this part. | + +* The modules with hand-drawn style is under development and will be released in the future. +* The modules with dashed border is highly user-customizable and extendible. + + +# Quick start + +## Installation + +To install Qlib from source you need _Cython_ in addition to the normal dependencies above: + +```bash +pip install cython +``` + +Clone the repository and then run: +```bash +python setup.py install +``` + + +## Get Data +- Load and prepare the Data: execute the following command to load the stock data: + ```bash + python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data + ``` + + +## Auto Quant research workflow with _estimator_ +Qlib provides a tool named `estimator` to run whole workflow automatically(including building dataset, train models, backtest, analysis) + +1. Run _estimator_ (_config.yaml_ for: [estimator_config.yaml](example/estimator/estimator_config.yaml)): + + ```bash + estimator -c example/estimator/estimator_config.yaml + ``` + + Estimator result: + + ```bash + pred_long mean 0.001386 + std 0.004403 + annual 0.349379 + sharpe 4.998428 + mdd -0.049486 + pred_short mean 0.002703 + std 0.004680 + annual 0.681071 + sharpe 9.166842 + mdd -0.053523 + pred_long_short mean 0.004089 + std 0.007028 + annual 1.030451 + sharpe 9.236475 + mdd -0.045817 + sub_bench mean 0.000953 + std 0.004688 + annual 0.240123 + sharpe 3.226878 + mdd -0.064588 + sub_cost mean 0.000718 + std 0.004694 + annual 0.181003 + sharpe 2.428964 + mdd -0.072977 + ``` + See the full documnents for [Use _Estimator_ to Start An Experiment](TODO:URL). + +2. Analysis + + Run `examples/estimator/analyze_from_estimator.ipynb` in `jupyter notebook` + 1. forecasting signal analysis + - Model Performance + ![Model Performance](docs/_static/img/model_performance.png) + + 2. portfolio analysis + - Report + ![Report](docs/_static/img/report.png) + + +## Customized Quant research workflow by code +Automatical workflow may not suite the research workflow of all Quant researchers. To support flexible Quant research workflow, Qlib also provide modulized interface to allow researchers to build their own workflow. [Here](TODO_URL) is a demo for customized Quant research workflow by code + + + +# More About Qlib +The detailed documents are organized in [docs](docs). +[Sphinx](http://www.sphinx-doc.org) and the readthedocs theme is required to build the documentation in html formats. +```bash +cd docs/ +conda install sphinx sphinx_rtd_theme -y +# Otherwise, you can install them with pip +# pip install sphinx sphinx_rtd_theme +make html +``` +You can also view the [latest document](TODO_URL) online directly. + + + +## Offline mode and online mode +The data server of Qlib can both deployed as offline mode and online mode. The default mode is offline mode. + +Under offline mode, the data will be deployed locally. + +Under online mode, the data will be deployed as a shared data service. The data and their cache will be shared by clients. The data retrieving performance is expected to be improved due to higher rate of cache hits. It will use less disk space, too. The documents of the online mode can be found in [Qlib-Server](TODO_link). The online mode can be deployed automatically with [Azure CLI based scripts](TODO_link) + +## Performance of Qlib Data Server +The performance of data processing is important to datadriven methods like AI technologies. As an AI-oriented platform, Qlib provides a solution for data storage and data processing. To demonstrate the performance of Qlib, We +compare Qlib with several other solutions. + +The task for the solutions is to create a dataset from the +basic OHLCV daily data of a stock market, which involves +data query and processing. + + + +Most general purpose databases take too much time on loading data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general purpose database solution. +Such overheads greatly slow down the data loading process. +Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation. + + + + # Contributing diff --git a/README.rst b/README.rst new file mode 100644 index 0000000000..c09800898a --- /dev/null +++ b/README.rst @@ -0,0 +1,34 @@ +QLib +========== + +QLib is a Quantitative-research Library, which can provide research data with highly consistency, reusability and extensibility. + +.. note:: Anaconda python is strongly recommended for this library. See https://www.anaconda.com/download/. + + +Install +---------- + +Install as root: + +.. code-block:: bash + + $ python setup.py install + + +Install as single user (if you have no root permission): + +.. code-block:: bash + + $ python setup.py install --user + + +To verify your installation, open your python shell: + +.. code-block:: python + + >>> import qlib + >>> qlib.__version__ + '0.2.2' + +You can also run ``tests/test_data_sim.py`` to verify your installation. diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000000..11ee1e7986 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = python3 -msphinx +SPHINXPROJ = Quantlab +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/img/cumulative_return.png b/docs/_static/img/cumulative_return.png new file mode 100644 index 0000000000..ea24c28d67 Binary files /dev/null and b/docs/_static/img/cumulative_return.png differ diff --git a/docs/_static/img/framework.png b/docs/_static/img/framework.png new file mode 100644 index 0000000000..673f10e033 Binary files /dev/null and b/docs/_static/img/framework.png differ diff --git a/docs/_static/img/model_performance.png b/docs/_static/img/model_performance.png new file mode 100644 index 0000000000..58f46fd3d3 Binary files /dev/null and b/docs/_static/img/model_performance.png differ diff --git a/docs/_static/img/rank_label.png b/docs/_static/img/rank_label.png new file mode 100644 index 0000000000..29a82401f3 Binary files /dev/null and b/docs/_static/img/rank_label.png differ diff --git a/docs/_static/img/report.png b/docs/_static/img/report.png new file mode 100644 index 0000000000..2687bb5c19 Binary files /dev/null and b/docs/_static/img/report.png differ diff --git a/docs/_static/img/risk_analysis.png b/docs/_static/img/risk_analysis.png new file mode 100644 index 0000000000..a15610d883 Binary files /dev/null and b/docs/_static/img/risk_analysis.png differ diff --git a/docs/_static/img/score_ic.png b/docs/_static/img/score_ic.png new file mode 100644 index 0000000000..14ecd4f14f Binary files /dev/null and b/docs/_static/img/score_ic.png differ diff --git a/docs/advanced/backtest.rst b/docs/advanced/backtest.rst new file mode 100644 index 0000000000..ac2a9621f1 --- /dev/null +++ b/docs/advanced/backtest.rst @@ -0,0 +1,61 @@ +.. _backtest: +=================== +Backtest: Model&Strategy Testing +=================== +.. currentmodule:: qlib + +Introduction +=================== +By ``Backtest``, users can check the performance of custom model/strategy. + +`Backtest` can test the predicted score of the `Model` module and the customized `Strategy` module. + +Example +=========================== + +Users need to generate a score file(a pandas DataFrame) with MultiIndex and a `score` column. And users need to assign a strategy used in backtest, if strategy is not assigned, +a 'TopkAmountStrategy' strategy with(topk=20, buffer_margin=150, risk_degree=0.95, limit_threshold=0.0095) will be used. +If strategy module is not user's interested part, 'TopkAmountStrategy' is enough. + +The simple example is as follows. + +.. code-block:: python + + from qlib.contrib.evaluate import backtest + report, positions = backtest(pred_test, topk=50, margin=0.5, verbose=False, limit_threshold=0.0095) + + +Score file +-------------- + +The score file is a pandas DataFrame, its index is and it must +contains a "score" column. + +A score file sample is shown as follows. + +.. code-block:: python + + instrument datetime score + SH600000 2019-01-04 -0.505488 + SZ002531 2019-01-04 -0.320391 + SZ000999 2019-01-04 0.583808 + SZ300569 2019-01-04 0.819628 + SZ001696 2019-01-04 -0.137140 + ... ... + SZ000996 2019-04-30 -1.027618 + SH603127 2019-04-30 0.225677 + SH603126 2019-04-30 0.462443 + SH603133 2019-04-30 -0.302460 + SZ300760 2019-04-30 -0.126383 + +``Model`` module can produce the score file, please refer to `Model `_. + +Strategy +-------------- + +To know more abot ``Strategy``, please refer to `Strategy `_. + + +Api +============== +Please refer to `Backtest Api <../reference/api.html>`_. \ No newline at end of file diff --git a/docs/advanced/cache.rst b/docs/advanced/cache.rst new file mode 100644 index 0000000000..ef9c6e052e --- /dev/null +++ b/docs/advanced/cache.rst @@ -0,0 +1,84 @@ +.. _cache: +==================== +Cache: Frequently-Used Data +==================== + +.. currentmodule:: qlib + +The `cache` is a pluggable module to help accelerate providing data by saving some frequently-used data as cache file. Qlib provides a `Memcache` class to cache the most-frequently-used data in memory, an inheritable `ExpressionCache` class, and an inheritable `DatasetCache` class. + +`Memcache` is a memory cache mechanism that composes of three `MemCacheUnit` instances to cache **Calendar**, **Instruments**, and **Features**. The MemCache is defined globally in `cache.py` as `H`. User can use `H['c'], H['i'], H['f']` to get/set memcache. + +.. autoclass:: qlib.data.cache.MemCacheUnit + :members: + +.. autoclass:: qlib.data.cache.MemCache + :members: + +`ExpressionCache` is a disk cache mechanism that saves expressions such as **Mean($close, 5)**. Users can inherit this base class to define their own cache mechanism. Users need to override `self._uri` method to define how their cache file path is generated, `self._expression` method to define what data they want to cache and how to cache it. + +`DatasetCache` is a disk cache mechanism that saves datasets. A certain dataset is regulated by a stockpool configuration (or a series of instruments, though not recommended), a list of expressions or static feature fields, the start time and end time for the collected features and the frequency. Users need to override `self._uri` method to define how their cache file path is generated, `self._expression` method to define what data they want to cache and how to cache it. + +`ExpressionCache` and `DatasetCache` actually provides the same interfaces with `ExpressionProvider` and `DatasetProvider` so that the disk cache layer is transparent to users and will only be used if they want to define their own cache mechanism. The users can plug the cache mechanism into the server system by assigning the cache class they want to use in `config.py`: + +.. code-block:: python + + 'ExpressionCache': 'ServerExpressionCache', + 'DatasetCache': 'ServerDatasetCache', + +User can find the cache interface here. + +ExpressionCache +==================== +.. autoclass:: qlib.data.cache.ExpressionCache + :members: + +DatasetCache +===================== +.. autoclass:: qlib.data.cache.DatasetCache + :members: + + +Qlib has currently provided `ServerExpressionCache` class and `ServerDatasetCache` class as the cache mechanisms used for QlibServer. The class interface and file structure designed for server cache mechanism is listed below. + +ServerExpressionCache +===================== +.. autoclass:: qlib.data.cache.ServerExpressionCache + + +ServerDatasetCache +==================== +.. autoclass:: qlib.data.cache.ServerDatasetCache + + +Data and cache file structure on server +======================================== +.. code-block:: json + + - data/ + [raw data] updated by data providers + - calendars/ + - day.txt + - instruments/ + - all.txt + - csi500.txt + - ... + - features/ + - sh600000/ + - open.day.bin + - close.day.bin + - ... + - ... + [cached data] updated by server when raw data is updated + - calculated features/ + - sh600000/ + - [hash(instrtument, field_expression, freq)] + - all-time expression -cache data file + - .meta : an assorted meta file recording the instrument name, field name, freq, and visit times + - ... + - cache/ + - [hash(stockpool_config, field_expression_list, freq)] + - all-time Dataset-cache data file + - .meta : an assorted meta file recording the stockpool config, field names and visit times + - .index : an assorted index file recording the line index of all calendars + - ... diff --git a/docs/advanced/data.rst b/docs/advanced/data.rst new file mode 100644 index 0000000000..2e2155d4c5 --- /dev/null +++ b/docs/advanced/data.rst @@ -0,0 +1,176 @@ +.. _data: +============================ +Data: Data Framework&Usage +============================ + +Introduction +============================ + +``Qlib`` provides some methods for obtaining and processing data, and allows users to customize their own methods. + + +Raw Data +============================ + +Qlib provides the script 'scripts/get_data.py' to download the raw data that will be used to initialize the qlib package, please refer to `Initialization <../start/initialization.rst>`_. + +When Qlib is initialized, users can choose A-share mode or US stocks mode, please refer to `Initialization <../start/initialization.rst>`_. + +A-share Mode +-------------------------------- + +If users use Qlib in A-share mode, A-share data is required. The script'scripts/get_data.py' provides methods to download A-share data. If users want to use A-share mode, they need to do as follows. + +- Download data in csv format + Run the following command to download A-share data in csv format. + + .. code-block:: bash + + python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data + + Users can find A-share data in csv format in the'~/.qlib/csv_data/cn_data' directory. + +- Convert data from csv format to Qlib format + Qlib provides the 'scripts/dump_bin.py' to convert data from csv format to qlib format. + Assuming that the users store the A-share data in csv format in path '~/.qlib/csv_data/cn_data', they need to execute the following command to convert the data from csv format to Qlib format: + + .. code-block:: bash + + python scripts/dump_bin.py dump --csv_path ~/.qlib/csv_data/cn_data --qlib_dir ~/.qlib/qlib_data/cn_data --include_fields open,close,high,low,volume,factor + + + When initializing Qlib, users only need to execute `qlib.init(mount_path='~/.qlib/qlib_data/cn_data', region='us')`. Please refer to `Api`_. + +US Stock Mode +------------------------- +If users use Qlib in US Stock mode, US stock data is required. Qlib does not mention script to download US stock data. If users want to use Qlib in US stock mode, they need to do as follows. + +- Prepare data in csv format + Users need to prepare US stock data in csv format by themselves, which is in the same format as the A-share data in csv format. In order to refer to the format, please download the A-share csv data as follows. + + .. code-block:: bash + + python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data + + +- Convert data from csv format to Qlib format + Qlib provides the 'scripts/dump_bin.py' to convert data from csv format to qlib format. + Assuming that the users store the US Stock data in csv format in path '~/.qlib/csv_data/us_data', they need to execute the following command to convert the data from csv format to Qlib format: + + .. code-block:: bash + + python scripts/dump_bin.py dump --csv_path ~/.qlib/csv_data/us_data --qlib_dir ~/.qlib/qlib_data/us_data --include_fields open,close,high,low,volume,factor + + + When initializing Qlib, users only need to execute `qlib.init(mount_path='~/.qlib/qlib_data/us_data', region='us')`. Please refer to `Api`_. + +Please refer to `Script Api <../reference/api.html>`_ for more details. + +Data Retrieval +======================== + +Please refer to `Data Retrieval <../start/getdata.html>`_. + + +Data Handler +================= + +Data Handler is a part of estimator and can also be used as a single module. + +Data Handler can process the raw data. It uses the API in 'qlib.data' to get the raw data, It uses the API in'data' to obtain the original data, and then processes the data, such as standardizing features, removing NaN data, etc. + +Interface +----------------- + +Qlib provides a base class `qlib.contrib.estimator.BaseDataHandler <../reference/api.html#class-qlib.contrib.estimator.BaseDataHandler>`_, which provides the following interfaces: + +- `setup_feature` + Implement the interface to load the data features. + +- `setup_label` + Implement the interface to load the data labels and calculate user's labels. + +- `setup_processed_data` + Implement the interface for data preprocessing, such as preparing feature columns, discarding blank lines, and so on. + +Qlib also provides two functions to help user init the data handler, user can override them for user's need. + +- `_init_kwargs` + User can init the kwargs of the data handler in this function, some kwargs may be used when init the raw df. + Kwargs are the other attributes in data.args, like dropna_label, dropna_feature + +- `_init_raw_df` + User can init the raw df, feature names and label names of data handler in this function. + If the index of feature df and label df are not same, user need to override this method to merge them (e.g. inner, left, right merge). + +If users want to load features and labels through config, users can inherit `qlib.contrib.estimator.handler.ConfigDataHandler`, Qlib also have provided some preprocess method in this subclass. +If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`. + +Usage +------------------ +'Data Handler' can be used as a single module, which provides the following mehtod: + +- `get_split_data` + - According to the start and end dates, return features and labels of the pandas DataFrame type used for the 'Model' + +- `get_rolling_data` + - According to the start and end dates, and `rolling_period`, an iterator is returned, which can be used to traverse the features and labels used for rolling. + + +Example +------------------ + +'Data Handler' can be run with 'estimator' by modifying the configuration file, and can also be used as a single module. + +Know more about how to run 'Data Handler' with estimator, please refer to `Estimator `_. + +Qlib provides data handler 'QLibDataHandlerV1', the following example shows how to run 'QLibDataHandlerV1' as a single module. + +.. note:: User needs to initialize package qlib with qlib.init first, please refer to `initialization `_. + + +.. code-block:: Python + + from qlib.contrib.estimator.handler import QLibDataHandlerV1 + from qlib.contrib.model.gbdt import LGBModel + + DATA_HANDLER_CONFIG = { + "dropna_label": True, + "start_date": "2007-01-01", + "end_date": "2020-08-01", + "market": "csi500", + } + + TRAINER_CONFIG = { + "train_start_date": "2007-01-01", + "train_end_date": "2014-12-31", + "validate_start_date": "2015-01-01", + "validate_end_date": "2016-12-31", + "test_start_date": "2017-01-01", + "test_end_date": "2020-08-01", + } + + exampleDataHandler = QLibDataHandlerV1(**DATA_HANDLER_CONFIG) + + # example of 'get_split_data' + x_train, y_train, x_validate, y_validate, x_test, y_test = exampleDataHandler.get_split_data(**TRAINER_CONFIG) + + # example of 'get_rolling_data' + + for (x_train, y_train, x_validate, y_validate, x_test, y_test) in exampleDataHandler.get_rolling_data(**TRAINER_CONFIG): + print(x_train, y_train, x_validate, y_validate, x_test, y_test) + + +.. note:: (x_train, y_train, x_validate, y_validate, x_test, y_test) can be used as arguments for the ``fit``, ``predict``, and ``score`` methods of the 'Model' , please refer to `Model `_. + +Also, the above example has been given in `examples.estimator.train_backtest_analyze.ipynb`. + +To know more abot 'Data Handler', please refer to `Data Handler Api <../reference/api.html#handler>`_. + + + + +Api +====================== + +Please refer to `Data Api <../reference/api.html#>`_. \ No newline at end of file diff --git a/docs/advanced/estimator.rst b/docs/advanced/estimator.rst new file mode 100644 index 0000000000..28c10ded0e --- /dev/null +++ b/docs/advanced/estimator.rst @@ -0,0 +1,720 @@ +.. _estimator: +=================== +Estimator: Workflow Management +=================== +.. currentmodule:: qlib + +Introduction +=================== + +By ``Estimator``, user can start an 'experiment', which has the following process: + +- Data loading +- Data processing +- Data slicing +- Model static training, rolling training +- Model saving & loading +- Back testing + +Qlib will capture the standard input and output, and backtest performance files of this experiment, and identifiers such as names are stored on disk or on a database. + +Example +=================== + +The following is an example: + +.. note:: Make sure user have installed the latest version of `qlib`, see detail in `Qlib installation <../start/installation.html>`_. + +If user want to use the models and data provided by `Qlib`, then user only need to do as follows. + +First, Write a simple configuration file as following, + +.. code-block:: YAML + + experiment: + name: estimator_example + observer_type: file_storage + mode: train + + model: + class: LGBModel + module_path: qlib.contrib.model.gbdt + args: + loss: mse + colsample_bytree: 0.8879 + learning_rate: 0.0421 + subsample: 0.8789 + lambda_l1: 205.6999 + lambda_l2: 580.9768 + max_depth: 8 + num_leaves: 210 + num_threads: 20 + data: + class: QLibDataHandlerV1 + args: + dropna_label: True + filter: + market: csi500 + trainer: + class: StaticTrainer + args: + rolling_period: 360 + train_start_date: 2007-01-01 + train_end_date: 2014-12-31 + validate_start_date: 2015-01-01 + validate_end_date: 2016-12-31 + test_start_date: 2017-01-01 + test_end_date: 2020-08-01 + strategy: + class: TopkAmountStrategy + args: + topk: 50 + buffer_margin: 230 + backtest: + normal_backtest_args: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: SH000905 + deal_price: vwap + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 + long_short_backtest_args: + topk: 50 + + qlib_data: + # when testing, please modify the following parameters according to the specific environment + mount_path: "~/.qlib/qlib_data/cn_data" + region: "cn" + + +Then run the following command: + +.. code-block:: bash + + estimator -c configuration.yaml + +.. note:: 'estimator' is a built-in command of our program. + + +Then Make a happy one-click alchemy! + +Configuration file +=================== + +Before using `estimator`, user need to prepare a configuration file. Next Qlib will show user how to prepare each part of the configuration file. + +About the experiment +-------------------- + +First, configuration file needs to have a field about the experiment, whose key is `experiment`, this field and its contents determine how `estimator` tracks and persists this "experiment". Qlib used `sacred`, a lightweight open source tool designed to help us configure, organize, generate logs and manage experiment results. The field `experiment` will determine the partial behavior of `sacred`. + +Usually, in the running process of `estimator`, those following will be managed by `sacred`: + +- `model.bin`, model binary file +- `pred.pkl`, model prediction result file +- `analysis.pkl`, backtest performance analysis file +- `positions.pkl`, backtest position record file +- `run`, the experiment information object, usually contains some meta information such as the experiment name, experiment date, etc. + +Usually it should contain the following: + +.. code-block:: YAML + + experiment: + name: test_experiment + observer_type: mongo + mongo_url: mongodb://MONGO_URL + db_name: public + finetune: false + exp_info_path: /home/test_user/exp_info.json + mode: test + loader: + id: 677 + + +The meaning of each field is as follows: + +- `name` + The experiment name, str type, `sacred` will use this experiment name as an identifier for some important internal processes. Usually, user can see this field in `sacred` by `run` object. The default value is `test_experiment`. + +- `observer_type` + Observer type, str type, there are two values which are `file_storage` and `mongo` respectively. If it is `file_storage`, all the above-mentioned managed contents will be stored in the `dir` directory, separated by the number of times of experiments as a subfolder. If it is `mongo`, the content will be stored in the database. The default is `file_storage`. + + - For `file_storage` observer. + - `dir` + Directory url, str type, directory for `file_storage` observer type, files captures and managed by sacred with observer type of `file_storage` will be save to this directory, default is the directory of `config.json`. + + - For `mongo` observer. + - `mongo_url` + Database URL, str type, required if the observer type is `mongo`. + + - `db_name` + Database name, str type, required if the observer type is `mongo`. + +- `finetune` + Estimator will produce a model based on this flag + + The following table is the processing logic for different situations. + + ========== =========================================== ==================================== =========================================== ========================================== + . Static Rolling + . Finetune=True Finetune=False Finetune=True Finetune=False + ========== =========================================== ==================================== =========================================== ========================================== + Train - Need to provide model(Static or Rolling) - No need to provide model - Need to provide model(Static or Rolling) - Need to provide model(Static or Rolling) + - The args in model section will be - The args in model section will be - The args in model section will be - The args in model section will be + used for finetuning used for training used for finetuning used for finetuning + - Update based on the provided model - Train model from scratch - Update based on the provided model - Based on the provided model update + and parameters and parameters - Train model from scratch + - **Each rolling time slice is based on** - **Train each rolling time slice** + **a model updated from the previous** **separately** + **time** + Test - Model must exist, otherwise an exception will be raised. + - For `StaticTrainer`, user need to train a model and record 'exp_info' for 'Test'. + - For `RollingTrainer`, user need to train a set of models until the latest time, and record 'exp_info' for 'Test'. + ========== ============================================================================================================================================================================= + + .. note:: + + 1. finetune parameters: share model.args parameters. + + 2. provide model: from `loader.model_index`, load the index of the model(starting from 0). + + 3. If `loader.model_index` is None: + - In 'Static Finetune=True', if provide 'Rolling', use the last model to update. + + - For RollingTrainer with Finetune=Ture. + + - If StaticTrainer is used in loader, the model will be used for initialization for finetuning. + + - If RollingTrainer is used in loader, the existing models will be used without any modification and the new models will be initialized with the model in the last period and finetune one by one. + + +- `exp_info_path` + experiment info save path, str type, save the experiment info and model prediction score after the experiment is finished. Optional parameter, the default value is `config_file_dir/ex_name/exp_info.json` + +- `mode` + `train` or `test`, str type, if `mode` is test, it will load the model according to the parameters of `loader`. The default value is `train`. + Also note that when the load model failed, it will `fit` model. + +- `loader` + If the `mode` is `test` or `finetune` is `true`, it will be used. + + - `model_index` + Model index, int type. The index of the loaded model in loader_models (starting at 0) for the first `finetune`. The default value is None. + + - `exp_info_path` + Loader model experiment info path, str type. If the field exists, the following parameters will be parsed from `exp_info_path`, and the following parameters will not work. This field and `id` must exist one. + + - `id` + The experiment id of the model that needs to be loaded, int type. If the `mode` is `test`, this value is required. This field and `exp_info_path` must exist one. + + - `name` + The experiment name of the model that needs to be loaded, str type. The default value is the current experiment `name`. + + - `observer_type` + The experiment observer type of the model that needs to be loaded, str type. The default value is the current experiment `observer_type`. + +Detail Observer Type +~~~~~~~~~~~~~~~~~~~ + +The observer type is a concept of the `sacred` module, which determines how files, standard input and output which are managed by sacred are stored. + +file_storage +^^^^^^^^^^^^ + +If user's choice is `file_storage`, the config may be as following: + +.. code-block:: YAML + + experiment: + name: test_experiment + dir: # default is dir of `config.yml` + observer_type: file_storage + +mongo +^^^^^^^^^^^^ + +If user's choice is `mongo`, the config may be as following: + +.. code-block:: YAML + + experiment: + name: test_experiment + observer_type: mongo + mongo_url: mongodb://MONGO_URL + db_name: public + +The difference with `file_storage` is that user need to indicate `mongo_url` and `db_name` for a mongo observer. + +Note about Mongo Observer +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Also note that if user choose mongo observer, user need to make sure: + +- have an environment with the mongodb installed and a mongo database dedicated for storing the experiments results. +- The python environment(the version of python and package) to run the experiments and the one to fetch the results are consistent. + +Note about mode test +~~~~~~~~~~~~~~~~~~~~ +Also note that if user choose `mode` test, user need to make sure: + +- The loader of `test_start_date` must be less than or equal to the current `test_start_date`. +- If other parameters of the `loader` model args are different, a warning will appear. + + +About the model +----------------- + +User can use a specified model by configuration with hyper-parameters, + +For Custom Models +~~~~~~~~~~~~~~~~~ + +Qlib support custom models, but it must be a subclass of the `qlib.contrib.model.Model`, the config for custom model may be as following, + +.. code-block:: YAML + + model: + class: SomeModel + module_path: /tmp/my_experment/custom_model.py + args: + loss: binary + + +The class `SomeModel` should be in the module `custom_model`, and Qlib could parse the `module_path` to load the class. + +Learn more about how to integrate custom model into Qlib, see detial in `Integration <../start/integration.html>`_. + +About data +----------------- + +Qlib have provided a implemented data handler `ALPHA360`, which is used to load raw data, prepare features and label columns, preprocess data, split training, validation, and test sets. It is a subclass of `qlib.contrib.estimator.handler.BaseDataHandler` which provides some interfaces, for example: + +- `setup_feature` + Implement the interface to load the data features. + +- `setup_label` + Implement the interface to load the data labels and calculate user's labels. + +- `setup_processed_data` + Implement this interface for data preprocessing, such as preparing feature columns, discarding blank lines, and so on. + +The `ALPHA360` implements these interfaces as a subclass. Its original data is the stock data of CSI 500, and the tag is the excess return of `t+2` day. + +Qlib also provided two functions to help user init the data handler, user can override them for user's need. + +- `_init_kwargs` + User can init the kwargs of the data handler in this function, some kwargs may be used when init the raw df. + Kwargs are the other attributes in data.args, like dropna_label, dropna_feature + +- `_init_raw_df` + User can init the raw df, feature names and label names of data handler in this function. + If the index of feature df and label df are not same, user need to override this method to merge them (e.g. inner, left, right merge). + +User can use the specified data handler by config as following, + +.. code-block:: YAML + + data: + class: ALPHA360 + provider_uri: C:\\Users\\v-shuyw\\qlib\\qlib_data\\qlib + args: + start_date: 2005-01-01 + end_date: 2018-04-30 + dropna_label: True + feature_label_config: /data/qlib/feature_config/feature_config.yaml + filter: + market: csi500 + filter_pipeline: + - + class: NameDFilter + module_path: qlib.filter + args: + name_rule_re: S(?!Z3) + fstart_time: 2018-01-01 + fend_time: 2018-12-11 + - + class: ExpressionDFilter + module_path: qlib.filter + args: + rule_expression: $open/$factor<=45 + fstart_time: 2018-01-01 + fend_time: 2018-12-11 + +- `class` + Data handler class, str type, which should be a subclass of `qlib.contrib.estimator.handler.BaseDataHandler`, and implements 5 important interfaces for loading features, loading raw data, preprocessing raw data, slicing train, validation, and test data. The default value is `ALPHA360`. If the user want to write a data handler to retrieve the data in qlib. QlibDataHandler is suggested. + +- `module_path` + The module path, str type, absolute url is also supported, indicates the path of the `class` implementation of data processor class. The default value is `qlib.contrib.estimator.handler`. + +- `market` + Index name, str type, the default value is `csi500`. In the 0.3.2 version, move to filter section. + +- `train_start_date` + Training start time, str type, default value is `2005-01-01`. + +- `start_date` + Data start date, str type. + +- `end_date` + Data end date, str type. the data from start_date to end_date decides which part of data will be loaded in datahandler, user can only use these data in the following parts. + +- `dropna_feature` (Optional in args) + Drop Nan feature, bool type, default value is False. + +- `dropna_label` (Optional in args) + Drop Nan label, bool type, default value is True. Some multi-label tasks will use this. + +- `normalize_method` (Optional in args) + Normalzie data by given method. str type. Qlib give two normalize method, `MinMax` and `Std`. + If users wants to build their own method, please override `_process_normalize_feature`. + +- `feature_label_config` (Optional in args) + Features and labels config location, str type (or dict type), indicates the path and filename of user's features and labels config. User can configure the features and labels data in this yaml. (Or, users can just put their data config here directly.) + Here is a reference of data config: + + .. code-block:: YAML + + static_fields: ['$open/$close', '$high/$close', '$low/$close', '$vwap/$close'] + static_names: ['OPN', 'HIGH', 'LOW', 'VWAP'] + windows: [5, 10, 20, 30, 60] + dynamic_fields: ['Ref($vwap, {w})/$close', + 'Mean($vwap, {w})/$close'] + dynamic_names: ['ROC{w}', + 'MA{w}'] + labels: ['Ref($vwap, -2)/Ref($vwap, -1) - 1'] + + + - `static_fields` + Single feature list, list type, each element in the list represents a kind of feature. + + - `static_names` + Single feature name list, list type, each element in the list represents the name of each element in the `static_fields`. + + - `windows` + Time windows, list type, each element represent a time window which will be used when calculating dynamic features. + + - `dynamic_fields` + Dynamic feature which will be expanded by windows, list type, each element in the list represents a feature need to be calculated through time window. + + - `dynamic_names` + Dynamic feature name list, list type, each element in the list represents the name of each element in the `dynamic_fields`. + + - `labels` + Data labels, list type, the labels of the data. + + Qlib gave `ALPHA360` a default config in code, user can use it directly. + + +- `filter` + Dynamically filtering the stocks based on the filter pipeline. + + - `market` + index name, str type, the default value is `csi500`. In the 0.3.2 version, move to this section. + + - `filter_pipeline` + Filter rule list, list type, the default value is []. Can be customized according to user needs. + + - `class` + Filter class name, str type. + + - `module_path` + The module path, str type. + + - `args` + The filter class parameters, this parameters are set according to the `class`, and all the parameters as kwargs to `class`. + + +For Custom Data Handler +~~~~~~~~~~~~~~~~~~~~~~ + +Qlib support custom data handler, but it must be a subclass of the `qlib.contrib.estimator.handler.BaseDataHandler`, the config for custom data handler may be as following, + +.. code-block:: YAML + + data: + class: SomeDataHandler + module_path: /tmp/my_experment/custom_data_handler.py + provider_uri: C:\\Users\\v-shuyw\\qlib\\qlib_data\\qlib + args: + start_date: 2005-01-01 + end_date: 2018-04-30 + feature_label_config: /data/qlib/feature_config/feature_config.yaml + +The class `SomeDataHandler` should be in the module `custom_data_handler`, and Qlib could parse the `module_path` to load the class. + +If user want to load features and labels through config, user can inherit `qlib.contrib.estimator.handler.ConfigDataHandler`, Qlib also have provided some preprocess method in this subclass. +If user want to use qlib data, `QLibDataHandler` is recommended, user can inherit user's custom class through this one, which is also a subclass of `ConfigDataHandler`. + + +About training +----------------- + +User can specify the trainer `trainer` through the config file, which is subclass of `qlib.contrib.estimator.trainer.BaseTrainer` and implement three important interfaces for training the model, restoring the model, and getting model predictions, for example: + +- `train` + Implement this interface to train the model. + +- `load` + Implement this interface to recover the model from disk. + +- `get_pred` + Implement this interface to get model prediction results. + +Qlib have provided two implemented trainer, + +- `StaticTrainer` + The static trainer will be trained using the training, validation, and test data of the data processor static slicing. + +- `RollingTrainer` + The rolling trainer will use the rolling iterator of the data processor to split data for rolling training. + + +User can specify `trainer` through the configuration file: + +.. code-block:: YAML + + trainer: + class: StaticTrainer // or RollingTrainer + args: + rolling_period: 360 + train_start_date: 2005-01-01 + train_end_date: 2014-12-31 + validate_start_date: 2015-01-01 + validate_end_date: 2016-06-30 + test_start_date: 2016-07-01 + test_end_date: 2017-07-31 + +- `class` + Trainer class, trt should be a subclass of `qlib.contrib.estimator.trainer.BaseTrainer`, and need to implement three important interfaces, the default value is `StaticTrainer`. + +- `module_path` + The module path, str type, absolute url is also supported, indicates the path of the trainer class implementation. + +- `rolling_period` + The rolling period, integer type, indicates how many time steps need rolling when rolling the data. The default value is `60`. Only used in `RollingTrainer`. + +- `train_start_date` + Training start time, str type. + +- `train_end_date` + Training end time, str type. + +- `validate_start_date` + Validation start time, str type. + +- `validate_end_date` + Validation end time, str type. + +- `test_start_date` + Test start time, str type. + +- `test_end_date` + Test end time, str type. If `test_end_date` is `-1` or greater than the last date of the data, the last date of the data will be used as `test_end_date`. + +For Custom Trainer +~~~~~~~~~~~~~~~~~~ + +Qlib support custom trainer, but it must be a subclass of the `qlib.contrib.estimator.trainer.BaseTrainer`, the config for custom trainer may be as following, + +.. code-block:: YAML + + trainer: + class: SomeTrainer + module_path: /tmp/my_experment/custom_trainer.py + args: + train_start_date: 2005-01-01 + train_end_date: 2014-12-31 + validate_start_date: 2015-01-01 + validate_end_date: 2016-06-30 + test_start_date: 2016-07-01 + test_end_date: 2017-07-31 + + +The class `SomeTrainer` should be in the module `custom_trainer`, and Qlib could parse the `module_path` to load the class. + +About strategy +----------------- + +User can specify strategy through a config file, for example: + +.. code-block:: YAML + + strategy : + class: TopkAmountStrategy + args: + topk: 50 + buffer_margin: 300 + +- `class` + The strategy class, str type, should be a subclass of `qlib.contrib.strategy.strategy.BaseStrategy`. The default value is `TopkAmountStrategy`. + +- `module_path` + The module location, str type, absolute url is also supported, and absolute path is also supported, indicates the location of the policy class implementation. + +- `topk` + A threshold for buying rank, integer type, determines the threshold for the topk-margin strategy buy rank. The default value is 30. + +- `margin` + The sell buffer threshold, integer type, determines the buffer threshold, those who are outside the margin will be sold. The default value is 350. + + +For Custom Strategy +^^^^^^^^^^^^^^^^^^^ + +Qlib support custom strategy, but it must be a subclass of the `qlib.contrib.strategy.strategy.BaseStrategy`, the config for custom strategy may be as following, + + +.. code-block:: YAML + + strategy : + class: SomeStrategy + module_path: /tmp/my_experment/custom_strategy.py + +The class `SomeStrategy` should be in the module `custom_strategy`, and Qlib could parse the `module_path` to load the class. + +About backtest +----------------- + +User can specify `backtest` through a config file, for example: + +.. code-block:: YAML + + backtest : + normal_backtest_args: + topk: 50 + benchmark: SH000905 + account: 500000 + deal_price: vwap + min_cost: 5 + subscribe_fields: + - $close + - $change + - $factor + + long_short_backtest_args: + topk: 50 + subscribe_fields: + - $close + - $factor + +- `normal_backtest_args` + Normal backtest parameters. All the parameters in this section will be passed to the `qlib.contrib.evaluate.backtest` function in the form of `**kwargs`. + +- `long_short_backtest_args` + long short backtest parameters. All the parameters in this section will be passed to the `qlib.contrib.evaluate.long_short_backtest` function in the form of `**kwargs`. + +- `benchmark` + Stock index symbol, str or list type, the default value is `None`. + + .. note:: + + * If `benchmark` is None, it will use the average change of the day of all stocks in 'pred' as the 'bench'. + + * If `benchmark` is list, it will use the daily average change of the stock pool in the list as the 'bench'. + + * If `benchmark` is str, it will use the daily change as the 'bench'. + + +- `account` + Backtest initial cash, integer type. The `account` in `strategy` section is deprecated. It only works when `account` is not set in `backtest` section. It will be overridden by `account` in the `backtest` section. The default value is 1e9. + +- `deal_price` + Order transaction price field, str type, the default value is vwap. + +- `min_cost` + Min transaction cost, float type, the default value is 5. + +- `subscribe_fields` + Subscribe quote fields, array type, the default value is [`deal_price`, $close, $change, $factor]. + + +Experiment Result +=================== + +User can check the experiment results from file storage directly, or check the experiment results from database, or user can get the experiment results through two API of a module `fetcher` provided by us. + +- `get_experiments()` + The API takes two parameters. The first parameter is experiment name. The default are all experiments. The second parameter is the observer type. User can get experiment name dictionary with list of ids and test end date with this API as follows: + + +.. code-block:: JSON + + { + "ex_a": [ + { + "id": 1, + "test_end_date": "2017-01-01" + } + ], + "ex_b": [ + ... + ] + } + + +- `get_experiment(exp_name, exp_id, fields=None)` + The API takes three parameters, the first parameter is the experiment name, the second parameter is the experiment id, and the third parameter is field list. + If fields is None, will get all fields. + + .. note:: + Currently supported fields: + ['model', 'analysis', 'positions', 'report_normal', 'report_long', 'report_short', 'report_long_short', 'pred', 'task_config', 'label'] + +.. code-block:: JSON + + { + 'analysis': analysis_df, + 'pred': pred_df, + 'positions': positions_dic, + 'report_normal': report_normal_df, + 'report_long_short': report_long_short_df + } + + +Here is a simple example of `FileFetcher`, which could fetch files from `file_storage` observer. + + +.. code-block:: python + + >>> from qlib.contrib.estimator.fetcher import FileFetcher + >>> f = FileFetcher(experiments_dir=r'./') + >>> print(f.get_experiments()) + + { + 'test_experiment': [ + { + 'id': '1', + 'config': ... + }, + { + 'id': '2', + 'config': ... + }, + { + 'id': '3', + 'config': ... + } + ] + } + + + >>> print(f.get_experiment('test_experiment', '1')) + + risk + pred_long mean 0.001964 + std 0.001880 + sharpe 16.516510 + mdd -0.006503 + annual 0.490902 + pred_long_short mean 0.005570 + std 0.005056 + +If users uses mongo observer when training, user should initialize their fether with mongo_url + +.. code-block:: python + + >>> from qlib.contrib.estimator.fetcher import MongoFetcher + >>> f = MongoFetcher(mongo_url=..., db_name=...) + diff --git a/docs/advanced/model.rst b/docs/advanced/model.rst new file mode 100644 index 0000000000..b779f4203a --- /dev/null +++ b/docs/advanced/model.rst @@ -0,0 +1,179 @@ +=================== +Model: Train&Predict +=================== + +Introduction +=================== + +By ``Model``, users can use known data and features to train the model and predict the future score of the stock. + +Interface +=================== + +Qlib provides a base class `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_, which models should inherit from. + +The base class provides the following interfaces: + +- `def __init__` + - Initialization. + - If users use `estimator <../advanced/estimator.html>`_ to start an experiment, the parameter of `__init__` method shoule be consistent with the hyperparameters in the configuration file. + +- `def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs)` + - Train model. + - Parameter: + - ``x_train``, pd.DataFrame type, train feature + The following example explains the value of x_train: + + .. code-block:: YAML + + KMID KLEN KMID2 KUP KUP2 + instrument datetime + SH600004 2012-01-04 0.000000 0.017685 0.000000 0.012862 0.727275 + 2012-01-05 -0.006473 0.025890 -0.250001 0.012945 0.499998 + 2012-01-06 0.008117 0.019481 0.416666 0.008117 0.416666 + 2012-01-09 0.016051 0.025682 0.624998 0.006421 0.250001 + 2012-01-10 0.017323 0.026772 0.647057 0.003150 0.117648 + ... ... ... ... ... ... + SZ300273 2014-12-25 -0.005295 0.038697 -0.136843 0.016293 0.421052 + 2014-12-26 -0.022486 0.041701 -0.539215 0.002453 0.058824 + 2014-12-29 -0.031526 0.039092 -0.806451 0.000000 0.000000 + 2014-12-30 -0.010000 0.032174 -0.310811 0.013913 0.432433 + 2014-12-31 0.010917 0.020087 0.543479 0.001310 0.065216 + + + ``x_train`` is a pandas DataFrame, whose index is MultiIndex . Each column of `x_train` corresponds to a feature, and the column name is the feature name. + + .. note:: + + The number and names of the columns is determined by the data handler, please refer to `Data Handler `_ and `Estimator Data `_. + + - ``y_train``, pd.DataFrame type, train label + The following example explains the value of y_train: + + .. code-block:: YAML + + LABEL3 + instrument datetime + SH600004 2012-01-04 -0.798456 + 2012-01-05 -1.366716 + 2012-01-06 -0.491026 + 2012-01-09 0.296900 + 2012-01-10 0.501426 + ... ... + SZ300273 2014-12-25 -0.465540 + 2014-12-26 0.233864 + 2014-12-29 0.471368 + 2014-12-30 0.411914 + 2014-12-31 1.342723 + + ``y_train`` is a pandas DataFrame, whose index is MultiIndex . The 'LABEL3' column represents the value of train label. + + .. note:: + + The number and names of the columns is determined by the data handler, please refer to `Data Handler `_. + + - ``x_valid``, pd.DataFrame type, validation feature + The form of ``x_valid`` is same as ``x_train`` + + + - ``y_valid``, pd.DataFrame type, validation label + The form of ``y_valid`` is same as ``y_train`` + + - ``w_train``(Optional args, default is None), pd.DataFrame type, train weight + ``w_train`` is a pandas DataFrame, whose shape and index is same as ``x_train``. The float value in ``w_train`` represents the weight of the feature at the same position in ``x_train``. + + - ``w_valid``(Optional args, default is None), pd.DataFrame type, validation weight + ``w_valid`` is a pandas DataFrame, whose shape and index is same as ``x_valid``. The float value in ``w_train`` represents the weight of the feature at the same position in ``x_train``. + +- `def predict(self, x_test, **kwargs)` + - Predict test data 'x_test' + - Parameter: + - ``x_test``, pd.DataFrame type, test features + The form of ``x_test`` is same as ``x_train`` in 'fit' method. + - Return: + - ``label``, np.ndarray type, test label + The label of ``x_test`` that predicted by model. + +- `def score(self, x_test, y_test, w_test=None, **kwargs)` + - Evaluate model with test feature/label + - Parameter: + - ``x_test``, pd.DataFrame type, test feature + The form of ``x_test`` is same as ``x_train`` in 'fit' method. + + - ``x_test``, pd.DataFrame type, test label + The form of ``y_test`` is same as ``y_train`` in 'fit' method. + + - ``w_test``, pd.DataFrame type, test weight + The form of ``w_test`` is same as ``w_train`` in 'fit' method. + - Return: float type, evaluation score + +For other interfaces such as ``save``, ``load``, ``finetune``, please refer to `Model Api <../reference/api.html#module-qlib.contrib.model.base>`_. + +Example +================== + +'Model' can be run with 'estimator' by modifying the configuration file, and can also be used as a single module. + +Know more about how to run 'Model' with estimator, please refer to `Estimator `_. + +Qlib provides LightGBM and DNN models as the baseline, the following example shows how to run LightGBM as a single module. + +.. note:: User needs to initialize package qlib with qlib.init first, please refer to `initialization `_. + + +.. code-block:: Python + + from qlib.contrib.estimator.handler import QLibDataHandlerV1 + from qlib.contrib.model.gbdt import LGBModel + + DATA_HANDLER_CONFIG = { + "dropna_label": True, + "start_date": "2007-01-01", + "end_date": "2020-08-01", + "market": MARKET, + } + + TRAINER_CONFIG = { + "train_start_date": "2007-01-01", + "train_end_date": "2014-12-31", + "validate_start_date": "2015-01-01", + "validate_end_date": "2016-12-31", + "test_start_date": "2017-01-01", + "test_end_date": "2020-08-01", + } + + x_train, y_train, x_validate, y_validate, x_test, y_test = QLibDataHandlerV1( + **DATA_HANDLER_CONFIG + ).get_split_data(**TRAINER_CONFIG) + + + MODEL_CONFIG = { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + } + # use default model + # custom Model, refer to: TODO: Model api url + model = LGBModel(**MODEL_CONFIG) + model.fit(x_train, y_train, x_validate, y_validate) + _pred = model.predict(x_test) + +.. note:: 'QLibDataHandlerV1' is the data handler provided by Qlib, please refer to `Data Handler `_. + +Also, the above example has been given in ``examples.estimator.train_backtest_analyze.ipynb``. + +Custom Model +=================== + +Qlib supports custom models, how to customize the model and integrate the model into Qlib, please refer to `How to integrate Model into Qlib <../start/integration.html>`_. + + +Api +=================== +Please refer to `Model Api <../reference/api.html#module-qlib.contrib.model.base>`_ for Model Api. diff --git a/docs/advanced/report.rst b/docs/advanced/report.rst new file mode 100644 index 0000000000..029fc0e419 --- /dev/null +++ b/docs/advanced/report.rst @@ -0,0 +1,76 @@ +=================== +'Report': Graphical Results +=================== + +Introduction +=================== + +By ``Report``, user can view the graphical results of the experiment. + +There are the following graphics to view: + +- analysis_position + - report_graph + - score_ic_graph + - cumulative_return_graph + - risk_analysis_graph + - rank_label_graph + +- analysis_model + - model_performance_graph + + +Example +=================== + +.. note:: + + The following is a simple example of drawing. + For more features, please see the function document: similar to ``help(qcr.analysis_position.report_graph)`` + + +Get all supported graphics. Please see the API section at the bottom of the page for details: + +.. code-block:: python + + >>> import qlib.contrib.report as qcr + >>> print(qcr.GRAPH_NAME_LISt) + ['analysis_position.report_graph', 'analysis_position.score_ic_graph', 'analysis_position.cumulative_return_graph', 'analysis_position.risk_analysis_graph', 'analysis_position.rank_label_graph', 'analysis_model.model_performance_graph'] + + + + + +API +=================== + + + +.. automodule:: qlib.contrib.report.analysis_position.report + :members: + + + +.. automodule:: qlib.contrib.report.analysis_position.score_ic + :members: + + + +.. automodule:: qlib.contrib.report.analysis_position.cumulative_return + :members: + + + +.. automodule:: qlib.contrib.report.analysis_position.risk_analysis + :members: + + + +.. automodule:: qlib.contrib.report.analysis_position.rank_label + :members: + + + +.. automodule:: qlib.contrib.report.analysis_model.analysis_model_performance + :members: + diff --git a/docs/advanced/strategy.rst b/docs/advanced/strategy.rst new file mode 100644 index 0000000000..4c9bf4c339 --- /dev/null +++ b/docs/advanced/strategy.rst @@ -0,0 +1,121 @@ +.. _strategy: +========================== +Strategy: Portfolio Management +========================== +.. currentmodule:: qlib + +Introduction +=================== + +By ``Strategy``, users can adopt different trading strategies, which means that users can use different algorithms to generate investment portfolios based on the predicted scores of the ``Model`` module. + +``Qlib`` provides several trading strategy classes, users can customize strategies according to their own needs also. + +Base Class & Interface +===================== + +BaseStrategy +------------------ + +Qlib provides a base class ``qlib.contrib.strategy.BaseStrategy``. All strategy classes need to inherit the base class and implement its interface. + +- `get_risk_degree` + Return the proportion of your total value you will use in investment. Dynamically risk_degree will result in Market timing. + +- `generate_order_list` + Rerturn the order list. + +User can inherit 'BaseStrategy' to costomize their strategy class. + +WeightStrategyBase +-------------------- + +Qlib alse provides a class ``qlib.contrib.strategy.WeightStrategyBase`` that is a subclass of `BaseStrategy`. + +`WeightStrategyBase` only focuses on the target positions, and automatically generates an order list based on positions. It provides the `generate_target_weight_position` interface. + +- `generate_target_weight_position` + According to the current position and trading date to generate the target position. + + .. note:: The cash is not considered. + Return the target position. + +`WeightStrategyBase` implements the interface `generate_order_list`, whose process is as follows. + +- Call `generate_target_weight_position` method to generate the target position. +- Generate the target amount of stocks from the target position. +- Generate the order list from the target amount + +User can inherit `WeightStrategyBase` and implement the inteface `generate_target_weight_position` to costomize their strategy class, which focuses on the target positions. + +Implemented Strategy +==================== + +Qlib provides several implemented strategy classes, such as `TopkWeightStrategy`, `TopkAmountStrategy` and `TopkDropoutStrategy`. + +TopkWeightStrategy +------------------ +`TopkWeightStrategy` is a subclass of `WeightStrategyBase` and implements the interface `generate_target_weight_position`. + +The implemented interface `generate_target_weight_position` adopts the ``Topk`` algorithm to calculate the target position, it ensures that the weight of each stock is as even as possible. + +.. note:: + ``TopK`` algorithm: Define a threshold `margin`. On each trading day, the stocks with the predicted scores behind `margin` will be sold, and then the stocks with the best predicted scores will be bought to maintain the number of stocks at k. + + + +TopkAmountStrategy +------------------ +`TopkAmountStrategy` is a subclass of `BaseStrategy` and implement the interface `generate_order_list` whose process is as follows. + +- Adopt the the ``Topk`` algorithm to calculate the target amount of each stock +- Generate the order list from the target amount + + + +TopkDropoutStrategy +------------------ +`TopkDropoutStrategy` is a subclass of `BaseStrategy` and implement the interface `generate_order_list` whose process is as follows. + +- Adopt the the ``TopkDropout`` algorithm to calculate the target amount of each stock + + .. note:: + + ``TopkDropout`` algorithm: On each trading day, the held stocks with the worst predicted scores will be sold, and then stocks with the best predicted scores will be bought to maintain the number of stocks at k. Because a fixed number of stocks are sold and bought every day, this algorithm can make the turnover rate a fixed value. + +- Generate the order list from the target amount + +Example +==================== +``Strategy`` can be specified in the ``Backtest`` module, the example is as follows. + +.. code-block:: python + + from qlib.contrib.strategy.strategy import TopkAmountStrategy + from qlib.contrib.evaluate import backtest + STRATEGY_CONFIG = { + "topk": 50, + "buffer_margin": 230, + } + BACKTEST_CONFIG = { + "verbose": False, + "limit_threshold": 0.095, + "account": 100000000, + "benchmark": BENCHMARK, + "deal_price": "vwap", + } + + # use default strategy + # custom Strategy, refer to: TODO: Strategy api url + strategy = TopkAmountStrategy(**STRATEGY_CONFIG) + report_normal, positions_normal = backtest( + pred_score, strategy=strategy, **BACKTEST_CONFIG + ) + +Also, the above example has been given in ``examples.estimator.train_backtest_analyze.ipynb``. + +To know more about ``Backtest``, please refer to `Backtest: Model&Strategy Testing `_. + +Api +=================== +Please refer to `Strategy Api <../reference/api.html>`_. diff --git a/docs/advanced/tuner.rst b/docs/advanced/tuner.rst new file mode 100644 index 0000000000..47dbb30315 --- /dev/null +++ b/docs/advanced/tuner.rst @@ -0,0 +1,327 @@ +.. _tuner: + +Tuner +=================== +.. currentmodule:: qlib + +Introduction +------------------- + +Welcome to use Tuner, this document is based on that you can use Estimator proficiently and correctly. + +You can find the optimal hyper-parameters and combinations of models, trainers, strategies and data labels. + +The usage of program `tuner` is similar with `estimator`, you need provide the URL of the configuration file. +The `tuner` will do the following things: + +- Construct tuner pipeline +- Search and save best hyper-parameters of one tuner +- Search next tuner in pipeline +- Save the global best hyper-parameters and combination + +Each tuner is consisted with a kind of combination of modules, and its goal is searching the optimal hyper-parameters of this combination. +The pipeline is consisted with different tuners, it is aim at finding the optimal combination of modules. + +The result will be printed on screen and saved in file, you can check the result in your experiment saving files. + +Example +~~~~~~~ + +Let's see an example, + +First make sure you have the latest version of `qlib` installed. + +Then, you need to privide a configuration to setup the experiment. +We write a simple configuration example as following, + +.. code-block:: YAML + + experiment: + name: tuner_experiment + tuner_class: QLibTuner + qlib_client: + auto_mount: False + logging_level: INFO + optimization_criteria: + report_type: model + report_factor: model_score + optim_type: max + tuner_pipeline: + - + model: + class: SomeModel + space: SomeModelSpace + trainer: + class: RollingTrainer + strategy: + class: TopkAmountStrategy + space: TopkAmountStrategySpace + max_evals: 2 + + time_period: + rolling_period: 360 + train_start_date: 2005-01-01 + train_end_date: 2014-12-31 + validate_start_date: 2015-01-01 + validate_end_date: 2016-06-30 + test_start_date: 2016-07-01 + test_end_date: 2018-04-30 + data: + class: ALPHA360 + provider_uri: /data/qlib + args: + start_date: 2005-01-01 + end_date: 2018-04-30 + dropna_label: True + dropna_feature: True + filter: + market: csi500 + filter_pipeline: + - + class: NameDFilter + module_path: qlib.data.filter + args: + name_rule_re: S(?!Z3) + fstart_time: 2018-01-01 + fend_time: 2018-12-11 + - + class: ExpressionDFilter + module_path: qlib.data.filter + args: + rule_expression: $open/$factor<=45 + fstart_time: 2018-01-01 + fend_time: 2018-12-11 + backtest: + normal_backtest_args: + verbose: False + limit_threshold: 0.095 + account: 500000 + benchmark: SH000905 + deal_price: vwap + long_short_backtest_args: + topk: 50 + +Next, we run the following command, and you can see: + +.. code-block:: bash + + ~/v-yindzh/Qlib/cfg$ tuner -c tuner_config.yaml + + Searching params: {'model_space': {'colsample_bytree': 0.8870905643607678, 'lambda_l1': 472.3188735122233, 'lambda_l2': 92.75390994877243, 'learning_rate': 0.09741751430635413, 'loss': 'mse', 'max_depth': 8, 'num_leaves': 160, 'num_threads': 20, 'subsample': 0.7536051584789751}, 'strategy_space': {'buffer_margin': 250, 'topk': 40}} + ... + (Estimator experiment screen log) + ... + Searching params: {'model_space': {'colsample_bytree': 0.6667379039007301, 'lambda_l1': 382.10698024977904, 'lambda_l2': 117.02506488151757, 'learning_rate': 0.18514539615228137, 'loss': 'mse', 'max_depth': 6, 'num_leaves': 200, 'num_threads': 12, 'subsample': 0.9449255686969292}, 'strategy_space': {'buffer_margin': 200, 'topk': 30}} + ... + (Estimator experiment screen log) + ... + Local best params: {'model_space': {'colsample_bytree': 0.6667379039007301, 'lambda_l1': 382.10698024977904, 'lambda_l2': 117.02506488151757, 'learning_rate': 0.18514539615228137, 'loss': 'mse', 'max_depth': 6, 'num_leaves': 200, 'num_threads': 12, 'subsample': 0.9449255686969292}, 'strategy_space': {'buffer_margin': 200, 'topk': 30}} + Time cost: 489.87220 | Finished searching best parameters in Tuner 0. + Time cost: 0.00069 | Finished saving local best tuner parameters to: tuner_experiment/estimator_experiment/estimator_experiment_0/local_best_params.json . + Searching params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 2',)}, 'model_space': {'input_dim': 158, 'lr': 0.001, 'lr_decay': 0.9100529502185579, 'lr_decay_steps': 162.48901403763966, 'optimizer': 'gd', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 300, 'topk': 35}} + ... + (Estimator experiment screen log) + ... + Searching params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 1',)}, 'model_space': {'input_dim': 158, 'lr': 0.1, 'lr_decay': 0.9882802970847494, 'lr_decay_steps': 164.76742865207729, 'optimizer': 'adam', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 250, 'topk': 35}} + ... + (Estimator experiment screen log) + ... + Local best params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 1',)}, 'model_space': {'input_dim': 158, 'lr': 0.1, 'lr_decay': 0.9882802970847494, 'lr_decay_steps': 164.76742865207729, 'optimizer': 'adam', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 250, 'topk': 35}} + Time cost: 550.74039 | Finished searching best parameters in Tuner 1. + Time cost: 0.00023 | Finished saving local best tuner parameters to: tuner_experiment/estimator_experiment/estimator_experiment_1/local_best_params.json . + Time cost: 1784.14691 | Finished tuner pipeline. + Time cost: 0.00014 | Finished save global best tuner parameters. + Best Tuner id: 0. + You can check the best parameters at tuner_experiment/global_best_params.json. + + +Finally, you can check the results of your experiment in the given path. + +Configuration file +------------------ + +Before using `tuner`, you need to prepare a configuration file. Next we will show you how to prepare each part of the configuration file. + +About the experiment +~~~~~~~~~~~~~~~~~~~~ + +First, your configuration file needs to have a field about the experiment, whose key is `experiment`, this field and its contents determine the saving path and tuner class. + +Usually it should contain the following content: + +.. code-block:: YAML + + experiment: + name: tuner_experiment + tuner_class: QLibTuner + +Also, there are some optional fields. The meaning of each field is as follows: + +- `name` + The experiment name, str type, the program will use this experiment name to construct a directory to save the process of the whole experiment and the results. The default value is `tuner_experiment`. + +- `dir` + The saving path, str type, the program will construct the experiment directory in this path. The default value is the path where configuration locate. + +- `tuner_class` + The class of tuner, str type, must be an already implemented model, such as `QLibTuner` in `qlib`, or a custom tuner, but it must be a subclass of `qlib.contrib.tuner.Tuner`, the default value is `QLibTuner`. + +- `tuner_module_path` + The module path, str type, absolute url is also supported, indicates the path of the implementation of tuner. The default value is `qlib.contrib.tuner.tuner` + +About the optimization criteria +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You need to designate a factor to optimize, for tuner need a factor to decide which case is better than other cases. +Usually, we use the result of `estimator`, such as backtest results and the score of model. + +This part needs contain these fields: + +.. code-block:: YAML + + optimization_criteria: + report_type: model + report_factor: model_pearsonr + optim_type: max + +- `report_type` + The type of the report, str type, determines which kind of report you want to use. If you want to use the backtest result type, you can choose `pred_long`, `pred_long_short`, `pred_short`, `sub_bench` and `sub_cost`. If you want to use the model result type, you can only choose `model`. + +- `report_factor` + The factor you want to use in the report, str type, determines which factor you want to optimize. If your `report_type` is backtest result type, you can choose `annual`, `sharpe`, `mdd`, `mean` and `std`. If your `report_type` is model result type, you can choose `model_score` and `model_pearsonr`. + +- `optim_type` + The optimization type, str type, determines what kind of optimization you want to do. you can minimize the factor or maximize the factor, so you can choose `max`, `min` or `correlation` at this field. + Note: `correlation` means the factor's best value is 1, such as `model_pearsonr` (a corraltion coefficient). + +If you want to process the factor or you want fetch other kinds of factor, you can override the `objective` method in your own tuner. + +About the tuner pipeline +~~~~~~~~~~~~~~~~~~~~~~~~ + +The tuner pipeline contains different tuners, and the `tuner` program will process each tuner in pipeline. Each tuner will get an optimal hyper-parameters of its specific combination of modules. The pipeline will contrast the results of each tuner, and get the best combination and its optimal hyper-parameters. So, you need to configurate the pipeline and each tuner, here is an example: + +.. code-block:: YAML + + tuner_pipeline: + - + model: + class: SomeModel + space: SomeModelSpace + trainer: + class: RollingTrainer + strategy: + class: TopkAmountStrategy + space: TopkAmountStrategySpace + max_evals: 2 + +Each part represents a tuner, and its modules which are to be tuned. Space in each part is the hyper-parameters' space of a certain module, you need to create your searching space and modify it in `/qlib/contrib/tuner/space.py`. We use `hyperopt` package to help us to construct the space, you can see the detail of how to use it in https://github.com/hyperopt/hyperopt/wiki/FMin . + +- model + You need to provide the `class` and the `space` of the model. If the model is user's own implementation, you need to privide the `module_path`. + +- trainer + You need to proveide the `class` of the trainer. If the trainer is user's own implementation, you need to privide the `module_path`. + +- strategy + You need to provide the `class` and the `space` of the strategy. If the strategy is user's own implementation, you need to privide the `module_path`. + +- data_label + The label of the data, you can search which kinds of labels will lead to a better result. This part is optional, and you only need to provide `space`. + +- max_evals + Allow up to this many function evaluations in this tuner. The default value is 10. + +If you don't want to search some modules, you can fix their spaces in `space.py`. We will not give the default module. + +About the time period +~~~~~~~~~~~~~~~~~~~~~ + +You need to use the same dataset to evaluate your different `estimator` experiments in `tuner` experiment. Two experiments using different dataset are uncomparable. You can specify `time_period` through the configuration file: + +.. code-block:: YAML + + time_period: + rolling_period: 360 + train_start_date: 2005-01-01 + train_end_date: 2014-12-31 + validate_start_date: 2015-01-01 + validate_end_date: 2016-06-30 + test_start_date: 2016-07-01 + test_end_date: 2018-04-30 + +- `rolling_period` + The rolling period, integer type, indicates how many time steps need rolling when rolling the data. The default value is `60`. If you use `RollingTrainer`, this config will be used, or it will be ignored. + +- `train_start_date` + Training start time, str type. + +- `train_end_date` + Training end time, str type. + +- `validate_start_date` + Validation start time, str type. + +- `validate_end_date` + Validation end time, str type. + +- `test_start_date` + Test start time, str type. + +- `test_end_date` + Test end time, str type. If `test_end_date` is `-1` or greater than the last date of the data, the last date of the data will be used as `test_end_date`. + +About the data and backtest +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +`data` and `backtest` are all same in the whole `tuner` experiment. Different `estimator` experiments must use the same data and backtest method. So, these two parts of config are same with that in `estimator` configuration. You can see the precise defination of these parts in `estimator` introduction. We only provide an example here. + +.. code-block:: YAML + + data: + class: ALPHA360 + provider_uri: /data/qlib + args: + start_date: 2005-01-01 + end_date: 2018-04-30 + dropna_label: True + dropna_feature: True + feature_label_config: /home/v-yindzh/v-yindzh/QLib/cfg/feature_config.yaml + filter: + market: csi500 + filter_pipeline: + - + class: NameDFilter + module_path: qlib.filter + args: + name_rule_re: S(?!Z3) + fstart_time: 2018-01-01 + fend_time: 2018-12-11 + - + class: ExpressionDFilter + module_path: qlib.filter + args: + rule_expression: $open/$factor<=45 + fstart_time: 2018-01-01 + fend_time: 2018-12-11 + backtest: + normal_backtest_args: + verbose: False + limit_threshold: 0.095 + account: 500000 + benchmark: SH000905 + deal_price: vwap + long_short_backtest_args: + topk: 50 + +Experiment Result +----------------- + +All the results are stored in experiment file directly, you can check them directly in the corresponding files. +What we save are as following: + +- Global optimal parameters +- Local optimal parameters of each tuner +- Config file of this `tuner` experiment +- Every `estimator` experiments result in the process + diff --git a/docs/changelog/changelog.rst b/docs/changelog/changelog.rst new file mode 100644 index 0000000000..2414029293 --- /dev/null +++ b/docs/changelog/changelog.rst @@ -0,0 +1,2 @@ +.. include:: ../../CHANGES.rst + diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000000..0e815d7e02 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,224 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +# QLib documentation build configuration file, created by +# sphinx-quickstart on Wed Sep 27 15:16:05 2017. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +import pkg_resources + + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.todo', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u"QLib" +copyright = u"Microsoft" +author = u"Microsoft" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = pkg_resources.get_distribution("qlib").version +# The full version, including alpha/beta/rc tags. +release = pkg_resources.get_distribution("qlib").version + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = 'en_US' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + +# If true, '()' will be appended to :func: etc. cross-reference text. +add_function_parentheses = False + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +add_module_names = True + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "sphinx_rtd_theme" + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# html_context = { +# "display_github": False, +# "last_updated": True, +# "commit": True, +# "github_user": "Microsoft", +# "github_repo": "QLib", +# 'github_version': 'master', +# 'conf_py_path': '/docs/', + +# } +# +html_theme_options = { + 'collapse_navigation': False, + 'display_version': False, + 'navigation_depth': 3, +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +#html_static_path = ['_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# This is required for the alabaster theme +# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars +html_sidebars = { + '**': [ + 'about.html', + 'navigation.html', + 'relations.html', # needs 'show_related': True theme option to display + 'searchbox.html', + ] +} + + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'qlibdoc' + + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, "qlib.tex", u"QLib Documentation", u"Microsoft", "manual"), +] + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'qlib', u'QLib Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'QLib', u'QLib Documentation', + author, 'QLib', 'One line description of project.', + 'Miscellaneous'), +] + + + +# -- Options for Epub output ---------------------------------------------- + +# Bibliographic Dublin Core info. +epub_title = project +epub_author = author +epub_publisher = author +epub_copyright = copyright + +# The unique identifier of the text. This can be a ISBN number +# or the project homepage. +# +# epub_identifier = '' + +# A unique identification for the text. +# +# epub_uid = '' + +# A list of files that should not be packed into the epub file. +epub_exclude_files = ['search.html'] + + +autodoc_member_order = 'bysource' +autodoc_default_flags = ['members'] diff --git a/docs/hidden/client.rst b/docs/hidden/client.rst new file mode 100644 index 0000000000..242c1afc27 --- /dev/null +++ b/docs/hidden/client.rst @@ -0,0 +1,171 @@ +.. _client: + +Qlib Client-Server Framework +=================== + +.. currentmodule:: qlib + +Introduction +----------- +Client-Server is designed to solve following problems + +- Manage the data in a centralized way. Users don't have to manage data of different versions. +- Reduce the amount of cache to be generated. +- Make the data can be accessed in a remote way. + +Therefore, we designed the client-server framework to solve these problems. +We will maintain a server and provide the data. + +You have to initialize you qlib with specific config for using the client-server framework. +Here is a typical initialization process. + +qlib ``init`` commonly used parameters; ``nfs-common`` must be installed on the server where the client is located, execute: ``sudo apt install nfs-common``: + - ``provider_uri``: nfs-server path; the format is ``host: data_dir``, for example: ``172.23.233.89:/data2/gaochao/sync_qlib/qlib``. If using offline, it can be a local data directory + - ``mount_path``: local data directory, ``provider_uri`` will be mounted to this directory + - ``auto_mount``: whether to automatically mount ``provider_uri`` to ``mount_path`` during qlib ``init``; You can also mount it manually: sudo mount.nfs ``provider_uri`` ``mount_path``. If on PAI, it is recommended to set ``auto_mount=True`` + - ``flask_server``: data service host; if you are on the intranet, you can use the default host: 172.23.233.89 + - ``flask_port``: data service port + + +If running on 10.150.144.153 or 10.150.144.154 server, it's recommended to use the following code to ``init`` qlib: + +.. code-block:: python + + >>> import qlib + >>> qlib.init(auto_mount=False, mount_path='/data/csdesign/qlib') + >>> from qlib.data import D + >>> D.features(['SH600000'], ['$close'], start_time='20080101', end_time='20090101').head() + [39336:MainThread](2019-05-28 21:35:42,800) INFO - Initialization - [__init__.py:16] - default_conf: client. + [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings. + [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:56] - provider_uri=172.23.233.89:/data2/gaochao/sync_qlib/qlib + [39336:Thread-68](2019-05-28 21:35:42,809) INFO - Client - [client.py:28] - Connect to server ws://172.23.233.89:9710 + [39336:Thread-72](2019-05-28 21:35:43,489) INFO - Client - [client.py:31] - Disconnect from server! + Opening /data/csdesign/qlib/cache/d239a3b191daa9a5b1b19a59beb47b33 in read-only mode + Out[5]: + $close + instrument datetime + SH600000 2008-01-02 119.079704 + 2008-01-03 113.120125 + 2008-01-04 117.878860 + 2008-01-07 124.505539 + 2008-01-08 125.395004 + + +If running on PAI, it's recommended to use the following code to ``init`` qlib: + +.. code-block:: python + + >>> import qlib + >>> qlib.init(auto_mount=True, mount_path='/data/csdesign/qlib', provider_uri='172.23.233.89:/data2/gaochao/sync_qlib/qlib') + >>> from qlib.data import D + >>> D.features(['SH600000'], ['$close'], start_time='20080101', end_time='20090101').head() + [39336:MainThread](2019-05-28 21:35:42,800) INFO - Initialization - [__init__.py:16] - default_conf: client. + [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings. + [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:56] - provider_uri=172.23.233.89:/data2/gaochao/sync_qlib/qlib + [39336:Thread-68](2019-05-28 21:35:42,809) INFO - Client - [client.py:28] - Connect to server ws://172.23.233.89:9710 + [39336:Thread-72](2019-05-28 21:35:43,489) INFO - Client - [client.py:31] - Disconnect from server! + Opening /data/csdesign/qlib/cache/d239a3b191daa9a5b1b19a59beb47b33 in read-only mode + Out[5]: + $close + instrument datetime + SH600000 2008-01-02 119.079704 + 2008-01-03 113.120125 + 2008-01-04 117.878860 + 2008-01-07 124.505539 + 2008-01-08 125.395004 + + +If running on Windows, open **NFS** features and write correct **mount_path**, it's recommended to use the following code to ``init`` qlib: + +1.windows System open NFS Features + * Open ``Programs and Features``. + * Click ``Turn Windows features on or off``. + * Scroll down and check the option ``Services for NFS``, then click OK + Reference address: https://graspingtech.com/mount-nfs-share-windows-10/ +2.config correct mount_path + * In windows, mount path must be not exist path and root path, + * correct format path eg: `H`, `i`... + * error format path eg: `C`, `C:/user/name`, `qlib_data`... + +.. code-block:: python + + >>> import qlib + >>> qlib.init(auto_mount=True, mount_path='H', provider_uri='172.23.233.89:/data2/gaochao/sync_qlib/qlib') + >>> from qlib.data import D + >>> D.features(['SH600000'], ['$close'], start_time='20080101', end_time='20090101').head() + [39336:MainThread](2019-05-28 21:35:42,800) INFO - Initialization - [__init__.py:16] - default_conf: client. + [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings. + [39336:MainThread](2019-05-28 21:35:42,801) INFO - Initialization - [__init__.py:56] - provider_uri=172.23.233.89:/data2/gaochao/sync_qlib/qlib + [39336:Thread-68](2019-05-28 21:35:42,809) INFO - Client - [client.py:28] - Connect to server ws://172.23.233.89:9710 + [39336:Thread-72](2019-05-28 21:35:43,489) INFO - Client - [client.py:31] - Disconnect from server! + Opening /data/csdesign/qlib/cache/d239a3b191daa9a5b1b19a59beb47b33 in read-only mode + Out[5]: + $close + instrument datetime + SH600000 2008-01-02 119.079704 + 2008-01-03 113.120125 + 2008-01-04 117.878860 + 2008-01-07 124.505539 + 2008-01-08 125.395004 + + + + + +The client will mount the data in `provider_uri` on `mount_path`. Then the server and client will communicate with flask and transporting data with this NFS. + + +If you have a local qlib data files and want to use the qlib data offline instead of online with client server framework. +It is also possible with specific config. +You can created such a config. `client_config_local.yml` + +.. code-block:: YAML + + provider_uri: /data/csdesign/qlib + calendar_provider: 'LocalCalendarProvider' + instrument_provider: 'LocalInstrumentProvider' + feature_provider: 'LocalFeatureProvider' + expression_provider: 'LocalExpressionProvider' + dataset_provider: 'LocalDatasetProvider' + provider: 'LocalProvider' + dataset_cache: 'SimpleDatasetCache' + local_cache_path: '~/.cache/qlib/' + +`provider_uri` is the directory of your local data. + +.. code-block:: python + + >>> import qlib + >>> qlib.init_from_yaml_conf('client_config_local.yml') + >>> from qlib.data import D + >>> D.features(['SH600001'], ['$close'], start_time='20180101', end_time='20190101').head() + 21232:MainThread](2019-05-29 10:16:05,066) INFO - Initialization - [__init__.py:16] - default_conf: client. + [21232:MainThread](2019-05-29 10:16:05,066) INFO - Initialization - [__init__.py:54] - qlib successfully initialized based on client settings. + [21232:MainThread](2019-05-29 10:16:05,067) INFO - Initialization - [__init__.py:56] - provider_uri=/data/csdesign/qlib + Out[9]: + $close + instrument datetime + SH600001 2008-01-02 21.082111 + 2008-01-03 23.195362 + 2008-01-04 23.874615 + 2008-01-07 24.880930 + 2008-01-08 24.277143 + +Limitations +----------- +1. The following API under the client-server module may not be as fast as the older off-line API. + - Cal.calendar + - Inst.list_instruments +2. The rolling operation expression with parameter `0` can not be updated rightly under mechanism of the client-server framework. + +API +******************** + +The client is based on `python-socketio`_ which is a framework that supports WebSocket client for Python language. The client can only propose requests and receive results, which do not include any calculating procedure. + +Class +-------------------- + +.. automodule:: qlib.data.client + + diff --git a/docs/hidden/online.rst b/docs/hidden/online.rst new file mode 100644 index 0000000000..da4fc99d47 --- /dev/null +++ b/docs/hidden/online.rst @@ -0,0 +1,285 @@ +.. _online: + +Online +=================== +.. currentmodule:: qlib + +Introduction +------------------- + +Welcome to use Online, this module simulates what will be like if we do the real trading use our model and strategy. + +Just like Estimator and other modules in Qlib, you need to determine parameters through the configuration file, +and in this module, you need to add an account in a folder to do the simulation. Then in each coming day, +this module will use the newest information to do the trade for your account, +the performance can be viewed at any time using the API we defined. + +Each account will experience the following processes, the ‘pred_date’ represents the date you predict the target +positions after trading, also, the ‘trade_date’ is the date you do the trading. + +- Generate the order list (pre_date) +- Execute the order list (trade_date) +- Update account (trade_date) + +In the meantime, you can just create an account and use this module to test its performance in a period. + +- Simulate (start_date, end_date) + +This module need to save your account in a folder, the model and strategy will be saved as pickle files, +and the position and report will be saved as excel. +The file structure can be viewed at fileStruct_. + + +Example +------------------- + +Let's take an example, + +.. note:: Make sure you have the latest version of `qlib` installed. + +If you want to use the models and data provided by `qlib`, you only need to do as follows. + +Firstly, write a simple configuration file as following, + +.. code-block:: YAML + + strategy: + class: TopkAmountStrategy + module_path: qlib.contrib.strategy + args: + market: csi500 + trade_freq: 5 + + model: + class: ScoreFileModel + module_path: qlib.contrib.online.online_model + args: + loss: mse + model_path: ./model.bin + + init_cash: 1000000000 + +We then can use this command to create a folder and do trading from 2017-01-01 to 2018-08-01. + +.. code-block:: bash + + online simulate -id v-test -config ./config/config.yaml -exchange_config ./config/exchange.yaml -start 2017-01-01 -end 2018-08-01 -path ./user_data/ + +The start date (2017-01-01) is the add date of the user, which also is the first predict date, +and the end date (2018-08-01) is the last trade date. You can use "`online generate -date 2018-08-02...`" +command to continue generate the order_list at next trading date. + +If Your account was saved in "./user_data/", you can see the performance of your account compared to a benchmark by + +.. code-block:: bash + + >> online show -id v-test -path ./user_data/ -bench SH000905 + + ... + Result of porfolio: + sub_bench: + risk + mean 0.001157 + std 0.003039 + annual 0.289131 + sharpe 6.017635 + mdd -0.013185 + sub_cost: + risk + mean 0.000800 + std 0.003043 + annual 0.199944 + sharpe 4.155963 + mdd -0.015517 + +Here 'SH000905' represents csi500 and 'SH000300' represents csi300 + +Manage your account +-------------------- + +Any account processed by `online` should be saved in a folder. you can use commands +defined to manage your accounts. + +- add an new account + This will add an new account with user_id='v-test', add_date='2019-10-15' in ./user_data. + + .. code-block:: bash + + >> online add_user -id {user_id} -config {config_file} -path {folder_path} -date {add_date} + >> online add_user -id v-test -config config.yaml -path ./user_data/ -date 2019-10-15 + +- remove an account + .. code-block:: bash + + >> online remove_user -id {user_id} -path {folder_path} + >> online remove_user -id v-test -path ./user_data/ + +- show the performance + Here benchmark indicates the baseline is to be compared with yours. + + .. code-block:: bash + + >> online show -id {user_id} -path {folder_path} -bench {benchmark} + >> online show -id v-test -path ./user_data/ -bench SH000905 + +The default value of all the parameter 'date' below is trade date +(will be today if today is trading date and information has been updated in `qlib`). + +The 'generate' and 'update' will check whether input date is valid, the following 3 processes should +be called at each trading date. + +- generate the order list + generate the order list at trade date, and save them in {folder_path}/{user_id}/temp/ as a json file. + + .. code-block:: bash + + >> online generate -date {date} -path {folder_path} + >> online generate -date 2019-10-16 -path ./user_data/ + +- execute the order list + execute the order list and generate the transactions result in {folder_path}/{user_id}/temp/ at trade date + + .. code-block:: bash + + >> online execute -date {date} -exchange_config {exchange_config_path} -path {folder_path} + >> online execute -date 2019-10-16 -exchange_config ./config/exchange.yaml -path ./user_data/ + + A simple exchange config file can be as + + .. code-block:: yaml + + open_cost: 0.003 + close_cost: 0.003 + limit_threshold: 0.095 + deal_price: vwap + + +- update accounts + update accounts in "{folder_path}/" at trade date + + .. code-block:: bash + + >> online update -date {date} -path {folder_path} + >> online update -date 2019-10-16 -path ./user_data/ + +API +------------------ + +All those operations are based on defined in `qlib.contrib.online.operator` + +.. automodule:: qlib.contrib.online.operator + +.. _fileStruct: + +File structure +------------------ + +'user_data' indicates the root of folder. +Name that bold indicates it’s a folder, otherwise it’s a document. + +.. code-block:: yaml + + {user_folder} + │ users.csv: (Init date for each users) + │ + └───{user_id1}: (users' sub-folder to save their data) + │ │ position.xlsx + │ │ report.csv + │ │ model_{user_id1}.pickle + │ │ strategy_{user_id1}.pickle + │ │ + │ └───score + │ │ └───{YYYY} + │ │ └───{MM} + │ │ │ score_{YYYY-MM-DD}.csv + │ │ + │ └───trade + │ └───{YYYY} + │ └───{MM} + │ │ orderlist_{YYYY-MM-DD}.json + │ │ transaction_{YYYY-MM-DD}.csv + │ + └───{user_id2} + │ │ position.xlsx + │ │ report.csv + │ │ model_{user_id2}.pickle + │ │ strategy_{user_id2}.pickle + │ │ + │ └───score + │ └───trade + .... + + +Configuration file +------------------ + +The configure file used in `online` should contain the model and strategy information. + +About the model +~~~~~~~~~~~~~~~~~~~~ + +First, your configuration file needs to have a field about the model, +this field and its contents determine the model we used when generating score at predict date. + +Followings are two examples for ScoreFileModel and a model that read a score file and return score at trade date. + +.. code-block:: YAML + + model: + class: ScoreFileModel + module_path: qlib.contrib.online.OnlineModel + args: + loss: mse + +.. code-block:: YAML + + model: + class: ScoreFileModel + module_path: qlib.contrib.online.OnlineModel + args: + score_path: + +If your model doesn't belong to above models, you need to coding your model manually. +Your model should be a subclass of models defined in 'qlib.contfib.model'. And it must +contains 2 methods used in `online` module. + + +About the strategy +~~~~~~~~~~~~~~~~~~~~ + +Your need define the strategy used to generate the order list at predict date. + +Followings are two examples for a TopkAmountStrategy + +.. code-block:: YAML + + strategy: + class: TopkAmountStrategy + module_path: qlib.contrib.strategy.strategy + args: + topk: 100 + buffer_margin: 300 + +Generated files +------------------ + +The 'online_generate' command will create the order list at {folder_path}/{user_id}/temp/, +the name of that is orderlist_{YYYY-MM-DD}.json, YYYY-MM-DD is the date that those orders to be executed. + +The format of json file is like + +.. code-block:: python + + { + 'sell': { + {'$stock_id1': '$amount1'}, + {'$stock_id2': '$amount2'}, ... + }, + 'buy': { + {'$stock_id1': '$amount1'}, + {'$stock_id2': '$amount2'}, ... + } + } + +Then after executing the order list (either by 'online_execute' or other executors), a transaction file +will be created also at {folder_path}/{user_id}/temp/. diff --git a/docs/hidden/tuner.rst b/docs/hidden/tuner.rst new file mode 100644 index 0000000000..35d606c9c1 --- /dev/null +++ b/docs/hidden/tuner.rst @@ -0,0 +1,327 @@ +.. _tuner: + +Tuner +=================== +.. currentmodule:: qlib + +Introduction +------------------- + +Welcome to use Tuner, this document is based on that you can use Estimator proficiently and correctly. + +You can find the optimal hyper-parameters and combinations of models, trainers, strategies and data labels. + +The usage of program `tuner` is similar with `estimator`, you need provide the URL of the configuration file. +The `tuner` will do the following things: + +- Construct tuner pipeline +- Search and save best hyper-parameters of one tuner +- Search next tuner in pipeline +- Save the global best hyper-parameters and combination + +Each tuner is consisted with a kind of combination of modules, and its goal is searching the optimal hyper-parameters of this combination. +The pipeline is consisted with different tuners, it is aim at finding the optimal combination of modules. + +The result will be printed on screen and saved in file, you can check the result in your experiment saving files. + +Example +~~~~~~~ + +Let's see an example, + +First make sure you have the latest version of `qlib` installed. + +Then, you need to privide a configuration to setup the experiment. +We write a simple configuration example as following, + +.. code-block:: YAML + + experiment: + name: tuner_experiment + tuner_class: QLibTuner + qlib_client: + auto_mount: False + logging_level: INFO + optimization_criteria: + report_type: model + report_factor: model_score + optim_type: max + tuner_pipeline: + - + model: + class: SomeModel + space: SomeModelSpace + trainer: + class: RollingTrainer + strategy: + class: TopkAmountStrategy + space: TopkAmountStrategySpace + max_evals: 2 + + time_period: + rolling_period: 360 + train_start_date: 2005-01-01 + train_end_date: 2014-12-31 + validate_start_date: 2015-01-01 + validate_end_date: 2016-06-30 + test_start_date: 2016-07-01 + test_end_date: 2018-04-30 + data: + class: ALPHA360 + provider_uri: /data/qlib + args: + start_date: 2005-01-01 + end_date: 2018-04-30 + dropna_label: True + dropna_feature: True + filter: + market: csi500 + filter_pipeline: + - + class: NameDFilter + module_path: qlib.data.filter + args: + name_rule_re: S(?!Z3) + fstart_time: 2018-01-01 + fend_time: 2018-12-11 + - + class: ExpressionDFilter + module_path: qlib.data.filter + args: + rule_expression: $open/$factor<=45 + fstart_time: 2018-01-01 + fend_time: 2018-12-11 + backtest: + normal_backtest_args: + verbose: False + limit_threshold: 0.095 + account: 500000 + benchmark: SH000905 + deal_price: vwap + long_short_backtest_args: + topk: 50 + +Next, we run the following command, and you can see: + +.. code-block:: bash + + ~/v-yindzh/Qlib/cfg$ tuner -c tuner_config.yaml + + Searching params: {'model_space': {'colsample_bytree': 0.8870905643607678, 'lambda_l1': 472.3188735122233, 'lambda_l2': 92.75390994877243, 'learning_rate': 0.09741751430635413, 'loss': 'mse', 'max_depth': 8, 'num_leaves': 160, 'num_threads': 20, 'subsample': 0.7536051584789751}, 'strategy_space': {'buffer_margin': 250, 'topk': 40}} + ... + (Estimator experiment screen log) + ... + Searching params: {'model_space': {'colsample_bytree': 0.6667379039007301, 'lambda_l1': 382.10698024977904, 'lambda_l2': 117.02506488151757, 'learning_rate': 0.18514539615228137, 'loss': 'mse', 'max_depth': 6, 'num_leaves': 200, 'num_threads': 12, 'subsample': 0.9449255686969292}, 'strategy_space': {'buffer_margin': 200, 'topk': 30}} + ... + (Estimator experiment screen log) + ... + Local best params: {'model_space': {'colsample_bytree': 0.6667379039007301, 'lambda_l1': 382.10698024977904, 'lambda_l2': 117.02506488151757, 'learning_rate': 0.18514539615228137, 'loss': 'mse', 'max_depth': 6, 'num_leaves': 200, 'num_threads': 12, 'subsample': 0.9449255686969292}, 'strategy_space': {'buffer_margin': 200, 'topk': 30}} + Time cost: 489.87220 | Finished searching best parameters in Tuner 0. + Time cost: 0.00069 | Finished saving local best tuner parameters to: tuner_experiment/estimator_experiment/estimator_experiment_0/local_best_params.json . + Searching params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 2',)}, 'model_space': {'input_dim': 158, 'lr': 0.001, 'lr_decay': 0.9100529502185579, 'lr_decay_steps': 162.48901403763966, 'optimizer': 'gd', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 300, 'topk': 35}} + ... + (Estimator experiment screen log) + ... + Searching params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 1',)}, 'model_space': {'input_dim': 158, 'lr': 0.1, 'lr_decay': 0.9882802970847494, 'lr_decay_steps': 164.76742865207729, 'optimizer': 'adam', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 250, 'topk': 35}} + ... + (Estimator experiment screen log) + ... + Local best params: {'data_label_space': {'labels': ('Ref($vwap, -2)/Ref($vwap, -1) - 1',)}, 'model_space': {'input_dim': 158, 'lr': 0.1, 'lr_decay': 0.9882802970847494, 'lr_decay_steps': 164.76742865207729, 'optimizer': 'adam', 'output_dim': 1}, 'strategy_space': {'buffer_margin': 250, 'topk': 35}} + Time cost: 550.74039 | Finished searching best parameters in Tuner 1. + Time cost: 0.00023 | Finished saving local best tuner parameters to: tuner_experiment/estimator_experiment/estimator_experiment_1/local_best_params.json . + Time cost: 1784.14691 | Finished tuner pipeline. + Time cost: 0.00014 | Finished save global best tuner parameters. + Best Tuner id: 0. + You can check the best parameters at tuner_experiment/global_best_params.json. + + +Finally, you can check the results of your experiment in the given path. + +Configuration file +------------------ + +Before using `tuner`, you need to prepare a configuration file. Next we will show you how to prepare each part of the configuration file. + +About the experiment +~~~~~~~~~~~~~~~~~~~~ + +First, your configuration file needs to have a field about the experiment, whose key is `experiment`, this field and its contents determine the saving path and tuner class. + +Usually it should contain the following content: + +.. code-block:: YAML + + experiment: + name: tuner_experiment + tuner_class: QLibTuner + +Also, there are some optional fields. The meaning of each field is as follows: + +- `name` + The experiment name, str type, the program will use this experiment name to construct a directory to save the process of the whole experiment and the results. The default value is `tuner_experiment`. + +- `dir` + The saving path, str type, the program will construct the experiment directory in this path. The default value is the path where configuration locate. + +- `tuner_class` + The class of tuner, str type, must be an already implemented model, such as `QLibTuner` in `qlib`, or a custom tuner, but it must be a subclass of `qlib.contrib.tuner.Tuner`, the default value is `QLibTuner`. + +- `tuner_module_path` + The module path, str type, absolute url is also supported, indicates the path of the implementation of tuner. The default value is `qlib.contrib.tuner.tuner` + +About the optimization criteria +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You need to designate a factor to optimize, for tuner need a factor to decide which case is better than other cases. +Usually, we use the result of `estimator`, such as backtest results and the score of model. + +This part needs contain these fields: + +.. code-block:: YAML + + optimization_criteria: + report_type: model + report_factor: model_pearsonr + optim_type: max + +- `report_type` + The type of the report, str type, determines which kind of report you want to use. If you want to use the backtest result type, you can choose `pred_long`, `pred_long_short`, `pred_short`, `sub_bench` and `sub_cost`. If you want to use the model result type, you can only choose `model`. + +- `report_factor` + The factor you want to use in the report, str type, determines which factor you want to optimize. If your `report_type` is backtest result type, you can choose `annual`, `sharpe`, `mdd`, `mean` and `std`. If your `report_type` is model result type, you can choose `model_score` and `model_pearsonr`. + +- `optim_type` + The optimization type, str type, determines what kind of optimization you want to do. you can minimize the factor or maximize the factor, so you can choose `max`, `min` or `correlation` at this field. + Note: `correlation` means the factor's best value is 1, such as `model_pearsonr` (a corraltion coefficient). + +If you want to process the factor or you want fetch other kinds of factor, you can override the `objective` method in your own tuner. + +About the tuner pipeline +~~~~~~~~~~~~~~~~~~~~~~~~ + +The tuner pipeline contains different tuners, and the `tuner` program will process each tuner in pipeline. Each tuner will get an optimal hyper-parameters of its specific combination of modules. The pipeline will contrast the results of each tuner, and get the best combination and its optimal hyper-parameters. So, you need to configurate the pipeline and each tuner, here is an example: + +.. code-block:: YAML + + tuner_pipeline: + - + model: + class: SomeModel + space: SomeModelSpace + trainer: + class: RollingTrainer + strategy: + class: TopkAmountStrategy + space: TopkAmountStrategySpace + max_evals: 2 + +Each part represents a tuner, and its modules which are to be tuned. Space in each part is the hyper-parameters' space of a certain module, you need to create your searching space and modify it in `/qlib/contrib/tuner/space.py`. We use `hyperopt` package to help us to construct the space, you can see the detail of how to use it in https://github.com/hyperopt/hyperopt/wiki/FMin . + +- model + You need to provide the `class` and the `space` of the model. If the model is user's own implementation, you need to privide the `module_path`. + +- trainer + You need to proveide the `class` of the trainer. If the trainer is user's own implementation, you need to privide the `module_path`. + +- strategy + You need to provide the `class` and the `space` of the strategy. If the strategy is user's own implementation, you need to privide the `module_path`. + +- data_label + The label of the data, you can search which kinds of labels will lead to a better result. This part is optional, and you only need to provide `space`. + +- max_evals + Allow up to this many function evaluations in this tuner. The default value is 10. + +If you don't want to search some modules, you can fix their spaces in `space.py`. We will not give the default module. + +About the time period +~~~~~~~~~~~~~~~~~~~~~ + +You need to use the same dataset to evaluate your different `estimator` experiments in `tuner` experiment. Two experiments using different dataset are uncomparable. You can specify `time_period` through the configuration file: + +.. code-block:: YAML + + time_period: + rolling_period: 360 + train_start_date: 2005-01-01 + train_end_date: 2014-12-31 + validate_start_date: 2015-01-01 + validate_end_date: 2016-06-30 + test_start_date: 2016-07-01 + test_end_date: 2018-04-30 + +- `rolling_period` + The rolling period, integer type, indicates how many time steps need rolling when rolling the data. The default value is `60`. If you use `RollingTrainer`, this config will be used, or it will be ignored. + +- `train_start_date` + Training start time, str type. + +- `train_end_date` + Training end time, str type. + +- `validate_start_date` + Validation start time, str type. + +- `validate_end_date` + Validation end time, str type. + +- `test_start_date` + Test start time, str type. + +- `test_end_date` + Test end time, str type. If `test_end_date` is `-1` or greater than the last date of the data, the last date of the data will be used as `test_end_date`. + +About the data and backtest +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +`data` and `backtest` are all same in the whole `tuner` experiment. Different `estimator` experiments must use the same data and backtest method. So, these two parts of config are same with that in `estimator` configuration. You can see the precise defination of these parts in `estimator` introduction. We only provide an example here. + +.. code-block:: YAML + + data: + class: ALPHA360 + provider_uri: /data/qlib + args: + start_date: 2005-01-01 + end_date: 2018-04-30 + dropna_label: True + dropna_feature: True + feature_label_config: /home/v-yindzh/v-yindzh/QLib/cfg/feature_config.yaml + filter: + market: csi500 + filter_pipeline: + - + class: NameDFilter + module_path: qlib.filter + args: + name_rule_re: S(?!Z3) + fstart_time: 2018-01-01 + fend_time: 2018-12-11 + - + class: ExpressionDFilter + module_path: qlib.filter + args: + rule_expression: $open/$factor<=45 + fstart_time: 2018-01-01 + fend_time: 2018-12-11 + backtest: + normal_backtest_args: + verbose: False + limit_threshold: 0.095 + account: 500000 + benchmark: SH000905 + deal_price: vwap + long_short_backtest_args: + topk: 50 + +Experiment Result +----------------- + +All the results are stored in experiment file directly, you can check them directly in the corresponding files. +What we save are as following: + +- Global optimal parameters +- Local optimal parameters of each tuner +- Config file of this `tuner` experiment +- Every `estimator` experiments result in the process + diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000000..3ad0cfdd71 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,57 @@ +============================================================ +QLib Documentation +============================================================ + +QLib is a Quantitative-research Library, which can provide research data with highly consistency, reusability and extensibility. + +.. _user_guide: + +Document Structure +==================== + +.. toctree:: + :hidden: + + Home + +.. toctree:: + :maxdepth: 3 + :caption: INTRODUCTION: + + Introduction + +.. toctree:: + :maxdepth: 3 + :caption: GETTING STARTED: + + Installation + Initialization + Data Retrieval + Integrate Custom Models + + +.. toctree:: + :maxdepth: 3 + :caption: ADVANCED FEATURES: + + Estimator: Workflow Management + Data: Data Framework&Usage + Model: Train&Predict + Strategy: Portfolio Management + Backtest: Model&Strategy Testing + Report: Graphical Results + Cache: Frequently-Used Data + Tuner: Tuner + + +.. toctree:: + :maxdepth: 3 + :caption: REFERENCE: + + Api + +.. toctree:: + :maxdepth: 3 + :caption: Change Log: + + Change Log diff --git a/docs/introduction/introduction.rst b/docs/introduction/introduction.rst new file mode 100644 index 0000000000..597ebdc4c3 --- /dev/null +++ b/docs/introduction/introduction.rst @@ -0,0 +1,45 @@ +=================== +Qlib +=================== + +Introduction +================== + +``Qlib`` is a an AI-oriented quantitative investment platform. aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment. + +With ``Qlib``, you can easily apply your favorite model to create better Quant investment strategy. + + +Framework +================== + +.. image:: ../_static/img/framework.png + :alt: Framework + + +At module level, Qlib is a platform that consists of the above components. Each components is loose-coupling and can be used stand-alone. + +====================== ======================================================================== +Name Description +====================== ======================================================================== +`Data layer` `DataServer` focus on providing high performance infrastructure for user + to retreive and get raw data. `DataEnhancement` will preprocess the data + and provide the best dataset to be fed in to the models. + +`Interday Model` `Interday model` focus on produce forecasting signals(aka. `alpha`). + Models are trained by `Model Creator` and managed by `Model Manager`. + User could choose one or multiple models for forecasting. Multiple models + could be combined with `Ensemble` module. + +`Interday Strategy` `Portfolio Generator` will take forecasting signals as input and output + the orders based on current position to achieve target portfolio. + +`Intraday Trading` `Order Executor` is responsible for executing orders output by + `Interday Strategy` and returning the executed results. + +`Analysis` User could get detailed analysis report of forecasting signal and portfolio + in this part. +====================== ======================================================================== + +- The modules with hand-drawn style is under development and will be released in the future. +- The modules with dashed border is highly user-customizable and extendible. \ No newline at end of file diff --git a/docs/reference/api.rst b/docs/reference/api.rst new file mode 100644 index 0000000000..23e5aee8cc --- /dev/null +++ b/docs/reference/api.rst @@ -0,0 +1,49 @@ +================================ +API Reference +================================ + + + +Here you can find all ``QLib`` interfaces. + + +Data +==================== + +Provider +-------------------- + +.. automodule:: qlib.data.data + :members: + +Filter +-------------------- + +.. automodule:: qlib.data.filter + :members: + +Feature +-------------------- + +Class +~~~~~~~~~~~~~~~~~~~~ +.. automodule:: qlib.data.base + :members: + +Operator +~~~~~~~~~~~~~~~~~~~~ +.. automodule:: qlib.data.ops + :members: + +Contrib +==================== + +Model +-------------------- +.. automodule:: qlib.contrib.model.base + :members: + +Evaluate +-------------------- +.. automodule:: qlib.contrib.evaluate + :members: \ No newline at end of file diff --git a/docs/start/getdata.rst b/docs/start/getdata.rst new file mode 100644 index 0000000000..8a2d297e16 --- /dev/null +++ b/docs/start/getdata.rst @@ -0,0 +1,141 @@ +.. _getdata: +============================= +Data Retrieval +============================= + +.. currentmodule:: qlib + +Introduction +==================== + +Users can get stock data by Qlib, the following are some examples. + +Examples +==================== + +Init qlib package: + +.. note:: In order to get the data, users need to initialize package qlib with qlib.init first. + +Please refer to `initialization `_ + + +It is recommended to use the following code to initialize qlib: + +.. code-block:: python + + >>> import qlib + >>> qlib.init(mount_path='~/.qlib/qlib_data/cn_data') + + +Load trading calendar with the given time range and frequency: + +.. code-block:: python + + >>> from qlib.data import D + >>> D.calendar(start_time='2010-01-01', end_time='2017-12-31', freq='day')[:2] + [Timestamp('2010-01-04 00:00:00'), Timestamp('2010-01-05 00:00:00')] + +Parse a given market name into a stockpool config: + +.. code-block:: python + + >>> from qlib.data import D + >>> D.instruments(market='all') + {'market': 'all', 'filter_pipe': []} + +Load instruments of certain stockpool in the given time range: + +.. code-block:: python + + >>> from qlib.data import D + >>> instruments = D.instruments(market='csi500') + >>> D.list_instruments(instruments=instruments, start_time='2010-01-01', end_time='2017-12-31', as_list=True)[:6] + ['SH600000', 'SH600003', 'SH600004', 'SH600005', 'SH600006', 'SH600007'] + +Load dynamic instruments from a base market according to a name filter + +.. code-block:: python + + >>> from qlib.data import D + >>> from qlib.data.filter import NameDFilter + >>> nameDFilter = NameDFilter(name_rule_re='SH[0-9]{4}55') + >>> instruments = D.instruments(market='csi500', filter_pipe=[nameDFilter]) + >>> D.list_instruments(instruments=instruments, start_time='2015-01-01', end_time='2016-02-15', as_list=True) + ['SH600655', 'SH600755', 'SH603355', 'SH603555'] + +Load dynamic instruments from a base market according to an expression filter + +.. code-block:: python + + >>> from qlib.data import D + >>> from qlib.data.filter import ExpressionDFilter + >>> expressionDFilter = ExpressionDFilter(rule_expression='$close>100') + >>> instruments = D.instruments(market='csi500', filter_pipe=[expressionDFilter]) + >>> D.list_instruments(instruments=instruments, start_time='2015-01-01', end_time='2016-02-15', as_list=True) + ['SH600601', 'SH600651', 'SH600654'] + +To know more about how to use the filter or how to build one's own filter, go to API Reference: `filter API <../reference/api.html#filter>`_ + +Load features of certain instruments in given time range: + +.. note:: This is not a recommended way to get features. + +.. code-block:: python + + >>> from qlib.data import D + >>> instruments = ['SH600000'] + >>> fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low'] + >>> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head() + $close $volume Ref($close,1) Mean($close,3) \ + instrument datetime + SH600000 2010-01-04 81.809998 17144536.0 NaN 81.809998 + 2010-01-05 82.419998 29827816.0 81.809998 82.114998 + 2010-01-06 80.800003 25070040.0 82.419998 81.676666 + 2010-01-07 78.989998 22077858.0 80.800003 80.736666 + 2010-01-08 79.879997 17019168.0 78.989998 79.889999 + + Sub($high,$low) + instrument datetime + SH600000 2010-01-04 2.741158 + 2010-01-05 3.049736 + 2010-01-06 1.621399 + 2010-01-07 2.856926 + 2010-01-08 1.930397 + 2010-01-08 1.930397 + +Load features of certain stockpool in given time range: + +.. note:: Since the server need to cache all-time data for your request stockpool and fields, it may take longer to process your request than before. But in the second time, your request will be processed and responded in a flash even if you change the timespan. + +.. code-block:: python + + >>> from qlib.data import D + >>> from qlib.data.filter import NameDFilter, ExpressionDFilter + >>> nameDFilter = NameDFilter(name_rule_re='SH[0-9]{4}55') + >>> expressionDFilter = ExpressionDFilter(rule_expression='($close/$factor)>100') + >>> instruments = D.instruments(market='csi500', filter_pipe=[nameDFilter, expressionDFilter]) + >>> fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low'] + >>> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head() + + $close $volume Ref($close, 1) \ + instrument datetime + SH600655 2015-06-15 4342.160156 258706.359375 4530.459961 + 2015-06-16 4409.270020 257349.718750 4342.160156 + 2015-06-17 4312.330078 235214.890625 4409.270020 + 2015-06-18 4086.729980 196772.859375 4312.330078 + 2015-06-19 3678.250000 182916.453125 4086.729980 + Mean($close, 3) high− low + instrument datetime + SH600655 2015-06-15 4480.743327 285.251465 + 2015-06-16 4427.296712 298.301270 + 2015-06-16 4354.586751 356.098145 + 2015-06-16 4269.443359 363.554932 + 2015-06-16 4025.770020 368.954346 + + +.. note:: When calling D.features() at client, use parameter 'disk_cache=0' to skip dataset cache, use 'disk_cache=1' to generate and use dataset cache. In addition, when calling at server, you can use 'disk_cache=2' to update the dataset cache. + +Api +==================== +To know more about how to use the Data, go to API Reference: `Data API <../reference/api.html#Data>`_ \ No newline at end of file diff --git a/docs/start/initialization.rst b/docs/start/initialization.rst new file mode 100644 index 0000000000..c9e2919cad --- /dev/null +++ b/docs/start/initialization.rst @@ -0,0 +1,51 @@ +.. _initialization: +==================== +Initialize Qlib +==================== + +.. currentmodule:: qlib + + +Initialize ``qlib`` Package +========================= + +Please execute the following process to initialize ``qlib`` Package: + +- Download and prepare the Data: execute the following command to download the stock data. + .. code-block:: bash + + python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data + + Know more about how to use get_data.py, refer to `Raw Data <../advanced/data.html#raw-data>`_. + + +- Run the initialization code: run the following code in python: + + .. code-block:: Python +<<<<<<< HEAD + from qlib.config import REG_CN, REG_US + mount_path = "~/.qlib/qlib_data/cn_data" # target_dir + qlib.init(mount_path=mount_path, region="REG_CN") + +======= + + import qlib + # region in [REG_CN, REG_US] + from qlib.config import REG_CN + mount_path = "~/.qlib/qlib_data/cn_data" # target_dir + qlib.init(mount_path=mount_path, region=REG_CN) + +>>>>>>> c9205cac41923fe695edf8bd5728613d5c2f55c2 + + +Parameters +=============================== + +In fact, in addition to'mount_path' and 'region', qlib.init has other parameters. The following are all the parameters of qlib.init: + +- ``mount_path``: type: str. The local directory where the data loaded by 'get_data.py' is stored. +- ``region``: type: str, optional parameter(default: `qlib.config.REG_CN`/'cn'>). If region == `qlib.config.REG_CN`, 'qlib' will be initialized in US stock mode. If region == `qlib.config.REG_US`, 'qlib' will be initialized in A-share mode. + + .. note:: + + The value of'region' should be consistent with the data stored in'mount_path'. Currently,'scripts/get_data.py' only supports downloading A-share data. If users need to use the US stock mode, they need to prepare their own US stock data and store it in'mount_path'. \ No newline at end of file diff --git a/docs/start/installation.rst b/docs/start/installation.rst new file mode 100644 index 0000000000..d7bb346394 --- /dev/null +++ b/docs/start/installation.rst @@ -0,0 +1,55 @@ +.. _installation: +==================== +Installation +==================== + +.. currentmodule:: qlib + + +How to Install Qlib +==================== + +``Qlib`` only supports Python3, and supports up to Python3.8. + +Please execute the following process to install ``Qlib``: + +- Change the directory to Qlib, and the file'setup.py' exists in the directory +- Then, execute the following command: + + .. code-block:: bash + + $ pip install numpy + $ pip install --upgrade cython + $ python setup.py install + + +.. note:: + It's recommended to use anaconda/miniconda to setup environment. + ``Qlib`` needs lightgbm and tensorflow packages, use pip to install them. + +.. note:: + Do not import qlib in the ``Qlib`` folder, otherwise errors may occur. + + + +Use the following code to confirm installation successful: + +.. code-block:: python + + >>> import qlib + >>> qlib.__version__ + + +.. + .. note:: Please read this documentation carefully since there are lots of changes in qlib. + +.. + .. note:: On client side, there are some configs you need to notice like the providers, flask_server, flask_port and mount_path. The default is built for 10.150.144.153 since the server data path is pre-mounted to the mount_path. Don't change these configs unless you have some special test purposes. + + +.. + .. note:: You can always refer to the server docs on http://10.150.144.154:10002 + + + + diff --git a/docs/start/integration.rst b/docs/start/integration.rst new file mode 100644 index 0000000000..3f37e5355f --- /dev/null +++ b/docs/start/integration.rst @@ -0,0 +1,140 @@ +========================================= +Integrate Custom Models into QLib +========================================= + +Introduction +=================== +The baseline of the qlib model includes lightgbm and dnn. In addition to using the default model, users can their own integrate custom models into qlib. + +In order to use the custom model, user can do as follows. + +- Define a custom model class, which should be a subclass of the `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_ +- Write a configuration file that describes the path and parameters of the custom model + +The following is an example of integrating a custom lightgbm model into qlib: + +Define a custom model class +=========================== +The Custom models need to inherit `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_ and override the methods in it. + +- Override the `__init__` method + - Qlib passes the initialized parameters to the \_\_init\_\_ method + - The parameter must be consistent with the hyperparameters in the configuration file. + - Code Example: In the following example, the hyperparameter filed of the configuration file should contain parameters such as ‘loss:mse’. + .. code-block:: Python + + def __init__(self, loss='mse', **kwargs): + if loss not in {'mse', 'binary'}: + raise NotImplementedError + self._scorer = mean_squared_error if loss == 'mse' else roc_auc_score + self._params.update(objective=loss, **kwargs) + self._model = None + +- Override the `fit` method + - Qlib calls the fit method to train the model + - The parameters must include training feature 'x_train', training label 'y_train', test feature 'x_valid', test label 'y_valid'at least. + - The parameters could include some optional parameters with default values, such as train weight 'w_train', test weight 'w_valid' and 'num_boost_round = 1000'. + - Code Example: In the following example, 'num_boost_round = 1000' is an optional parameter. + .. code-block:: Python + + def fit(self, x_train:pd.DataFrame, y_train:pd.DataFrame, x_valid:pd.DataFrame, y_valid:pd.DataFrame, + w_train:pd.DataFrame = None, w_valid:pd.DataFrame = None, num_boost_round = 1000, **kwargs): + + # Lightgbm need 1D array as its label + if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: + y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values) + else: + raise ValueError('LightGBM doesn\'t support multi-label training') + + w_train_weight = None if w_train is None else w_train.values + w_valid_weight = None if w_valid is None else w_valid.values + + dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight) + dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight) + self._model = lgb.train( + self._params, + dtrain, + num_boost_round=num_boost_round, + valid_sets=[dtrain, dvalid], + valid_names=['train', 'valid'], + **kwargs + ) + +- Override the `predict` method + - The parameters include the features of the test data and return the prediction labels + - Please refer to `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_ for the parameter types of the fit method + - Code Example:In the following example, user need to user dnn to predict the label(such as 'preds') of test data 'x_test' and return it. + .. code-block:: Python + + def predict(self, x_test:pd.DataFrame, **kwargs)-> numpy.ndarray: + if self._model is None: + raise ValueError('model is not fitted yet!') + return self._model.predict(x_test.values) + +- Override the `score` method + - The parameters include the features and label of the test data, and the return loss whose type is passed in the __init__ method + - Code Example:In the following example, user need to calculate the weighted loss with test data 'x_test', test label 'y_test' and the weight 'w_test'. + .. code-block:: Python + + def score(self, x_test:pd.Dataframe, y_test:pd.Dataframe, w_test:pd.DataFrame = None) -> float: + # Remove rows from x, y and w, which contain Nan in any columns in y_test. + x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test) + preds = self.predict(x_test) + w_test_weight = None if w_test is None else w_test.values + scorer = mean_squared_error if self.loss_type == 'mse' else roc_auc_score + return scorer(y_test.values, preds, sample_weight=w_test_weight) + +- Override the `save` method & `load` method + - The `save` method parameter include the a `filename` that represents an absolute path, user need to save model into the path. + - The `load` method parameter include the a `buffer` read from the `filename` passed in `save` method , user need to load model from the `buffer`. + - Code Example: + .. code-block:: Python + + def save(self, filename): + if self._model is None: + raise ValueError('model is not fitted yet!') + self._model.save_model(filename) + + def load(self, buffer): + self._model = lgb.Booster(params={'model_str': buffer.decode('utf-8')}) + + +Write the configuration +======================= + +The configuration file is described in detail in the `estimator <../advanced/estimator.html#Example>`_ document. In order to integrate the custom model into qlib, you need to modify the "model" field in the configuration file. + +- Example: The following example describes the ‘model’ field of configuration file about the custom lightgbm model mentioned above , where ‘module_path’ is the module path, ‘class’ is the class name, and ‘args’ is the hyperparameter passed into the __init__ method. All parameters in the field is passed to 'self._params' by '\*\*kwargs' in `__init__` except 'loss = mse'. + +.. code-block:: YAML + + model: + class: LGBModel + module_path: qlib.contrib.model.gbdt + args: + loss: mse + colsample_bytree: 0.8879 + learning_rate: 0.0421 + subsample: 0.8789 + lambda_l1: 205.6999 + lambda_l2: 580.9768 + max_depth: 8 + num_leaves: 210 + num_threads: 20 + +Test the custom model +===================== +Assuming that the configuration file is named test.yaml, user can run the following command to test the custom model: + +.. code-block:: bash + + estimator -c test.yaml + +.. note:: 'estimator' is a built-in command of our program. + +Also, 'Model' can also be tested as a single module. An example has been given in 'examples.estimator.train_backtest_analyze.ipynb'. + +Know More about 'Model' +===================== + +If user want to know more about 'model', please refer to document `Use 'Model' to Train&Predict <../advanced/model.rst>`_ and api `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_. \ No newline at end of file diff --git a/examples/estimator/analyze_from_estimator.ipynb b/examples/estimator/analyze_from_estimator.ipynb new file mode 100644 index 0000000000..2e489f3bfc --- /dev/null +++ b/examples/estimator/analyze_from_estimator.ipynb @@ -0,0 +1,261 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import json\n", + "import yaml\n", + "import pickle\n", + "from pathlib import Path\n", + "\n", + "import qlib\n", + "import pandas as pd\n", + "from qlib.config import REG_CN\n", + "from qlib.utils import exists_qlib_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# use default data\n", + "# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data\n", + "mount_path = \"~/.qlib/qlib_data/cn_data\" # target_dir\n", + "if not exists_qlib_data(mount_path):\n", + " print(f\"Qlib data is not found in {mount_path}\")\n", + " sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath(\"scripts\")))\n", + " from get_data import GetData\n", + " GetData().qlib_data_cn(mount_path)\n", + "qlib.init(mount_path=mount_path, region=REG_CN)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CUR_DIR = Path.cwd()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with CUR_DIR.joinpath('estimator_config.yaml').open() as fp:\n", + " estimator_name = yaml.load(fp, Loader=yaml.FullLoader)['experiment']['name']\n", + "with CUR_DIR.joinpath(estimator_name, 'exp_info.json').open() as fp:\n", + " latest_id = json.load(fp)['id']\n", + " \n", + "estimator_dir = CUR_DIR.joinpath(estimator_name, 'sacred', latest_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# read estimator result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pred_df = pd.read_pickle(estimator_dir.joinpath('pred.pkl'))\n", + "report_normal_df = pd.read_pickle(estimator_dir.joinpath('report_normal.pkl'))\n", + "report_normal_df.index.names = ['index']\n", + "\n", + "_report_long_short_df = pd.DataFrame(pd.read_pickle(estimator_dir.joinpath('report_long_short.pkl')), columns=['long_short'])\n", + "_report_long_df = pd.DataFrame(pd.read_pickle(estimator_dir.joinpath('report_long.pkl')), columns=['long'])\n", + "_report_short_df = pd.DataFrame(pd.read_pickle(estimator_dir.joinpath('report_short.pkl')), columns=['short'])\n", + "report_long_short_df = pd.concat([_report_long_short_df, _report_long_df, _report_short_df], axis=1)\n", + "\n", + "analysis_df = pd.read_pickle(estimator_dir.joinpath('analysis.pkl'))\n", + "positions = pickle.load(estimator_dir.joinpath('positions.pkl').open('rb'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# get label data from qlib" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from qlib.data import D\n", + "pred_df_dates = pred_df.index.get_level_values(level='datetime')\n", + "features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())\n", + "features_df.columns = ['label']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# analyze graphs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from qlib.contrib.report import analysis_model, analysis_position" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## analysis position" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### report" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_position.report_graph(report_normal_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### score IC" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pred_label = pd.concat([features_df, pred_df], axis=1, sort=True).reindex(features_df.index)\n", + "analysis_position.score_ic_graph(pred_label)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### cumulative return" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### risk analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "analysis_position.risk_analysis_graph(analysis_df, report_normal_df, report_long_short_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### rank label" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## analysis model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### model performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "analysis_model.model_performance_graph(pred_label)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/estimator/estimator_config.yaml b/examples/estimator/estimator_config.yaml new file mode 100644 index 0000000000..392e91e23c --- /dev/null +++ b/examples/estimator/estimator_config.yaml @@ -0,0 +1,56 @@ +experiment: + name: estimator_example + observer_type: file_storage + mode: train + +model: + class: LGBModel + module_path: qlib.contrib.model.gbdt + args: + loss: mse + colsample_bytree: 0.8879 + learning_rate: 0.0421 + subsample: 0.8789 + lambda_l1: 205.6999 + lambda_l2: 580.9768 + max_depth: 8 + num_leaves: 210 + num_threads: 20 +data: + class: QLibDataHandlerV1 + args: + dropna_label: True + filter: + market: csi500 +trainer: + class: StaticTrainer + args: + rolling_period: 360 + train_start_date: 2007-01-01 + train_end_date: 2014-12-31 + validate_start_date: 2015-01-01 + validate_end_date: 2016-12-31 + test_start_date: 2017-01-01 + test_end_date: 2020-08-01 +strategy: + class: TopkAmountStrategy + args: + topk: 50 + buffer_margin: 230 +backtest: + normal_backtest_args: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: SH000905 + deal_price: vwap + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 + long_short_backtest_args: + topk: 50 + +qlib_data: + # when testing, please modify the following parameters according to the specific environment + mount_path: "~/.qlib/qlib_data/cn_data" + region: "cn" diff --git a/examples/train_and_backtest.py b/examples/train_and_backtest.py new file mode 100644 index 0000000000..d33b152803 --- /dev/null +++ b/examples/train_and_backtest.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from pathlib import Path + +import qlib +import pandas as pd +from qlib.config import REG_CN +from qlib.contrib.model.gbdt import LGBModel +from qlib.contrib.estimator.handler import QLibDataHandlerV1 +from qlib.contrib.strategy.strategy import TopkAmountStrategy +from qlib.contrib.evaluate import ( + backtest as normal_backtest, + long_short_backtest, + risk_analysis, +) +from qlib.utils import exists_qlib_data + + +if __name__ == "__main__": + + # use default data + mount_path = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(mount_path): + print(f"Qlib data is not found in {mount_path}") + sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) + from get_data import GetData + + GetData().qlib_data_cn(mount_path) + + qlib.init(mount_path=mount_path, region=REG_CN) + + MARKET = "CSI500" + BENCHMARK = "SH000905" + + ################################### + # train model + ################################### + DATA_HANDLER_CONFIG = { + "dropna_label": True, + "start_date": "2007-01-01", + "end_date": "2020-08-01", + "market": MARKET, + } + + TRAINER_CONFIG = { + "train_start_date": "2007-01-01", + "train_end_date": "2014-12-31", + "validate_start_date": "2015-01-01", + "validate_end_date": "2016-12-31", + "test_start_date": "2017-01-01", + "test_end_date": "2020-08-01", + } + + # use default DataHandler + # custom DataHandler, refer to: TODO: DataHandler api url + x_train, y_train, x_validate, y_validate, x_test, y_test = QLibDataHandlerV1(**DATA_HANDLER_CONFIG).get_split_data( + **TRAINER_CONFIG + ) + + MODEL_CONFIG = { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + } + # use default model + # custom Model, refer to: TODO: Model api url + model = LGBModel(**MODEL_CONFIG) + model.fit(x_train, y_train, x_validate, y_validate) + _pred = model.predict(x_test) + _pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns) + + # backtest requires pred_score + pred_score = pd.DataFrame(index=_pred.index) + pred_score["score"] = _pred.iloc(axis=1)[0] + + # save pred_score to file + pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser() + pred_score_path.parent.mkdir(exist_ok=True, parents=True) + pred_score.to_pickle(pred_score_path) + + ################################### + # backtest + ################################### + STRATEGY_CONFIG = { + "topk": 50, + "buffer_margin": 230, + } + BACKTEST_CONFIG = { + "verbose": False, + "limit_threshold": 0.095, + "account": 100000000, + "benchmark": BENCHMARK, + "deal_price": "vwap", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, + } + + # use default strategy + # custom Strategy, refer to: TODO: Strategy api url + strategy = TopkAmountStrategy(**STRATEGY_CONFIG) + report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG) + + # long short backtest + long_short_reports = long_short_backtest(pred_score, topk=50) + + ################################### + # analyze + # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb + ################################### + analysis = dict() + analysis["pred_long"] = risk_analysis(long_short_reports["long"]) + analysis["pred_short"] = risk_analysis(long_short_reports["short"]) + analysis["pred_long_short"] = risk_analysis(long_short_reports["long_short"]) + analysis["sub_bench"] = risk_analysis(report_normal["return"] - report_normal["bench"]) + analysis["sub_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"]) + analysis_df = pd.concat(analysis) # type: pd.DataFrame + print(analysis_df) diff --git a/examples/train_backtest_analyze.ipynb b/examples/train_backtest_analyze.ipynb new file mode 100644 index 0000000000..3f98bfea70 --- /dev/null +++ b/examples/train_backtest_analyze.ipynb @@ -0,0 +1,377 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "import qlib\n", + "import pandas as pd\n", + "from qlib.config import REG_CN\n", + "from qlib.contrib.model.gbdt import LGBModel\n", + "from qlib.contrib.estimator.handler import QLibDataHandlerV1\n", + "from qlib.contrib.strategy.strategy import TopkAmountStrategy\n", + "from qlib.contrib.evaluate import (\n", + " backtest as normal_backtest,\n", + " long_short_backtest,\n", + " risk_analysis,\n", + ")\n", + "from qlib.utils import exists_qlib_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# use default data\n", + "# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data\n", + "mount_path = \"~/.qlib/qlib_data/cn_data\" # target_dir\n", + "if not exists_qlib_data(mount_path):\n", + " print(f\"Qlib data is not found in {mount_path}\")\n", + " sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath(\"scripts\")))\n", + " from get_data import GetData\n", + " GetData().qlib_data_cn(mount_path)\n", + "qlib.init(mount_path=mount_path, region=REG_CN)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MARKET = \"csi500\"\n", + "BENCHMARK = \"SH000905\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# train model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "outputPrepend" + ] + }, + "outputs": [], + "source": [ + "###################################\n", + "# train model\n", + "###################################\n", + "DATA_HANDLER_CONFIG = {\n", + " \"dropna_label\": True,\n", + " \"start_date\": \"2007-01-01\",\n", + " \"end_date\": \"2020-08-01\",\n", + " \"market\": MARKET,\n", + "}\n", + "\n", + "TRAINER_CONFIG = {\n", + " \"train_start_date\": \"2007-01-01\",\n", + " \"train_end_date\": \"2014-12-31\",\n", + " \"validate_start_date\": \"2015-01-01\",\n", + " \"validate_end_date\": \"2016-12-31\",\n", + " \"test_start_date\": \"2017-01-01\",\n", + " \"test_end_date\": \"2020-08-01\",\n", + "}\n", + "\n", + "# use default DataHandler\n", + "# custom DataHandler, refer to: TODO: DataHandler api url\n", + "x_train, y_train, x_validate, y_validate, x_test, y_test = QLibDataHandlerV1(\n", + " **DATA_HANDLER_CONFIG\n", + ").get_split_data(**TRAINER_CONFIG)\n", + "\n", + "\n", + "MODEL_CONFIG = {\n", + " \"loss\": \"mse\",\n", + " \"colsample_bytree\": 0.8879,\n", + " \"learning_rate\": 0.0421,\n", + " \"subsample\": 0.8789,\n", + " \"lambda_l1\": 205.6999,\n", + " \"lambda_l2\": 580.9768,\n", + " \"max_depth\": 8,\n", + " \"num_leaves\": 210,\n", + " \"num_threads\": 20,\n", + "}\n", + "# use default model\n", + "# custom Model, refer to: TODO: Model api url\n", + "model = LGBModel(**MODEL_CONFIG)\n", + "model.fit(x_train, y_train, x_validate, y_validate)\n", + "_pred = model.predict(x_test)\n", + "_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)\n", + "\n", + "# backtest requires pred_score\n", + "pred_score = pd.DataFrame(index=_pred.index)\n", + "pred_score[\"score\"] = _pred.iloc(axis=1)[0]\n", + "\n", + "# save pred_score to file\n", + "pred_score_path = Path(\"~/tmp/qlib/pred_score.pkl\").expanduser()\n", + "pred_score_path.parent.mkdir(exist_ok=True, parents=True)\n", + "pred_score.to_pickle(pred_score_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# backtest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "###################################\n", + "# backtest\n", + "###################################\n", + "STRATEGY_CONFIG = {\n", + " \"topk\": 50,\n", + " \"buffer_margin\": 230,\n", + "}\n", + "BACKTEST_CONFIG = {\n", + " \"verbose\": False,\n", + " \"limit_threshold\": 0.095,\n", + " \"account\": 100000000,\n", + " \"benchmark\": BENCHMARK,\n", + " \"deal_price\": \"vwap\",\n", + " \"open_cost\": 0.0005,\n", + " \"close_cost\": 0.0015,\n", + " \"min_cost\": 5,\n", + " \n", + "}\n", + "\n", + "# use default strategy\n", + "# custom Strategy, refer to: TODO: Strategy api url\n", + "strategy = TopkAmountStrategy(**STRATEGY_CONFIG)\n", + "report_normal, positions_normal = normal_backtest(\n", + " pred_score, strategy=strategy, **BACKTEST_CONFIG\n", + ")\n", + "\n", + "# long short backtest\n", + "long_short_reports = long_short_backtest(pred_score, topk=50)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# analyze" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "###################################\n", + "# analyze\n", + "# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb\n", + "###################################\n", + "analysis = dict()\n", + "analysis[\"pred_long\"] = risk_analysis(long_short_reports[\"long\"])\n", + "analysis[\"pred_short\"] = risk_analysis(long_short_reports[\"short\"])\n", + "analysis[\"pred_long_short\"] = risk_analysis(long_short_reports[\"long_short\"])\n", + "analysis[\"sub_bench\"] = risk_analysis(report_normal[\"return\"] - report_normal[\"bench\"])\n", + "analysis[\"sub_cost\"] = risk_analysis(\n", + " report_normal[\"return\"] - report_normal[\"bench\"] - report_normal[\"cost\"]\n", + ")\n", + "analysis_df = pd.concat(analysis) # type: pd.DataFrame\n", + "print(analysis_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# analyze graphs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from qlib.contrib.report import analysis_model, analysis_position" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# get label data\n", + "from qlib.data import D\n", + "pred_df_dates = pred_score.index.get_level_values(level='datetime')\n", + "features_df = D.features(D.instruments(MARKET), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())\n", + "features_df.columns = ['label']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## analysis position" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### report" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_position.report_graph(report_normal)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### score IC" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pred_label = pd.concat([features_df, pred_score], axis=1, sort=True).reindex(features_df.index)\n", + "analysis_position.score_ic_graph(pred_label)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### cumulative return" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_position.cumulative_return_graph(positions_normal, report_normal, features_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### risk analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_long_short_df = pd.concat(long_short_reports, axis=1)\n", + "analysis_position.risk_analysis_graph(analysis_df, report_normal, report_long_short_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### rank label" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_position.rank_label_graph(positions_normal, features_df, pred_df_dates.min(), pred_df_dates.max())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## analysis model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### model performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "analysis_model.model_performance_graph(pred_label)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/qlib/__init__.py b/qlib/__init__.py new file mode 100644 index 0000000000..3007f96312 --- /dev/null +++ b/qlib/__init__.py @@ -0,0 +1,196 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +__version__ = "0.4.6.dev" + +import os +import copy +import logging +import re +import subprocess +import platform +from pathlib import Path + +from .utils import can_use_cache + + +# init qlib +def init(default_conf="client", **kwargs): + from .config import ( + C, + _default_client_config, + _default_server_config, + _default_region_config, + REG_CN, + ) + from .data.data import register_all_wrappers + from .log import get_module_logger, set_log_with_config + + _logging_config = C.logging_config + if "logging_config" in kwargs: + _logging_config = kwargs["logging_config"] + + # set global config + if _logging_config: + set_log_with_config(_logging_config) + + LOG = get_module_logger("Initialization", level=logging.INFO) + LOG.info(f"default_conf: {default_conf}.") + if default_conf == "server": + base_config = copy.deepcopy(_default_server_config) + elif default_conf == "client": + base_config = copy.deepcopy(_default_client_config) + else: + raise ValueError("Unknown system type") + if base_config: + base_config.update(_default_region_config[kwargs.get("region", REG_CN)]) + for k, v in base_config.items(): + C[k] = v + + for k, v in kwargs.items(): + C[k] = v + if k not in C: + LOG.warning("Unrecognized config %s" % k) + + if default_conf == "client": + C["mount_path"] = str(Path(C["mount_path"]).expanduser().resolve()) + if not (C["expression_cache"] is None and C["dataset_cache"] is None): + # check redis + if not can_use_cache(): + LOG.warning( + f"redis connection failed(host={C['redis_host']} port={C['redis_port']}), cache will not be used!" + ) + C["expression_cache"] = None + C["dataset_cache"] = None + + # check path if server/local + if re.match("^[^/ ]+:.+", C["provider_uri"]) is None: + if not os.path.exists(C["provider_uri"]): + if C["auto_mount"]: + LOG.error( + "Invalid provider uri: {}, please check if a valid provider uri has been set. This path does not exist.".format( + C["provider_uri"] + ) + ) + else: + LOG.warning("auto_path is False, please make sure {} is mounted".format(C["mount_path"])) + else: + mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"]) + # If the provider uri looks like this 172.23.233.89//data/csdesign' + # It will be a nfs path. The client provider will be used + if not C["auto_mount"]: + if not os.path.exists(C["mount_path"]): + raise FileNotFoundError( + "Invalid mount path: {}! Please mount manually: {} or Set init parameter `auto_mount=True`".format( + C["mount_path"], mount_command + ) + ) + else: + # Judging system type + sys_type = platform.system() + if "win" in sys_type.lower(): + # system: window + exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":")) + result = exec_result.read() + if "85" in result: + LOG.warning("already mounted or window mount path already exists") + elif "53" in result: + raise OSError("not find network path") + elif "error" in result or "错误" in result: + raise OSError("Invalid mount path") + elif C["provider_uri"] in result: + LOG.info("window success mount..") + else: + raise OSError(f"unknown error: {result}") + + # config mount path + C["mount_path"] = C["mount_path"] + ":\\" + else: + # system: linux/Unix/Mac + # check mount + _remote_uri = C["provider_uri"] + _remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri + _mount_path = C["mount_path"] + _mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path + _check_level_num = 2 + _is_mount = False + while _check_level_num: + with subprocess.Popen( + 'mount | grep "{}"'.format(_remote_uri), + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) as shell_r: + _command_log = shell_r.stdout.readlines() + if len(_command_log) > 0: + for _c in _command_log: + _temp_mount = _c.decode("utf-8").split(" ")[2] + _temp_mount = _temp_mount[:-1] if _temp_mount.endswith("/") else _temp_mount + if _temp_mount == _mount_path: + _is_mount = True + break + if _is_mount: + break + _remote_uri = "/".join(_remote_uri.split("/")[:-1]) + _mount_path = "/".join(_mount_path.split("/")[:-1]) + _check_level_num -= 1 + + if not _is_mount: + try: + os.makedirs(C["mount_path"], exist_ok=True) + except Exception: + raise OSError( + "Failed to create directory {}, please create {} manually!".format( + C["mount_path"], C["mount_path"] + ) + ) + + # check nfs-common + command_res = os.popen("dpkg -l | grep nfs-common") + command_res = command_res.readlines() + if not command_res: + raise OSError( + "nfs-common is not found, please install it by execute: sudo apt install nfs-common" + ) + # manually mount + command_status = os.system(mount_command) + if command_status == 256: + raise OSError( + "mount {} on {} error! Needs SUDO! Please mount manually: {}".format( + C["provider_uri"], C["mount_path"], mount_command + ) + ) + elif command_status == 32512: + # LOG.error("Command error") + raise OSError("mount {} on {} error! Command error".format(C["provider_uri"], C["mount_path"])) + elif command_status == 0: + LOG.info("Mount finished") + else: + LOG.warning("{} on {} is already mounted".format(_remote_uri, _mount_path)) + + LOG.info("qlib successfully initialized based on %s settings." % default_conf) + register_all_wrappers() + try: + if C["auto_mount"]: + LOG.info(f"provider_uri={C['provider_uri']}") + else: + LOG.info(f"mount_path={C['mount_path']}") + except KeyError: + LOG.info(f"provider_uri={C['provider_uri']}") + + if "flask_server" in C: + LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}") + + +def init_from_yaml_conf(conf_path): + """init_from_yaml_conf + + :param conf_path: A path to the qlib config in yml format + """ + import yaml + + with open(conf_path) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + default_conf = config.pop("default_conf", "client") + init(default_conf, **config) diff --git a/qlib/config.py b/qlib/config.py new file mode 100644 index 0000000000..6c64c2ba2d --- /dev/null +++ b/qlib/config.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +# REGION CONST +REG_CN = "cn" +REG_US = "US" + +_default_config = { + # data provider config + "calendar_provider": "LocalCalendarProvider", + "instrument_provider": "LocalInstrumentProvider", + "feature_provider": "LocalFeatureProvider", + "expression_provider": "LocalExpressionProvider", + "dataset_provider": "LocalDatasetProvider", + "provider": "LocalProvider", + # config it in qlib.init() + "provider_uri": "", + # cache + "expression_cache": None, + "dataset_cache": None, + "calendar_cache": None, + # for simple dataset cache + "local_cache_path": None, + "kernels": 16, + # How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data. + "maxtasksperchild": None, + "default_disk_cache": 1, # 0:skip/1:use + "disable_disk_cache": False, # disable disk cache; if High-frequency data generally disable_disk_cache=True + "mem_cache_size_limit": 500, + # memory cache expire second, only in used 'ClientDatasetCache' and 'client D.calendar' + # default 1 hour + "mem_cache_expire": 60 * 60, + # memory cache space limit, default 5GB, only in used client + "mem_cache_space_limit": 1024 * 1024 * 1024 * 5, + # cache dir name + "dataset_cache_dir_name": "dataset_cache", + "features_cache_dir_name": "features_cache", + # redis + # in order to use cache + "redis_host": "127.0.0.1", + "redis_port": 6379, + "redis_task_db": 1, + # This value can be reset via qlib.init + "logging_level": "INFO", + # Global configuration of qlib log + # logging_level can control the logging level more finely + "logging_config": { + "version": 1, + "formatters": { + "logger_format": { + "format": "[%(process)s:%(threadName)s](%(asctime)s) %(levelname)s - %(name)s - [%(filename)s:%(lineno)d] - %(message)s" + } + }, + "filters": { + "field_not_found": { + "()": "qlib.log.LogFilter", + "param": [".*?WARN: data not found for.*?"], + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "logger_format", + "filters": ["field_not_found"], + } + }, + "loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}}, + }, +} + +_default_server_config = { + # data provider config + "calendar_provider": "LocalCalendarProvider", + "instrument_provider": "LocalInstrumentProvider", + "feature_provider": "LocalFeatureProvider", + "expression_provider": "LocalExpressionProvider", + "dataset_provider": "LocalDatasetProvider", + "provider": "LocalProvider", + # config it in qlib.init() + "provider_uri": "", + # redis + "redis_host": "127.0.0.1", + "redis_port": 6379, + "redis_task_db": 1, + "kernels": 64, + # cache + "expression_cache": "ServerExpressionCache", + "dataset_cache": "ServerDatasetCache", +} + +_default_client_config = { + # data provider config + "calendar_provider": {"class": "LocalCalendarProvider", "kwargs": {"remote": True}}, + "instrument_provider": "LocalInstrumentProvider", + "feature_provider": {"class": "LocalFeatureProvider", "kwargs": {"remote": True}}, + "expression_provider": "LocalExpressionProvider", + "dataset_provider": "LocalDatasetProvider", + "provider": "LocalProvider", + # config it in user's own code + "provider_uri": "~/.qlib/qlib_data/cn_data", + # cache + # Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled. + "expression_cache": {"class": "ServerExpressionCache", "kwargs": {"remote": True}}, + "dataset_cache": {"class": "ServerDatasetCache", "kwargs": {"remote": True}}, + "calendar_cache": None, + # client config + "kernels": 16, + "mount_path": "~/.qlib/qlib_data/cn_data", + "auto_mount": False, # The nfs is already mounted on our server[auto_mount: False]. + # The nfs should be auto-mounted by qlib on other + # serversS(such as PAI) [auto_mount:True] + "timeout": 100, + "logging_level": "INFO", + "region": REG_CN, +} + + +_default_region_config = { + REG_CN: { + "trade_unit": 100, + "limit_threshold": 0.1, + "deal_price": "vwap", + }, + REG_US: { + "trade_unit": 1, + "limit_threshold": None, + "deal_price": "close", + }, +} + + +class Config: + def __getitem__(self, key): + return _default_config[key] + + def __getattr__(self, attr): + try: + return _default_config[attr] + except KeyError: + return AttributeError(f"No such {attr} in _default_config") + + def __setitem__(self, key, value): + _default_config[key] = value + + def __setattr__(self, attr, value): + _default_config[attr] = value + + def __contains__(self, item): + return item in _default_config + + def __getstate__(self): + return _default_config + + def __setstate__(self, state): + _default_config.update(state) + + def __str__(self): + return str(_default_config) + + def __repr__(self): + return str(_default_config) + + +# global config +C = Config() diff --git a/qlib/contrib/__init__.py b/qlib/contrib/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/qlib/contrib/backtest/__init__.py b/qlib/contrib/backtest/__init__.py new file mode 100644 index 0000000000..31746819cd --- /dev/null +++ b/qlib/contrib/backtest/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# -*- coding: utf-8 -*- +from .order import Order +from .account import Account +from .position import Position +from .exchange import Exchange +from .report import Report diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py new file mode 100644 index 0000000000..4335a6af24 --- /dev/null +++ b/qlib/contrib/backtest/account.py @@ -0,0 +1,174 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import copy + +from .position import Position +from .report import Report +from .order import Order + + +""" +rtn & earning in the Account + rtn: + from order's view + 1.change if any order is executed, sell order or buy order + 2.change at the end of today, (today_clse - stock_price) * amount + earning + from value of current position + earning will be updated at the end of trade date + earning = today_value - pre_value + **is consider cost** + while earning is the difference of two position value, so it considers cost, it is the true return rate + in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning +""" + + +class Account: + def __init__(self, init_cash, last_trade_date=None): + self.init_vars(init_cash, last_trade_date) + + def init_vars(self, init_cash, last_trade_date=None): + # init cash + self.init_cash = init_cash + self.current = Position(cash=init_cash) + self.positions = {} + self.rtn = 0 + self.ct = 0 + self.to = 0 + self.val = 0 + self.report = Report() + self.earning = 0 + self.last_trade_date = last_trade_date + + def get_positions(self): + return self.positions + + def get_cash(self): + return self.current.position["cash"] + + def update_state_from_order(self, order, trade_val, cost, trade_price): + # update cash + if order.direction == Order.SELL: # 0 for sell + self.current.position["cash"] += trade_val - cost + elif order.direction == Order.BUY: # 1 for buy + self.current.position["cash"] -= trade_val + cost + else: + raise NotImplementedError("{} ".format(order.direction)) + # update turnover + self.to += trade_val + # update cost + self.ct += cost + # update return + # update self.rtn from order + if order.direction == Order.SELL: # 0 for sell + # when sell stock, get profit from price change + profit = trade_val - self.current.get_stock_price(order.stock_id) * order.deal_amount + self.rtn += profit # note here do not consider cost + elif order.direction == Order.BUY: # 1 for buy + # when buy stock, we get return for the rtn computing method + # profit in buy order is to make self.rtn is consistent with self.earning at the end of date + profit = self.current.get_stock_price(order.stock_id) * order.deal_amount - trade_val + self.rtn += profit + + def update_order(self, order, trade_val, cost, trade_price): + # if stock is sold out, no stock price information in Position, then we should update account first, then update current position + # if stock is bought, there is no stock in current position, update current, then update account + if order.direction == Order.SELL: + # sell stock + self.update_state_from_order(order, trade_val, cost, trade_price) + # update current position + # for may sell all of stock_id + self.current.update_order(order, trade_price) + else: + # buy stock + # deal order, then update state + self.current.update_order(order, trade_price) + self.update_state_from_order(order, trade_val, cost, trade_price) + + def update_daily_end(self, today, trader): + """ + today: pd.TimeStamp + quote: pd.DataFrame (code, date), collumns + when the end of trade date + - update rtn + - update price for each asset + - update value for this account + - update earning (2nd view of return ) + - update holding day, count of stock + - update position hitory + - update report + :return: None + """ + # update price for stock in the position and the profit from changed_price + stock_list = self.current.get_stock_list() + profit = 0 + for code in stock_list: + # if suspend, no new price to be updated, profit is 0 + if trader.check_stock_suspended(code, today): + continue + else: + today_close = trader.get_close(code, today) + profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"] + self.current.update_stock_price(stock_id=code, price=today_close) + self.rtn += profit + # update holding day count + self.current.add_count_all() + # update value + self.val = self.current.calculate_value() + # update earning (2nd view of return) + # account_value - last_account_value + # for the first trade date, account_value - init_cash + # self.report.is_empty() to judge is_first_trade_date + # get last_account_value, today_account_value, today_stock_value + if self.report.is_empty(): + last_account_value = self.init_cash + else: + last_account_value = self.report.get_latest_account_value() + today_account_value = self.current.calculate_value() + today_stock_value = self.current.calculate_stock_value() + self.earning = today_account_value - last_account_value + # update report for today + # judge whether the the trading is begin. + # and don't add init account state into report, due to we don't have excess return in those days. + self.report.update_report_record( + trade_date=today, + account_value=today_account_value, + cash=self.current.position["cash"], + return_rate=(self.earning + self.ct) / last_account_value, + # here use earning to calculate return, position's view, earning consider cost, true return + # in order to make same definition with original backtest in evaluate.py + turnover_rate=self.to / last_account_value, + cost_rate=self.ct / last_account_value, + stock_value=today_stock_value, + ) + # set today_account_value to position + self.current.position["today_account_value"] = today_account_value + self.current.update_weight_all() + # update positions + # note use deepcopy + self.positions[today] = copy.deepcopy(self.current) + + # finish today's updation + # reset the daily variables + self.rtn = 0 + self.ct = 0 + self.to = 0 + self.last_trade_date = today + + def load_account(self, account_path): + report = Report() + position = Position() + last_trade_date = position.load_position(account_path / "position.xlsx") + report.load_report(account_path / "report.csv") + + # assign values + self.init_vars(position.init_cash) + self.current = position + self.report = report + self.last_trade_date = last_trade_date if last_trade_date else None + + def save_account(self, account_path): + self.current.save_position(account_path / "position.xlsx", self.last_trade_date) + self.report.save_report(account_path / "report.csv") diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py new file mode 100644 index 0000000000..ea7220133d --- /dev/null +++ b/qlib/contrib/backtest/backtest.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import numpy as np +import pandas as pd +from ...utils import get_date_by_shift, get_date_range +from ..online.executor import SimulatorExecutor +from ...data import D +from .account import Account +from ...config import C +from ...log import get_module_logger + +LOG = get_module_logger("backtest") + + +def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark): + """Parameters + ---------- + pred : pandas.DataFrame + predict should has index and one `score` column + strategy : Strategy() + strategy part for backtest + trade_exchange : Exchange() + exchage for backtest + shift : int + whether to shift prediction by one day + verbose : bool + whether to print log + account : float + init account value + benchmark : str/list/pd.Series + `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T. + example: + print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()) + 2017-01-04 0.011693 + 2017-01-05 0.000721 + 2017-01-06 -0.004322 + 2017-01-09 0.006874 + 2017-01-10 -0.003350 + + `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'. + `benchmark` is str, will use the daily change as the 'bench'. + benchmark code, default is SH000905 CSI500 + """ + trade_account = Account(init_cash=account) + _pred_dates = pred.index.get_level_values(level="datetime") + predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max()) + if isinstance(benchmark, pd.Series): + bench = benchmark + else: + _codes = benchmark if isinstance(benchmark, list) else [benchmark] + _temp_result = D.features( + _codes, + ["$close/Ref($close,1)-1"], + predict_dates[0], + get_date_by_shift(predict_dates[-1], shift=shift), + disk_cache=1, + ) + bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean() + + trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift)) + executor = SimulatorExecutor(trade_exchange, verbose=verbose) + + # trading apart + for pred_date, trade_date in zip(predict_dates, trade_dates): + # for loop predict date and trading date + # print + if verbose: + LOG.info("[I {:%Y-%m-%d}]: trade begin.".format(trade_date)) + + # 1. Load the score_series at pred_date + try: + score = pred.loc(axis=0)[:, pred_date] # (stock_id, trade_date) multi_index, score in pdate + score_series = score.reset_index(level="datetime", drop=True)[ + "score" + ] # pd.Series(index:stock_id, data: score) + except KeyError: + LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date)) + score_series = None + + if score_series is not None and score_series.count() > 0: # in case of the scores are all None + # 2. Update your strategy (and model) + strategy.update(score_series, pred_date, trade_date) + + # 3. Generate order list + order_list = strategy.generate_order_list( + score_series=score_series, + current=trade_account.current, + trade_exchange=trade_exchange, + pred_date=pred_date, + trade_date=trade_date, + ) + else: + order_list = [] + # 4. Get result after executing order list + # NOTE: The following operation will modify order.amount. + # NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated + trade_info = executor.execute(trade_account, order_list, trade_date) + + # 5. Update account information according to transaction + update_account(trade_account, trade_info, trade_exchange, trade_date) + + # generate backtest report + report_df = trade_account.report.generate_report_dataframe() + report_df["bench"] = bench + positions = trade_account.get_positions() + return report_df, positions + + +def update_account(trade_account, trade_info, trade_exchange, trade_date): + """Update the account and strategy + Parameters + ---------- + trade_account : Account() + trade_info : list of [Order(), float, float, float] + (order, trade_val, trade_cost, trade_price), trade_info with out factor + trade_exchange : Exchange() + used to get the $close_price at trade_date to update account + trade_date : pd.Timestamp + """ + # update account + for [order, trade_val, trade_cost, trade_price] in trade_info: + if order.deal_amount == 0: + continue + trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price) + # at the end of trade date, update the account based the $close_price of stocks. + trade_account.update_daily_end(today=trade_date, trader=trade_exchange) diff --git a/qlib/contrib/backtest/exchange.py b/qlib/contrib/backtest/exchange.py new file mode 100644 index 0000000000..68a5067185 --- /dev/null +++ b/qlib/contrib/backtest/exchange.py @@ -0,0 +1,430 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import random +import logging + +import numpy as np +import pandas as pd + +from ...data import D +from .order import Order +from ...config import C, REG_CN +from ...log import get_module_logger + + +class Exchange: + def __init__( + self, + trade_dates=None, + codes="all", + deal_price=None, + subscribe_fields=[], + limit_threshold=None, + open_cost=0.0015, + close_cost=0.0025, + trade_unit=None, + min_cost=5, + extra_quote=None, + ): + """__init__ + + :param trade_dates: list of pd.Timestamp + :param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50) + :param deal_price: str, 'close', 'open', 'vwap' + :param subscribe_fields: list, subscribe fields + :param limit_threshold: float, 0.1 for example, default None + :param open_cost: cost rate for open, default 0.0015 + :param close_cost: cost rate for close, default 0.0025 + :param trade_unit: trade unit, 100 for China A market + :param min_cost: min cost, default 5 + :param extra_quote: pandas, dataframe consists of + columns: like ['$vwap', '$close', '$factor', 'limit']. + The limit indicates that the etf is tradable on a specific day. + Necessary fields: + $close is for calculating the total value at end of each day. + Optional fields: + $vwap is only necessary when we use the $vwap price as the deal price + $factor is for rounding to the trading unit + limit will be set to False by default(False indicates we can buy this + target on this day). + index: MultipleIndex(instrument, pd.Datetime) + """ + if trade_unit is None: + trade_unit = C.trade_unit + if limit_threshold is None: + limit_threshold = C.limit_threshold + if deal_price is None: + deal_price = C.deal_price + + self.logger = get_module_logger("online operator", level=logging.INFO) + + self.trade_unit = trade_unit + + # TODO: the quote, trade_dates, codes are not necessray. + # It is just for performance consideration. + if limit_threshold is None: + if C.region == REG_CN: + self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold") + elif abs(limit_threshold) > 0.1: + if C.region == REG_CN: + self.logger.warning(f"limit_threshold may not be set to a reasonable value") + + if deal_price[0] != "$": + self.deal_price = "$" + deal_price + else: + self.deal_price = deal_price + if isinstance(codes, str): + codes = D.instruments(codes) + self.codes = codes + # Necessary fields + # $close is for calculating the total value at end of each day. + # $factor is for rounding to the trading unit + # $change is for calculating the limit of the stock + + necessary_fields = {self.deal_price, "$close", "$change", "$factor"} + subscribe_fields = list(necessary_fields | set(subscribe_fields)) + all_fields = list(necessary_fields | set(subscribe_fields)) + self.all_fields = all_fields + self.open_cost = open_cost + self.close_cost = close_cost + self.min_cost = min_cost + self.limit_threshold = limit_threshold + # TODO: the quote, trade_dates, codes are not necessray. + # It is just for performance consideration. + if trade_dates is not None and len(trade_dates): + start_date, end_date = trade_dates[0], trade_dates[-1] + else: + self.logger.warning("trade_dates have not been assigned, all dates will be loaded") + start_date, end_date = None, None + + self.extra_quote = extra_quote + self.set_quote(codes, start_date, end_date) + + def set_quote(self, codes, start_date, end_date): + if len(codes) == 0: + codes = D.instruments() + self.quote = D.features(codes, self.all_fields, start_date, end_date, disk_cache=True).dropna(subset=["$close"]) + self.quote.columns = self.all_fields + + if self.quote[self.deal_price].isna().any(): + self.logger.warning("{} field data contains nan.".format(self.deal_price)) + + if self.quote["$factor"].isna().any(): + # The 'factor.day.bin' file not exists, and `factor` field contains `nan` + # Use adjusted price + self.trade_w_adj_price = True + self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.") + else: + # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan` + # Use normal price + self.trade_w_adj_price = False + # update limit + # check limit_threshold + if self.limit_threshold is None: + self.quote["limit"] = False + else: + # set limit + self._update_limit(buy_limit=self.limit_threshold, sell_limit=self.limit_threshold) + + quote_df = self.quote + if self.extra_quote is not None: + # process extra_quote + if "$close" not in self.extra_quote: + raise ValueError("$close is necessray in extra_quote") + if self.deal_price not in self.extra_quote.columns: + self.extra_quote[self.deal_price] = self.extra_quote["$close"] + self.logger.warning("No deal_price set for extra_quote. Use $close as deal_price.") + if "$factor" not in self.extra_quote.columns: + self.extra_quote["$factor"] = 1.0 + self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.") + if "limit" not in self.extra_quote.columns: + self.extra_quote["limit"] = False + self.logger.warning("No limit set for extra_quote. All stock will be tradable.") + assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"} + quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0) + + # update quote: pd.DataFrame to dict, for search use + self.quote = quote_df.to_dict("index") + + def _update_limit(self, buy_limit, sell_limit): + self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit) + + def check_stock_limit(self, stock_id, trade_date): + """Parameter + stock_id + trade_date + is limtited + """ + return self.quote[(stock_id, trade_date)]["limit"] + + def check_stock_suspended(self, stock_id, trade_date): + # is suspended + return (stock_id, trade_date) not in self.quote + + def is_stock_tradable(self, stock_id, trade_date): + # check if stock can be traded + # same as check in check_order + if self.check_stock_suspended(stock_id, trade_date) or self.check_stock_limit(stock_id, trade_date): + return False + else: + return True + + def check_order(self, order): + # check limit and suspended + if self.check_stock_suspended(order.stock_id, order.trade_date) or self.check_stock_limit( + order.stock_id, order.trade_date + ): + return False + else: + return True + + def deal_order(self, order, trade_account=None, position=None): + """ + Deal order when the actual transaction + + :param order: Deal the order. + :param trade_account: Trade account to be updated after dealing the order. + :param position: position to be updated after dealing the order. + :return: trade_val, trade_cost, trade_price + """ + # need to check order first + # TODO: check the order unit limit in the exchange!!!! + # The order limit is related to the adj factor and the cur_amount. + # factor = self.quote[(order.stock_id, order.trade_date)]['$factor'] + # cur_amount = trade_account.current.get_stock_amount(order.stock_id) + if self.check_order(order) is False: + raise AttributeError("need to check order first") + if trade_account is not None and position is not None: + raise ValueError("trade_account and position can only choose one") + + trade_price = self.get_deal_price(order.stock_id, order.trade_date) + trade_val, trade_cost = self._calc_trade_info_by_order( + order, trade_account.current if trade_account else position + ) + # update account + if trade_val > 0: + # If the order can only be deal 0 trade_val. Nothing to be updated + # Otherwise, it will result some stock with 0 amount in the position + if trade_account: + trade_account.update_order( + order=order, + trade_val=trade_val, + cost=trade_cost, + trade_price=trade_price, + ) + elif position: + position.update_order(order, trade_price) + + return trade_val, trade_cost, trade_price + + def get_quote_info(self, stock_id, trade_date): + return self.quote[(stock_id, trade_date)] + + def get_close(self, stock_id, trade_date): + return self.quote[(stock_id, trade_date)]["$close"] + + def get_deal_price(self, stock_id, trade_date): + deal_price = self.quote[(stock_id, trade_date)][self.deal_price] + if np.isclose(deal_price, 0.0) or np.isnan(deal_price): + self.logger.warning(f"(stock_id:{stock_id}, trade_date:{trade_date}, {self.deal_price}): {deal_price}!!!") + self.logger.warning(f"setting deal_price to close price") + deal_price = self.get_close(stock_id, trade_date) + return deal_price + + def get_factor(self, stock_id, trade_date): + return self.quote[(stock_id, trade_date)]["$factor"] + + def generate_amount_position_from_weight_position(self, weight_position, cash, trade_date): + """ + The generate the target position according to the weight and the cash. + NOTE: All the cash will assigned to the tadable stock. + + Parameter: + weight_position : dict {stock_id : weight}; allocate cash by weight_position + among then, weight must be in this range: 0 < weight < 1 + cash : cash + trade_date : trade date + """ + + # calculate the total weight of tradable value + tradable_weight = 0.0 + for stock_id in weight_position: + if self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date): + # weight_position must be greater than 0 and less than 1 + if weight_position[stock_id] < 0 or weight_position[stock_id] > 1: + raise ValueError( + "weight_position is {}, " + "weight_position is not in the range of (0, 1).".format(weight_position[stock_id]) + ) + tradable_weight += weight_position[stock_id] + + if tradable_weight - 1.0 >= 1e-5: + raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight)) + + amount_dict = {} + for stock_id in weight_position: + if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date): + amount_dict[stock_id] = ( + cash + * weight_position[stock_id] + / tradable_weight + // self.get_deal_price(stock_id=stock_id, trade_date=trade_date) + ) + return amount_dict + + def get_real_deal_amount(self, current_amount, target_amount, factor): + """ + Calculate the real adjust deal amount when considering the trading unit + + :param current_amount: + :param target_amount: + :param factor: + :return real_deal_amount; Positive deal_amount indicates buying more stock. + """ + if current_amount == target_amount: + return 0 + elif current_amount < target_amount: + deal_amount = target_amount - current_amount + deal_amount = self.round_amount_by_trade_unit(deal_amount, factor) + return deal_amount + else: + if target_amount == 0: + return -current_amount + else: + deal_amount = current_amount - target_amount + deal_amount = self.round_amount_by_trade_unit(deal_amount, factor) + return -deal_amount + + def generate_order_for_target_amount_position(self, target_position, current_position, trade_date): + """Parameter: + target_position : dict { stock_id : amount } + current_postion : dict { stock_id : amount} + trade_unit : trade_unit + down sample : for amount 321 and trade_unit 100, deal_amount is 300 + deal order on trade_date + """ + # split buy and sell for further use + buy_order_list = [] + sell_order_list = [] + # three parts: kept stock_id, dropped stock_id, new stock_id + # handle kept stock_id + + # because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different; + # so here we sort stock_id, and then randomly shuffle the order of stock_id + # because the same random seed is used, the final stock_id order is fixed + sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys()))) + random.seed(0) + random.shuffle(sorted_ids) + for stock_id in sorted_ids: + + # Do not generate order for the nontradable stocks + if not self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date): + continue + + target_amount = target_position.get(stock_id, 0) + current_amount = current_position.get(stock_id, 0) + factor = self.quote[(stock_id, trade_date)]["$factor"] + + deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor) + if deal_amount == 0: + continue + elif deal_amount > 0: + # buy stock + buy_order_list.append( + Order( + stock_id=stock_id, + amount=deal_amount, + direction=Order.BUY, + trade_date=trade_date, + factor=factor, + ) + ) + else: + # sell stock + sell_order_list.append( + Order( + stock_id=stock_id, + amount=abs(deal_amount), + direction=Order.SELL, + trade_date=trade_date, + factor=factor, + ) + ) + # return order_list : buy + sell + return sell_order_list + buy_order_list + + def calculate_amount_position_value(self, amount_dict, trade_date, only_tradable=False): + """Parameter + position : Position() + amount_dict : {stock_id : amount} + """ + value = 0 + for stock_id in amount_dict: + if ( + self.check_stock_suspended(stock_id=stock_id, trade_date=trade_date) is False + and self.check_stock_limit(stock_id=stock_id, trade_date=trade_date) is False + ): + value += self.get_deal_price(stock_id=stock_id, trade_date=trade_date) * amount_dict[stock_id] + return value + + def round_amount_by_trade_unit(self, deal_amount, factor): + """Parameter + deal_amount : float, adjusted amount + factor : float, adjusted factor + return : float, real amount + """ + if not self.trade_w_adj_price: + # the minimal amount is 1. Add 0.1 for solving precision problem. + return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor + return deal_amount + + def _calc_trade_info_by_order(self, order, position): + """ + Calculation of trade info + + :param order: + :param position: Position + :return: trade_val, trade_cost + """ + + trade_price = self.get_deal_price(order.stock_id, order.trade_date) + if order.direction == Order.SELL: + # sell + if position is not None: + if np.isclose(order.amount, position.get_stock_amount(order.stock_id)): + # when selling last stock. The amount don't need rounding + order.deal_amount = order.amount + else: + order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) + else: + # TODO: We don't know current position. + # We choose to sell all + order.deal_amount = order.amount + + trade_val = order.deal_amount * trade_price + trade_cost = max(trade_val * self.close_cost, self.min_cost) + elif order.direction == Order.BUY: + # buy + if position is not None: + cash = position.get_cash() + trade_val = order.amount * trade_price + if cash < trade_val * (1 + self.open_cost): + # The money is not enough + order.deal_amount = self.round_amount_by_trade_unit( + cash / (1 + self.open_cost) / trade_price, order.factor + ) + else: + # THe money is enough + order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) + else: + # Unknown amount of money. Just round the amount + order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) + + trade_val = order.deal_amount * trade_price + trade_cost = trade_val * self.open_cost + else: + raise NotImplementedError("order type {} error".format(order.type)) + + return trade_val, trade_cost diff --git a/qlib/contrib/backtest/order.py b/qlib/contrib/backtest/order.py new file mode 100644 index 0000000000..740773b2fd --- /dev/null +++ b/qlib/contrib/backtest/order.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +class Order: + + SELL = 0 + BUY = 1 + + def __init__(self, stock_id, amount, trade_date, direction, factor): + """Parameter + direction : Order.SELL for sell; Order.BUY for buy + stock_id : str + amount : float + trade_date : pd.Timestamp + factor : float + presents the weight factor assigned in Exchange() + """ + # check direction + if direction not in {Order.SELL, Order.BUY}: + raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy") + self.stock_id = stock_id + # amount of generated orders + self.amount = amount + # amount of successfully completed orders + self.deal_amount = 0 + self.trade_date = trade_date + self.direction = direction + self.factor = factor diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py new file mode 100644 index 0000000000..b614c08d01 --- /dev/null +++ b/qlib/contrib/backtest/position.py @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import pandas as pd +import copy +import pathlib +from .order import Order + +""" +Position module +""" + +""" +current state of position +a typical example is :{ + : { + 'count': , + 'amount': , + 'price': , + 'weight': , + }, +} + +""" + + +class Position: + """Position""" + + def __init__(self, cash=0, position_dict={}, today_account_value=0): + # NOTE: The position dict must be copied!!! + # Otherwise the initial value + self.init_cash = cash + self.position = position_dict.copy() + self.position["cash"] = cash + self.position["today_account_value"] = today_account_value + + def init_stock(self, stock_id, amount, price=None): + self.position[stock_id] = {} + self.position[stock_id]["count"] = 0 # update count in the end of this date + self.position[stock_id]["amount"] = amount + self.position[stock_id]["price"] = price + self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date + + def buy_stock(self, stock_id, amount, price): + if stock_id not in self.position: + self.init_stock(stock_id=stock_id, amount=amount, price=price) + else: + # exist, add amount + self.position[stock_id]["amount"] += amount + + def sell_stock(self, stock_id, amount): + if stock_id not in self.position: + raise KeyError("{} not in current position".format(stock_id)) + else: + # decrease the amount of stock + self.position[stock_id]["amount"] -= amount + # check if to delete + if self.position[stock_id]["amount"] < -1e-5: + raise ValueError( + "only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, amount) + ) + elif abs(self.position[stock_id]["amount"]) <= 1e-5: + self.del_stock(stock_id) + + def del_stock(self, stock_id): + del self.position[stock_id] + + def update_order(self, order, trade_price): + # handle order, order is a order class, defined in exchange.py + if order.direction == Order.BUY: + # BUY + self.buy_stock(stock_id=order.stock_id, amount=order.deal_amount, price=trade_price) + elif order.direction == Order.SELL: + # SELL + self.sell_stock(stock_id=order.stock_id, amount=order.deal_amount) + else: + raise NotImplementedError("do not suppotr order direction {}".format(order.direction)) + + def update_stock_price(self, stock_id, price): + self.position[stock_id]["price"] = price + + def update_stock_count(self, stock_id, count): + self.position[stock_id]["count"] = count + + def update_stock_weight(self, stock_id, weight): + self.position[stock_id]["weight"] = weight + + def update_cash(self, cash): + self.position["cash"] = cash + + def calculate_stock_value(self): + stock_list = self.get_stock_list() + value = 0 + for stock_id in stock_list: + value += self.position[stock_id]["amount"] * self.position[stock_id]["price"] + return value + + def calculate_value(self): + value = self.calculate_stock_value() + value += self.position["cash"] + return value + + def get_stock_list(self): + stock_list = list(set(self.position.keys()) - {"cash", "today_account_value"}) + return stock_list + + def get_stock_price(self, code): + return self.position[code]["price"] + + def get_stock_amount(self, code): + return self.position[code]["amount"] + + def get_stock_count(self, code): + return self.position[code]["count"] + + def get_stock_weight(self, code): + return self.position[code]["weight"] + + def get_cash(self): + return self.position["cash"] + + def get_stock_amount_dict(self): + """generate stock amount dict {stock_id : amount of stock} """ + d = {} + stock_list = self.get_stock_list() + for stock_code in stock_list: + d[stock_code] = self.get_stock_amount(code=stock_code) + return d + + def get_stock_weight_dict(self, only_stock=False): + """get_stock_weight_dict + generate stock weight fict {stock_id : value weight of stock in the position} + it is meaningful in the beginning or the end of each trade date + + :param only_stock: If only_stock=True, the weight of each stock in total stock will be returned + If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned + """ + if only_stock: + position_value = self.calculate_stock_value() + else: + position_value = self.calculate_value() + d = {} + stock_list = self.get_stock_list() + for stock_code in stock_list: + d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value + return d + + def add_count_all(self): + stock_list = self.get_stock_list() + for code in stock_list: + self.position[code]["count"] += 1 + + def update_weight_all(self): + weight_dict = self.get_stock_weight_dict() + for stock_code, weight in weight_dict.items(): + self.update_stock_weight(stock_code, weight) + + def save_position(self, path, last_trade_date): + path = pathlib.Path(path) + p = copy.deepcopy(self.position) + cash = pd.Series() + cash["init_cash"] = self.init_cash + cash["cash"] = p["cash"] + cash["today_account_value"] = p["today_account_value"] + cash["last_trade_date"] = str(last_trade_date.date()) if last_trade_date else None + del p["cash"] + del p["today_account_value"] + positions = pd.DataFrame.from_dict(p, orient="index") + with pd.ExcelWriter(path) as writer: + positions.to_excel(writer, sheet_name="position") + cash.to_excel(writer, sheet_name="info") + + def load_position(self, path): + """load position information from a file + should have format below + sheet "position" + columns: ['stock', 'count', 'amount', 'price', 'weight'] + 'count': , + 'amount': , + 'price': , + 'weight': , + + sheet "cash" + index: ['init_cash', 'cash', 'today_account_value'] + 'init_cash': , + 'cash': , + 'today_account_value': + """ + path = pathlib.Path(path) + positions = pd.read_excel(open(path, "rb"), sheet_name="position", index_col=0) + cash_record = pd.read_excel(open(path, "rb"), sheet_name="info", index_col=0) + positions = positions.to_dict(orient="index") + init_cash = cash_record.loc["init_cash"].values[0] + cash = cash_record.loc["cash"].values[0] + today_account_value = cash_record.loc["today_account_value"].values[0] + last_trade_date = cash_record.loc["last_trade_date"].values[0] + + # assign values + self.position = {} + self.init_cash = init_cash + self.position = positions + self.position["cash"] = cash + self.position["today_account_value"] = today_account_value + + return None if pd.isna(last_trade_date) else pd.Timestamp(last_trade_date) diff --git a/qlib/contrib/backtest/profit_attribution.py b/qlib/contrib/backtest/profit_attribution.py new file mode 100644 index 0000000000..d51fc450eb --- /dev/null +++ b/qlib/contrib/backtest/profit_attribution.py @@ -0,0 +1,324 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import numpy as np +import pandas as pd +from .position import Position +from ...data import D +from ...config import C +import datetime +from pathlib import Path + + +def get_benchmark_weight( + bench, + start_date=None, + end_date=None, + path=None, +): + """get_benchmark_weight + + get the stock weight distribution of the benchmark + + :param bench: + :param start_date: + :param end_date: + :param path: + + :return: The weight distribution of the the benchmark described by a pandas dataframe + Every row corresponds to a trading day. + Every column corresponds to a stock. + Every cell represents the strategy. + + """ + if not path: + path = Path(C.mount_path).expanduser() / "raw" / "AIndexMembers" / "weights.csv" + # TODO: the storage of weights should be implemented in a more elegent way + # TODO: The benchmark is not consistant with the filename in instruments. + bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"]) + bench_weight_df = bench_weight_df[bench_weight_df["index"] == bench] + bench_weight_df["date"] = pd.to_datetime(bench_weight_df["date"]) + if start_date is not None: + bench_weight_df = bench_weight_df[bench_weight_df.date >= start_date] + if end_date is not None: + bench_weight_df = bench_weight_df[bench_weight_df.date <= end_date] + bench_stock_weight = bench_weight_df.pivot_table(index="date", columns="code", values="weight") / 100.0 + return bench_stock_weight + + +def get_stock_weight_df(positions): + """get_stock_weight_df + :param positions: Given a positions from backtest result. + :return: A weight distribution for the position + """ + stock_weight = [] + index = [] + for date in sorted(positions.keys()): + pos = positions[date] + if isinstance(pos, dict): + pos = Position(position_dict=pos) + index.append(date) + stock_weight.append(pos.get_stock_weight_dict(only_stock=True)) + return pd.DataFrame(stock_weight, index=index) + + +def decompose_portofolio_weight(stock_weight_df, stock_group_df): + """decompose_portofolio_weight + + ''' + :param stock_weight_df: a pandas dataframe to describe the portofolio by weight. + every row corresponds to a day + every column corresponds to a stock. + Here is an example below. + code SH600004 SH600006 SH600017 SH600022 SH600026 SH600037 \ + date + 2016-01-05 0.001543 0.001570 0.002732 0.001320 0.003000 NaN + 2016-01-06 0.001538 0.001569 0.002770 0.001417 0.002945 NaN + .... + :param stock_group_df: a pandas dataframe to describe the stock group. + every row corresponds to a day + every column corresponds to a stock. + the value in the cell repreponds the group id. + Here is a example by for stock_group_df for industry. The value is the industry code + instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \ + datetime + 2016-01-05 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 + 2016-01-06 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 + ... + :return: Two dict will be returned. The group_weight and the stock_weight_in_group. + The key is the group. The value is a Series or Dataframe to describe the weight of group or weight of stock + """ + all_group = np.unique(stock_group_df.values.flatten()) + all_group = all_group[~np.isnan(all_group)] + + group_weight = {} + stock_weight_in_group = {} + for group_key in all_group: + group_mask = stock_group_df == group_key + group_weight[group_key] = stock_weight_df[group_mask].sum(axis=1) + stock_weight_in_group[group_key] = stock_weight_df[group_mask].divide(group_weight[group_key], axis=0) + return group_weight, stock_weight_in_group + + +def decompose_portofolio(stock_weight_df, stock_group_df, stock_ret_df): + """ + :param stock_weight_df: a pandas dataframe to describe the portofolio by weight. + every row corresponds to a day + every column corresponds to a stock. + Here is an example below. + code SH600004 SH600006 SH600017 SH600022 SH600026 SH600037 \ + date + 2016-01-05 0.001543 0.001570 0.002732 0.001320 0.003000 NaN + 2016-01-06 0.001538 0.001569 0.002770 0.001417 0.002945 NaN + 2016-01-07 0.001555 0.001546 0.002772 0.001393 0.002904 NaN + 2016-01-08 0.001564 0.001527 0.002791 0.001506 0.002948 NaN + 2016-01-11 0.001597 0.001476 0.002738 0.001493 0.003043 NaN + .... + + :param stock_group_df: a pandas dataframe to describe the stock group. + every row corresponds to a day + every column corresponds to a stock. + the value in the cell repreponds the group id. + Here is a example by for stock_group_df for industry. The value is the industry code + instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \ + datetime + 2016-01-05 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 + 2016-01-06 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 + 2016-01-07 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 + 2016-01-08 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 + 2016-01-11 801780.0 801170.0 801040.0 801880.0 801180.0 801160.0 + ... + + :param stock_ret_df: a pandas dataframe to describe the stock return. + every row corresponds to a day + every column corresponds to a stock. + the value in the cell repreponds the return of the group. + Here is a example by for stock_ret_df. + instrument SH600000 SH600004 SH600005 SH600006 SH600007 SH600008 \ + datetime + 2016-01-05 0.007795 0.022070 0.099099 0.024707 0.009473 0.016216 + 2016-01-06 -0.032597 -0.075205 -0.098361 -0.098985 -0.099707 -0.098936 + 2016-01-07 -0.001142 0.022544 0.100000 0.004225 0.000651 0.047226 + 2016-01-08 -0.025157 -0.047244 -0.038567 -0.098177 -0.099609 -0.074408 + 2016-01-11 0.023460 0.004959 -0.034384 0.018663 0.014461 0.010962 + ... + + :return: It will decompose the portofolio to the group weight and group return. + """ + all_group = np.unique(stock_group_df.values.flatten()) + all_group = all_group[~np.isnan(all_group)] + + group_weight, stock_weight_in_group = decompose_portofolio_weight(stock_weight_df, stock_group_df) + + group_ret = {} + for group_key in stock_weight_in_group: + stock_weight_in_group_start_date = min(stock_weight_in_group[group_key].index) + stock_weight_in_group_end_date = max(stock_weight_in_group[group_key].index) + + temp_stock_ret_df = stock_ret_df[ + (stock_ret_df.index >= stock_weight_in_group_start_date) + & (stock_ret_df.index <= stock_weight_in_group_end_date) + ] + + group_ret[group_key] = (temp_stock_ret_df * stock_weight_in_group[group_key]).sum(axis=1) + # If no weight is assigned, then the return of group will be np.nan + group_ret[group_key][group_weight[group_key] == 0.0] = np.nan + + group_weight_df = pd.DataFrame(group_weight) + group_ret_df = pd.DataFrame(group_ret) + return group_weight_df, group_ret_df + + +def get_daily_bin_group(bench_values, stock_values, group_n): + """get_daily_bin_group + Group the values of the stocks of benchmark into several bins in a day. + Put the stocks into these bins. + + :param bench_values: A series contains the value of stocks in benchmark. + The index is the stock code. + :param stock_values: A series contains the value of stocks of your portofolio + The index is the stock code. + :param group_n: Bins will be produced + + :return: A series with the same size and index as the stock_value. + The value in the series is the group id of the bins. + The No.1 bin contains the biggest values. + """ + stock_group = stock_values.copy() + + # get the bin split points based on the daily proportion of benchmark + split_points = np.percentile(bench_values[~bench_values.isna()], np.linspace(0, 100, group_n + 1)) + # Modify the biggest uppper bound and smallest lowerbound + split_points[0], split_points[-1] = -np.inf, np.inf + for i, (lb, up) in enumerate(zip(split_points, split_points[1:])): + stock_group.loc[stock_values[(stock_values >= lb) & (stock_values < up)].index] = group_n - i + return stock_group + + +def get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, group_n=None): + if group_method == "category": + # use the value of the benchmark as the category + return stock_group_field_df + elif group_method == "bins": + assert group_n is not None + # place the values into `group_n` fields. + # Each bin corresponds to a category. + new_stock_group_df = stock_group_field_df.copy().loc[ + bench_stock_weight_df.index.min() : bench_stock_weight_df.index.max() + ] + for idx, row in (~bench_stock_weight_df.isna()).iterrows(): + bench_values = stock_group_field_df.loc[idx, row[row].index] + new_stock_group_df.loc[idx] = get_daily_bin_group( + bench_values, stock_group_field_df.loc[idx], group_n=group_n + ) + return new_stock_group_df + + +def brinson_pa( + positions, + bench="SH000905", + group_field="industry", + group_method="category", + group_n=None, + deal_price="vwap", +): + """brinson profit attribution + + :param positions: The position produced by the backtest class + :param bench: The benchmark for comparing. TODO: if no benchmark is set, the equal-weighted is used. + :param group_field: The field used to set the group for assets allocation. + `industry` and `market_value` is often used. + :param group_method: 'category' or 'bins'. The method used to set the group for asstes allocation + `bin` will split the value into `group_n` bins and each bins represents a group + :param group_n: . Only used when group_method == 'bins'. + + :return: + A dataframe with three columns: RAA(excess Return of Assets Allocation), RSS(excess Return of Stock Selectino), RTotal(Total excess Return) + Every row corresponds to a trading day, the value corresponds to the next return for this trading day + The middle info of brinson profit attribution + """ + # group_method will decide how to group the group_field. + dates = sorted(positions.keys()) + + start_date, end_date = min(dates), max(dates) + + bench_stock_weight = get_benchmark_weight(bench, start_date, end_date) + + # The attributes for allocation will not + if not group_field.startswith("$"): + group_field = "$" + group_field + if not deal_price.startswith("$"): + deal_price = "$" + deal_price + + # FIXME: In current version. Some attributes(such as market_value) of some + # suspend stock is NAN. So we have to get more date to forward fill the NAN + shift_start_date = start_date - datetime.timedelta(days=250) + instruments = D.list_instruments( + D.instruments(market="all"), + start_time=shift_start_date, + end_time=end_date, + as_list=True, + ) + stock_df = D.features( + instruments, + [group_field, deal_price], + start_time=shift_start_date, + end_time=end_date, + freq="day", + ) + stock_df.columns = [group_field, "deal_price"] + + stock_group_field = stock_df[group_field].unstack().T + # FIXME: some attributes of some suspend stock is NAN. + stock_group_field = stock_group_field.fillna(method="ffill") + stock_group_field = stock_group_field.loc[start_date:end_date] + + stock_group = get_stock_group(stock_group_field, bench_stock_weight, group_method, group_n) + + deal_price_df = stock_df["deal_price"].unstack().T + deal_price_df = deal_price_df.fillna(method="ffill") + + # NOTE: + # The return will be slightly different from the of the return in the report. + # Here the position are adjusted at the end of the trading day with close + stock_ret = (deal_price_df - deal_price_df.shift(1)) / deal_price_df.shift(1) + stock_ret = stock_ret.shift(-1).loc[start_date:end_date] + + port_stock_weight_df = get_stock_weight_df(positions) + + # decomposing the portofolio + port_group_weight_df, port_group_ret_df = decompose_portofolio(port_stock_weight_df, stock_group, stock_ret) + bench_group_weight_df, bench_group_ret_df = decompose_portofolio(bench_stock_weight, stock_group, stock_ret) + + # if the group return of the portofolio is NaN, replace it with the market + # value + mod_port_group_ret_df = port_group_ret_df.copy() + mod_port_group_ret_df[mod_port_group_ret_df.isna()] = bench_group_ret_df + + Q1 = (bench_group_weight_df * bench_group_ret_df).sum(axis=1) + Q2 = (port_group_weight_df * bench_group_ret_df).sum(axis=1) + Q3 = (bench_group_weight_df * mod_port_group_ret_df).sum(axis=1) + Q4 = (port_group_weight_df * mod_port_group_ret_df).sum(axis=1) + + return ( + pd.DataFrame( + { + "RAA": Q2 - Q1, # The excess profit from the assets allocation + "RSS": Q3 - Q1, # The excess profit from the stocks selection + # The excess profit from the interaction of assets allocation and stocks selection + "RIN": Q4 - Q3 - Q2 + Q1, + "RTotal": Q4 - Q1, # The totoal excess profit + } + ), + { + "port_group_ret": port_group_ret_df, + "port_group_weight": port_group_weight_df, + "bench_group_ret": bench_group_ret_df, + "bench_group_weight": bench_group_weight_df, + "stock_group": stock_group, + "bench_stock_weight": bench_stock_weight, + "port_stock_weight": port_stock_weight_df, + "stock_ret": stock_ret, + }, + ) diff --git a/qlib/contrib/backtest/readme.md b/qlib/contrib/backtest/readme.md new file mode 100644 index 0000000000..578b248cad --- /dev/null +++ b/qlib/contrib/backtest/readme.md @@ -0,0 +1,184 @@ +# backtest + +modules + +simulate true trading environment + + +- Order +- Exchange +- Position +- Account +- Report + +backtest demo + + auto-update cross different modules from trade order + +strategy framework + +## Order + +trade order +- Order.SELL: sell order, default 0 +- Order.BUY: buy order, default 1 +- direction: `Order.SELL` for sell, `Order.BUY` for buy +- sotck_id +- amount +- trade_date : pd.Timestamp + +## Exchange + +the stock exanchge, deal the trade order, provide stock market information + +### Exchange Property +- trade_dates : list of pd.Timestamp +- codes : list stock_id list +- deal_price : str, 'close', 'open', 'vwap' +- quote : dataframe by D.features, trading data cache, default None +- limit_threshold : float, 0.1 for example, default None +- open_cost : cost rate for open, default 0.0015 +- close_cost : cost rate for close, default 0.0025 +- min_cost : min transaction cost, default 5 +- trade_unit : trade unit, 100 for China A market + +### Exchange Function +- check_stock_limit : buy limit, True for cannot trade, limit_threshold +- check_stock_suspended : check if suspended +- check_order : check is executable, include limit and suspend +- deal_order : (order, trade_account=None, position=None),if the order id executable, return trade_val, trade_cost, trade_price +- get price information realated, in this way need to check suspend first, (stock_id, trade_date) + - get_close + - get_deal_price +- generate_amount_position_from_weight_position : for strategy use +- generate_order_for_target_amount_position : generate order_list from target_position ( {stock_id : amount} ) and current_position({stock_id : amount}) +- calculate_amount_position_value : value +- compare function : compare position dict + +## Position + +state of asset + +including cash and stock + +for each stock, contain +- count : holding days +- amount : stock amount +- price : stock price + +### Functions: +- update_order + - buy_stock + - sell_stock +- update postion information + - cash + - price + - amount + - count +- calculate value : use price in postion to calculate value + - calculate_stock_value : without cash + - calculate_value : with cash +- get information + - get_stock_list + - get_stock_price + - get_stock_amount + - get_stock_count + - get_cash +- add_count_all : add 1 to all stock count +- transform + - get_stock_amount_dict + - get_stock_weight_dict : use price in postion to calculate value + +## Report + +daily report for account + +- account postion value for each trade date +- daily return rate for each trade date +- turnover for each trade date +- trade cost for each trade date +- value for each trade date +- cash +- latest_report_date : pd.TimeStamp + +### Function +- is_empty +- get_latest_date +- get_latest_account_value +- update_report_record +- generate_report_dataframe + +## Account + +state for the stock_trader + +- curent position : Position() class +- trading related + - return + - turnover + - cost + - earning +- postion value + - val + - cash +- report : Report() +- today + +### Funtions + +- get + - get_positions + - get_cash +- init_state +- update_order(order, trade_val, cost) : update current postion and trading metrix after the order is dealed +- update_daily_end() : when the end of trade date, summarize today + - update rtn , from order's view + - update price for each stock still in current position + - update value for this account + - update earning (2nd view of return , position' view) + - update holding day, count of stock + - update position hitory + - update report + + +## backtest_demo + +trade strategy: + + parameters : + topk : int, select topk stocks + buffer_margin : size of buffer margin + + description : + hold topk stocks at each trade date + when adjust position + the score model will generate scores for each stock + if the stock of current position not in top buffer_margin score, sell them out; + then equally buy recommended stocks + + the previous version of this strategy is in evaluate.py + + demo.py accomplishes same trading strategy with modules of Order, Exchange, Position, Report and Account + + test_strategy_demo.py did the consistency check between evaluate.py and demo.py + + strategy.py provide a strategy framework to do the backtest + +## Strategy + +strategy framework + + strategy will generate orders if given pred_scores and market environment information + there are two stages: + 1. generate target position + 2. generate order from target postion and current position + +document for the framework + + the document shows some examples to accomplish those two stages + +two strategy demo: +- Strategy_amount_demo +- Strategy_weight_demo + +backtest_demo with using strategy diff --git a/qlib/contrib/backtest/report.py b/qlib/contrib/backtest/report.py new file mode 100644 index 0000000000..beb9759d0d --- /dev/null +++ b/qlib/contrib/backtest/report.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from collections import OrderedDict +import pandas as pd +import pathlib + + +class Report: + # daily report of the account + # contain those followings: returns, costs turnovers, accounts, cash, bench, value + # update report + def __init__(self): + self.init_vars() + + def init_vars(self): + self.accounts = OrderedDict() # account postion value for each trade date + self.returns = OrderedDict() # daily return rate for each trade date + self.turnovers = OrderedDict() # turnover for each trade date + self.costs = OrderedDict() # trade cost for each trade date + self.values = OrderedDict() # value for each trade date + self.cashes = OrderedDict() + self.latest_report_date = None # pd.TimeStamp + + def is_empty(self): + return len(self.accounts) == 0 + + def get_latest_date(self): + return self.latest_report_date + + def get_latest_account_value(self): + return self.accounts[self.latest_report_date] + + def update_report_record( + self, + trade_date=None, + account_value=None, + cash=None, + return_rate=None, + turnover_rate=None, + cost_rate=None, + stock_value=None, + ): + # check data + if None in [ + trade_date, + account_value, + cash, + return_rate, + turnover_rate, + cost_rate, + stock_value, + ]: + raise ValueError( + "None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]" + ) + # update report data + self.accounts[trade_date] = account_value + self.returns[trade_date] = return_rate + self.turnovers[trade_date] = turnover_rate + self.costs[trade_date] = cost_rate + self.values[trade_date] = stock_value + self.cashes[trade_date] = cash + # update latest_report_date + self.latest_report_date = trade_date + # finish daily report update + + def generate_report_dataframe(self): + report = pd.DataFrame() + report["account"] = pd.Series(self.accounts) + report["return"] = pd.Series(self.returns) + report["turnover"] = pd.Series(self.turnovers) + report["cost"] = pd.Series(self.costs) + report["value"] = pd.Series(self.values) + report["cash"] = pd.Series(self.cashes) + report.index.name = "date" + return report + + def save_report(self, path): + r = self.generate_report_dataframe() + r.to_csv(path) + + def load_report(self, path): + """load report from a file + should have format like + columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash'] + :param + path: str/ pathlib.Path() + """ + path = pathlib.Path(path) + r = pd.read_csv(open(path, "rb"), index_col=0) + r.index = pd.DatetimeIndex(r.index) + + index = r.index + self.init_vars() + for date in index: + self.update_report_record( + trade_date=date, + account_value=r.loc[date]["account"], + cash=r.loc[date]["cash"], + return_rate=r.loc[date]["return"], + turnover_rate=r.loc[date]["turnover"], + cost_rate=r.loc[date]["cost"], + stock_value=r.loc[date]["value"], + ) diff --git a/qlib/contrib/estimator/__init__.py b/qlib/contrib/estimator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/qlib/contrib/estimator/config.py b/qlib/contrib/estimator/config.py new file mode 100644 index 0000000000..2ae8d4a0d4 --- /dev/null +++ b/qlib/contrib/estimator/config.py @@ -0,0 +1,176 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import yaml +import copy +import os +import json +import tempfile +from pathlib import Path +from ...config import REG_CN + + +class EstimatorConfigManager(object): + def __init__(self, config_path): + + if not config_path: + raise ValueError("Config path is invalid.") + self.config_path = config_path + + with open(config_path) as fp: + config = yaml.load(fp, Loader=yaml.FullLoader) + self.config = copy.deepcopy(config) + + self.ex_config = ExperimentConfig(config.get("experiment", dict()), self) + self.data_config = DataConfig(config.get("data", dict()), self) + self.model_config = ModelConfig(config.get("model", dict()), self) + self.trainer_config = TrainerConfig(config.get("trainer", dict()), self) + self.strategy_config = StrategyConfig(config.get("strategy", dict()), self) + self.backtest_config = BacktestConfig(config.get("backtest", dict()), self) + self.qlib_data_config = QlibDataConfig(config.get("qlib_data", dict()), self) + + # If the start_date and end_date are not given in data_config, they will be referred from the trainer_config. + handler_start_date = self.data_config.handler_parameters.get("start_date", None) + handler_end_date = self.data_config.handler_parameters.get("end_date", None) + if handler_start_date is None: + self.data_config.handler_parameters["start_date"] = self.trainer_config.parameters["train_start_date"] + if handler_end_date is None: + self.data_config.handler_parameters["end_date"] = self.trainer_config.parameters["test_end_date"] + + +class ExperimentConfig(object): + TRAIN_MODE = "train" + TEST_MODE = "test" + + OBSERVER_FILE_STORAGE = "file_storage" + OBSERVER_MONGO = "mongo" + + def __init__(self, config, CONFIG_MANAGER): + """__init__ + + :param config: The config dict for experiment + :param CONFIG_MANAGER: The estimator config manager + """ + self.name = config.get("name", "test_experiment") + # The dir of the result of all the experiments + self.global_dir = config.get("dir", os.path.dirname(CONFIG_MANAGER.config_path)) + # The dir of the result of current experiment + self.ex_dir = os.path.join(self.global_dir, self.name) + if not os.path.exists(self.ex_dir): + os.makedirs(self.ex_dir) + self.tmp_run_dir = tempfile.mkdtemp(dir=self.ex_dir) + self.mode = config.get("mode", ExperimentConfig.TRAIN_MODE) + self.sacred_dir = os.path.join(self.ex_dir, "sacred") + self.observer_type = config.get("observer_type", ExperimentConfig.OBSERVER_FILE_STORAGE) + self.mongo_url = config.get("mongo_url", None) + self.db_name = config.get("db_name", None) + self.finetune = config.get("finetune", False) + + # The path of the experiment id of the experiment + self.exp_info_path = config.get("exp_info_path", os.path.join(self.ex_dir, "exp_info.json")) + exp_info_dir = Path(self.exp_info_path).parent + exp_info_dir.mkdir(parents=True, exist_ok=True) + + # Test mode config + loader_args = config.get("loader", dict()) + if self.mode == ExperimentConfig.TEST_MODE or self.finetune: + loader_exp_info_path = loader_args.get("exp_info_path", None) + self.loader_model_index = loader_args.get("model_index", None) + if (loader_exp_info_path is not None) and (os.path.exists(loader_exp_info_path)): + with open(loader_exp_info_path) as fp: + loader_dict = json.load(fp) + for k, v in loader_dict.items(): + setattr(self, "loader_{}".format(k), v) + # Check loader experiment id + assert hasattr(self, "loader_id"), "If mode is test or finetune is True, loader must contain id." + else: + self.loader_id = loader_args.get("id", None) + if self.loader_id is None: + raise ValueError("If mode is test or finetune is True, loader must contain id.") + + self.loader_observer_type = loader_args.get("observer_type", self.observer_type) + self.loader_name = loader_args.get("name", self.name) + self.loader_dir = loader_args.get("dir", self.global_dir) + + self.loader_mongo_url = loader_args.get("mongo_url", self.mongo_url) + self.loader_db_name = loader_args.get("db_name", self.db_name) + + +class DataConfig(object): + def __init__(self, config, CONFIG_MANAGER): + """__init__ + + :param config: The config dict for data + :param CONFIG_MANAGER: The estimator config manager + """ + self.handler_module_path = config.get("module_path", "qlib.contrib.estimator.handler") + self.handler_class = config.get("class", "ALPHA360") + self.handler_parameters = config.get("args", dict()) + self.handler_filter = config.get("filter", dict()) + # Update provider uri. + + +class ModelConfig(object): + def __init__(self, config, CONFIG_MANAGER): + """__init__ + + :param config: The config dict for model + :param CONFIG_MANAGER: The estimator config manager + """ + self.model_class = config.get("class", "Model") + self.model_module_path = config.get("module_path", "qlib.contrib.model") + self.save_dir = os.path.join(CONFIG_MANAGER.ex_config.tmp_run_dir, "model") + self.save_path = config.get("save_path", os.path.join(self.save_dir, "model.bin")) + self.parameters = config.get("args", dict()) + # Make dir if need. + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + + +class TrainerConfig(object): + def __init__(self, config, CONFIG_MANAGER): + """__init__ + + :param config: The config dict for trainer + :param CONFIG_MANAGER: The estimator config manager + """ + self.trainer_class = config.get("class", "StaticTrainer") + self.trainer_module_path = config.get("module_path", "qlib.contrib.estimator.trainer") + self.parameters = config.get("args", dict()) + + +class StrategyConfig(object): + def __init__(self, config, CONFIG_MANAGER): + """__init__ + + :param config: The config dict for strategy + :param CONFIG_MANAGER: The estimator config manager + """ + self.strategy_class = config.get("class", "TopkAmountStrategy") + self.strategy_module_path = config.get("module_path", "qlib.contrib.strategy.strategy") + self.parameters = config.get("args", dict()) + + +class BacktestConfig(object): + def __init__(self, config, CONFIG_MANAGE): + """__init__ + + :param config: The config dict for strategy + :param CONFIG_MANAGE: The estimator config manager + """ + self.normal_backtest_parameters = config.get("normal_backtest_args", dict()) + self.long_short_backtest_parameters = config.get("long_short_backtest_args", dict()) + + +class QlibDataConfig(object): + def __init__(self, config, CONFIG_MANAGE): + """__init__ + + :param config: The config dict for qlib_client + :param CONFIG_MANAGE: The estimator config manager + """ + self.provider_uri = config.pop("provider_uri", "~/.qlib/qlib_data/cn_data") + self.auto_mount = config.pop("auto_mount", False) + self.mount_path = config.pop("mount_path", "~/.qlib/qlib_data/cn_data") + self.region = config.pop("region", REG_CN) + self.args = config diff --git a/qlib/contrib/estimator/estimator.py b/qlib/contrib/estimator/estimator.py new file mode 100644 index 0000000000..f41a3383c4 --- /dev/null +++ b/qlib/contrib/estimator/estimator.py @@ -0,0 +1,321 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# coding=utf-8 + +import pandas as pd + +import os +import copy +import json +import yaml +import pickle + +import qlib +from ..evaluate import risk_analysis +from ..evaluate import backtest as normal_backtest +from ..evaluate import long_short_backtest +from .config import ExperimentConfig +from .fetcher import create_fetcher_with_config + +from ...log import get_module_logger, TimeInspector +from ...utils import get_module_by_module_path, compare_dict_value + + +class Estimator(object): + def __init__(self, config_manager, sacred_ex): + + # Set logger. + self.logger = get_module_logger("Estimator") + + # 1. Set config manager. + self.config_manager = config_manager + + # 2. Set configs. + self.ex_config = config_manager.ex_config + self.data_config = config_manager.data_config + self.model_config = config_manager.model_config + self.trainer_config = config_manager.trainer_config + self.strategy_config = config_manager.strategy_config + self.backtest_config = config_manager.backtest_config + + # If experiment.mode is test or experiment.finetune is True, load the experimental results in the loader + if self.ex_config.mode == self.ex_config.TEST_MODE or self.ex_config.finetune: + self.compare_config_with_config_manger(self.config_manager) + + # 3. Set sacred_experiment. + self.ex = sacred_ex + + # 4. Init data handler. + self.data_handler = None + self._init_data_handler() + + # 5. Init trainer. + self.trainer = None + self._init_trainer() + + # 6. Init strategy. + self.strategy = None + self._init_strategy() + + def _init_data_handler(self): + handler_module = get_module_by_module_path(self.data_config.handler_module_path) + + # Set market + market = self.data_config.handler_filter.get("market", None) + if market is None: + if "market" in self.data_config.handler_parameters: + self.logger.warning( + "Warning: The market in data.args section is deprecated. " + "It only works when market is not set in data.filter section. " + "It will be overridden by market in the data.filter section." + ) + market = self.data_config.handler_parameters["market"] + else: + market = "csi500" + + self.data_config.handler_parameters["market"] = market + + data_filter_list = [] + handler_filters = self.data_config.handler_filter.get("filter_pipeline", list()) + for h_filter in handler_filters: + filter_module_path = h_filter.get("module_path", "qlib.data.filter") + filter_class_name = h_filter.get("class", "") + filter_parameters = h_filter.get("args", {}) + filter_module = get_module_by_module_path(filter_module_path) + filter_class = getattr(filter_module, filter_class_name) + data_filter = filter_class(**filter_parameters) + data_filter_list.append(data_filter) + + self.data_config.handler_parameters["data_filter_list"] = data_filter_list + handler_class = getattr(handler_module, self.data_config.handler_class) + self.data_handler = handler_class(**self.data_config.handler_parameters) + + def _init_trainer(self): + + model_module = get_module_by_module_path(self.model_config.model_module_path) + trainer_module = get_module_by_module_path(self.trainer_config.trainer_module_path) + model_class = getattr(model_module, self.model_config.model_class) + trainer_class = getattr(trainer_module, self.trainer_config.trainer_class) + + self.trainer = trainer_class( + model_class, + self.model_config.save_path, + self.model_config.parameters, + self.data_handler, + self.ex, + **self.trainer_config.parameters + ) + + def _init_strategy(self): + + module = get_module_by_module_path(self.strategy_config.strategy_module_path) + strategy_class = getattr(module, self.strategy_config.strategy_class) + self.strategy = strategy_class(**self.strategy_config.parameters) + + def run(self): + if self.ex_config.mode == ExperimentConfig.TRAIN_MODE: + self.trainer.train() + elif self.ex_config.mode == ExperimentConfig.TEST_MODE: + self.trainer.load() + else: + raise ValueError("unexpected mode: %s" % self.ex_config.mode) + analysis = self.backtest() + self.logger.info(analysis) + self.logger.info( + "experiment id: {}, experiment name: {}".format(self.ex.experiment.current_run._id, self.ex_config.name) + ) + + # Remove temp dir + # shutil.rmtree(self.ex_config.tmp_run_dir) + + def backtest(self): + TimeInspector.set_time_mark() + # 1. Get pred and prediction score of model(s). + pred = self.trainer.get_test_score() + performance = self.trainer.get_test_performance() + # 2. Normal Backtest. + report_normal, positions_normal = self._normal_backtest(pred) + # 3. Long-Short Backtest. + long_short_reports = self._long_short_backtest(pred) + # 4. Analyze + analysis_df = self._analyze(report_normal, long_short_reports) + # 5. Save. + self._save_backtest_result( + pred, + analysis_df, + positions_normal, + report_normal, + long_short_reports, + performance, + ) + return analysis_df + + def _normal_backtest(self, pred): + TimeInspector.set_time_mark() + if "account" not in self.backtest_config.normal_backtest_parameters: + if "account" in self.strategy_config.parameters: + self.logger.warning( + "Warning: The account in strategy section is deprecated. " + "It only works when account is not set in backtest section. " + "It will be overridden by account in the backtest section." + ) + self.backtest_config.normal_backtest_parameters["account"] = self.strategy_config.parameters["account"] + report_normal, positions_normal = normal_backtest( + pred, strategy=self.strategy, **self.backtest_config.normal_backtest_parameters + ) + TimeInspector.log_cost_time("Finished normal backtest.") + return report_normal, positions_normal + + def _long_short_backtest(self, pred): + TimeInspector.set_time_mark() + long_short_reports = long_short_backtest(pred, **self.backtest_config.long_short_backtest_parameters) + TimeInspector.log_cost_time("Finished long-short backtest.") + return long_short_reports + + @staticmethod + def _analyze(report_normal, long_short_reports): + TimeInspector.set_time_mark() + + analysis = dict() + analysis["pred_long"] = risk_analysis(long_short_reports["long"]) + analysis["pred_short"] = risk_analysis(long_short_reports["short"]) + analysis["pred_long_short"] = risk_analysis(long_short_reports["long_short"]) + analysis["sub_bench"] = risk_analysis(report_normal["return"] - report_normal["bench"]) + analysis["sub_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"]) + analysis_df = pd.concat(analysis) # type: pd.DataFrame + TimeInspector.log_cost_time( + "Finished generating analysis," " average turnover is: {0:.4f}.".format(report_normal["turnover"].mean()) + ) + return analysis_df + + def _save_backtest_result(self, pred, analysis, positions, report_normal, long_short_reports, performance): + # 1. Result dir. + result_dir = os.path.join(self.config_manager.ex_config.tmp_run_dir, "result") + if not os.path.exists(result_dir): + os.makedirs(result_dir) + + self.ex.add_info( + "task_config", + json.loads(json.dumps(self.config_manager.config, default=str)), + ) + + # 2. Pred. + TimeInspector.set_time_mark() + pred_pkl_path = os.path.join(result_dir, "pred.pkl") + pred.to_pickle(pred_pkl_path) + self.ex.add_artifact(pred_pkl_path) + TimeInspector.log_cost_time("Finished saving pred.pkl to: {}".format(pred_pkl_path)) + + # 3. Ana. + TimeInspector.set_time_mark() + analysis_pkl_path = os.path.join(result_dir, "analysis.pkl") + analysis.to_pickle(analysis_pkl_path) + self.ex.add_artifact(analysis_pkl_path) + TimeInspector.log_cost_time("Finished saving analysis.pkl to: {}".format(analysis_pkl_path)) + + # 4. Pos. + TimeInspector.set_time_mark() + positions_pkl_path = os.path.join(result_dir, "positions.pkl") + with open(positions_pkl_path, "wb") as fp: + pickle.dump(positions, fp) + self.ex.add_artifact(positions_pkl_path) + TimeInspector.log_cost_time("Finished saving positions.pkl to: {}".format(positions_pkl_path)) + + # 5. Report normal. + TimeInspector.set_time_mark() + report_normal_pkl_path = os.path.join(result_dir, "report_normal.pkl") + report_normal.to_pickle(report_normal_pkl_path) + self.ex.add_artifact(report_normal_pkl_path) + TimeInspector.log_cost_time("Finished saving report_normal.pkl to: {}".format(report_normal_pkl_path)) + + # 6. Report long short. + for k, name in zip( + ["long", "short", "long_short"], + ["report_long.pkl", "report_short.pkl", "report_long_short.pkl"], + ): + TimeInspector.set_time_mark() + pkl_path = os.path.join(result_dir, name) + long_short_reports[k].to_pickle(pkl_path) + self.ex.add_artifact(pkl_path) + TimeInspector.log_cost_time("Finished saving {} to: {}".format(name, pkl_path)) + + # 7. Origin test label. + TimeInspector.set_time_mark() + label_pkl_path = os.path.join(result_dir, "label.pkl") + self.data_handler.get_origin_test_label_with_date( + self.trainer_config.parameters["test_start_date"], + self.trainer_config.parameters["test_end_date"], + ).to_pickle(label_pkl_path) + self.ex.add_artifact(label_pkl_path) + TimeInspector.log_cost_time("Finished saving label.pkl to: {}".format(label_pkl_path)) + + # 8. Experiment info, save the model(s) performance here. + TimeInspector.set_time_mark() + cur_ex_id = self.ex.experiment.current_run._id + exp_info = { + "id": cur_ex_id, + "name": self.ex_config.name, + "performance": performance, + "observer_type": self.ex_config.observer_type, + } + + if self.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO: + exp_info.update( + { + "mongo_url": self.ex_config.mongo_url, + "db_name": self.ex_config.db_name, + } + ) + else: + exp_info.update({"dir": self.ex_config.global_dir}) + + with open(self.ex_config.exp_info_path, "w") as fp: + json.dump(exp_info, fp, indent=4, sort_keys=True) + self.ex.add_artifact(self.ex_config.exp_info_path) + TimeInspector.log_cost_time("Finished saving ex_info to: {}".format(self.ex_config.exp_info_path)) + + @staticmethod + def compare_config_with_config_manger(config_manager): + """Compare loader model args and current config with ConfigManage + + :param config_manager: ConfigManager + :return: + """ + fetcher = create_fetcher_with_config(config_manager, load_form_loader=True) + loader_mode_config = fetcher.get_experiment( + exp_name=config_manager.ex_config.loader_name, + exp_id=config_manager.ex_config.loader_id, + fields=["task_config"], + )["task_config"] + with open(config_manager.config_path) as fp: + current_config = yaml.load(fp.read()) + current_config = json.loads(json.dumps(current_config, default=str)) + + logger = get_module_logger("Estimator") + + loader_mode_config = copy.deepcopy(loader_mode_config) + current_config = copy.deepcopy(current_config) + + # Require test_mode_config.test_start_date <= current_config.test_start_date + loader_trainer_args = loader_mode_config.get("trainer", {}).get("args", {}) + cur_trainer_args = current_config.get("trainer", {}).get("args", {}) + loader_start_date = loader_trainer_args.pop("test_start_date") + cur_test_start_date = cur_trainer_args.pop("test_start_date") + assert ( + loader_start_date <= cur_test_start_date + ), "Require: loader_mode_config.test_start_date <= current_config.test_start_date" + + # TODO: For the user's own extended `Trainer`, the support is not very good + if "RollingTrainer" == current_config.get("trainer", {}).get("class", None): + loader_period = loader_trainer_args.pop("rolling_period") + cur_period = cur_trainer_args.pop("rolling_period") + assert ( + loader_period == cur_period + ), "Require: loader_mode_config.rolling_period == current_config.rolling_period" + + compare_section = ["trainer", "model", "data"] + for section in compare_section: + changes = compare_dict_value(loader_mode_config.get(section, {}), current_config.get(section, {})) + if changes: + logger.warning("Warning: Loader mode config and current config, `{}` are different:\n".format(section)) diff --git a/qlib/contrib/estimator/fetcher.py b/qlib/contrib/estimator/fetcher.py new file mode 100644 index 0000000000..920c258c3a --- /dev/null +++ b/qlib/contrib/estimator/fetcher.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# coding=utf-8 + +import copy +import json +import yaml +import pickle +import gridfs +import pymongo +from pathlib import Path +from abc import abstractmethod + +from .config import EstimatorConfigManager, ExperimentConfig + + +class Fetcher(object): + """Sacred Experiments Fetcher""" + + @abstractmethod + def _get_experiment(self, exp_name, exp_id): + """Get experiment basic info with experiment and experiment id + + :param exp_name: experiment name + :param exp_id: experiment id + :return: dict + Must contain keys: _id, experiment, info, stop_time. + Here is an example below for FileFetcher. + exp = { + '_id': exp_id, # experiment id + 'path': path, # experiment result path + 'experiment': {'name': exp_name}, # experiment + 'info': info, # experiment config info + 'stop_time': run.get('stop_time', None) # The time the experiment ended + } + + """ + pass + + @abstractmethod + def _list_experiments(self, exp_name=None): + """Get experiment basic info list with experiment name + + :param exp_name: experiment name + :return: list + + """ + pass + + @abstractmethod + def _iter_artifacts(self, experiment): + """Get information about the data in the experiment results + + :param experiment: `self._get_experiment` method result + :return: iterable + Each element contains two elements. + first element : data name + second element : data uri + """ + pass + + @abstractmethod + def _load_data(self, uri): + """Load data with uri + + :param uri: data uri + :return: bytes + """ + pass + + @staticmethod + def model_dict_to_buffer_list(model_dict): + """ + + :param model_dict: + :return: + """ + model_list = [] + is_static_model = False + if len(model_dict) == 1 and list(model_dict.keys())[0] == "model.bin": + is_static_model = True + model_list.append(list(model_dict.values())[0]) + else: + sep = "model.bin_" + model_ids = list(map(lambda x: int(x.split(sep)[1]), model_dict.keys())) + min_id, max_id = min(model_ids), max(model_ids) + for i in range(min_id, max_id + 1): + model_key = sep + str(i) + model = model_dict.get(model_key, None) + if model is None: + print( + "WARNING: In Fetcher, {} is missing when the get model is in the get_experiment function.".format( + model_key + ) + ) + break + else: + model_list.append(model) + + if is_static_model: + return model_list[0] + + return model_list + + def get_experiments(self, exp_name=None): + """Get experiments with name. + + :param exp_name: str + If `exp_name` is set to None, then all experiments will return. + :return: dict + Experiments info dict(Including experiment id and task_config to run the + experiment). Here is an example below. + { + 'a_experiment': [ + { + 'id': '1', + 'task_config': {...} + }, + ... + ] + ... + } + """ + res = dict() + for ex in self._list_experiments(exp_name): + name = ex["experiment"]["name"] + tmp = { + "id": ex["_id"], + "task_config": ex["info"].get("task_config", {}), + "ex_run_stop_time": ex.get("stop_time", None), + } + res.setdefault(name, []).append(tmp) + return res + + def get_experiment(self, exp_name, exp_id, fields=None): + """ + + :param exp_name: + :param exp_id: + :param fields: list + Experiment result fields, if fields is None, will get all fields. + Currently supported fields: + ['model', 'analysis', 'positions', 'report_normal', 'report_long', 'report_short', + 'report_long_short', 'pred', 'task_config', 'label'] + :return: dict + """ + fields = copy.copy(fields) + ex = self._get_experiment(exp_name, exp_id) + results = dict() + model_dict = dict() + for name, uri in self._iter_artifacts(ex): + # When saving, use `sacred.experiment.add_artifact(filename)` , so `name` is os.path.basename(filename) + prefix = name.split(".")[0] + if fields and prefix not in fields: + continue + data = self._load_data(uri) + if prefix == "model": + model_dict[name] = data + else: + results[prefix] = pickle.loads(data) + # Sort model + if model_dict: + results["model"] = self.model_dict_to_buffer_list(model_dict) + + # Info + results["task_config"] = ex["info"].get("task_config", {}) + return results + + def estimator_config_to_dict(self, exp_name, exp_id): + """Save configuration to file + + :param exp_name: + :param exp_id: + :return: config dict + """ + + return self.get_experiment(exp_name, exp_id, fields=["task_config"])["task_config"] + + +class FileFetcher(Fetcher): + """File Fetcher""" + + def __init__(self, experiments_dir): + self.experiments_dir = Path(experiments_dir) + + def _get_experiment(self, exp_name, exp_id): + path = self.experiments_dir / exp_name / "sacred" / str(exp_id) + info_path = path / "info.json" + run_path = path / "run.json" + + if info_path.exists(): + with info_path.open("r") as f: + info = json.load(f) + else: + info = {} + + if run_path.exists(): + with run_path.open("r") as f: + run = json.load(f) + else: + run = {} + + exp = { + "_id": exp_id, + "path": path, + "experiment": {"name": exp_name}, + "info": info, + "stop_time": run.get("stop_time", None), + } + return exp + + def _list_experiments(self, exp_name=None): + runs = [] + for path in self.experiments_dir.glob("{}/sacred/[!_]*".format(exp_name or "*")): + exp_name, exp_id = path.parents[1].name, path.name + runs.append(self._get_experiment(exp_name, exp_id)) + return runs + + def _iter_artifacts(self, experiment): + if experiment is None: + return [] + + for fname in experiment["path"].iterdir(): + if fname.suffix == ".pkl" or ".bin" in fname.suffix: + name, uri = fname.name, str(fname) + yield name, uri + + def _load_data(self, uri): + with open(uri, "rb") as f: + data = f.read() + return data + + +class MongoFetcher(Fetcher): + """MongoDB Fetcher""" + + def __init__(self, mongo_url, db_name): + self.mongo_url = mongo_url + self.db_name = db_name + self.client = None + self.db = None + self.runs = None + self.fs = None + self._setup_mongo_client() + + def _setup_mongo_client(self): + self.client = pymongo.MongoClient(self.mongo_url) + self.db = self.client[self.db_name] + self.runs = self.db.runs + self.fs = gridfs.GridFS(self.db) + + def _get_experiment(self, exp_name, exp_id): + return self.runs.find_one({"_id": exp_id}) + + def _list_experiments(self, exp_name=None): + if exp_name is None: + return self.runs.find() + return self.runs.find({"experiment.name": exp_name}) + + def _iter_artifacts(self, experiment): + if experiment is None: + return [] + for artifact in experiment.get("artifacts", []): + name, uri = artifact["name"], artifact["file_id"] + yield name, uri + + def _load_data(self, uri): + data = self.fs.get(uri).read() + return data + + +def create_fetcher_with_config(config_manager: EstimatorConfigManager, load_form_loader: bool = False): + """Create fetcher with loader config + + :param config_manager: + :param load_form_loader + :return: + """ + flag = "" + if load_form_loader: + flag = "loader_" + if config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_FILE_STORAGE: + return FileFetcher(eval("config_manager.ex_config.{}_dir".format("loader" if load_form_loader else "global"))) + elif config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO: + return MongoFetcher( + mongo_url=eval("config_manager.ex_config.{}mongo_url".format(flag)), + db_name=eval("config_manager.ex_config.{}db_name".format(flag)), + ) + else: + return NotImplementedError("Unkown Backend") diff --git a/qlib/contrib/estimator/handler.py b/qlib/contrib/estimator/handler.py new file mode 100644 index 0000000000..e63eb55ecc --- /dev/null +++ b/qlib/contrib/estimator/handler.py @@ -0,0 +1,584 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# coding=utf-8 +import abc +import bisect +import logging + +import pandas as pd +import numpy as np + +from ...log import get_module_logger, TimeInspector +from ...data import D +from ...utils import parse_config, transform_end_date + +from . import processor as processor_module + + +class BaseDataHandler(abc.ABC): + def __init__(self, processors=[], **kwargs): + """ + :param start_date: + :param end_date: + :param kwargs: + """ + # Set logger + self.logger = get_module_logger("DataHandler") + + # init data using kwargs + self._init_kwargs(**kwargs) + + # Setup data. + self.raw_df, self.feature_names, self.label_names = self._init_raw_df() + + # Setup preprocessor + self.processors = [] + for klass in processors: + if isinstance(klass, str): + try: + klass = getattr(processor_module, klass) + except: + raise ValueError("unknown Processor %s" % klass) + self.processors.append(klass(self.feature_names, self.label_names, **kwargs)) + + def _init_kwargs(self, **kwargs): + """ + init the kwargs of DataHandler + """ + pass + + def _init_raw_df(self): + """ + init raw_df, feature_names, label_names of DataHandler + if the index of df_feature and df_label are not same, user need to overload this method to merge (e.g. inner, left, right merge). + + """ + df_features = self.setup_feature() + feature_names = df_features.columns + + df_labels = self.setup_label() + label_names = df_labels.columns + + raw_df = df_features.merge(df_labels, left_index=True, right_index=True, how="left") + + return raw_df, feature_names, label_names + + def reset_label(self, df_labels): + for col in self.label_names: + del self.raw_df[col] + self.label_names = df_labels.columns + self.raw_df = self.raw_df.merge(df_labels, left_index=True, right_index=True, how="left") + + def split_rolling_periods( + self, + train_start_date, + train_end_date, + validate_start_date, + validate_end_date, + test_start_date, + test_end_date, + rolling_period, + calendar_freq="day", + ): + """ + Calculating the Rolling split periods, the period rolling on market calendar. + :param train_start_date: + :param train_end_date: + :param validate_start_date: + :param validate_end_date: + :param test_start_date: + :param test_end_date: + :param rolling_period: The market period of rolling + :param calendar_freq: The frequence of the market calendar + :yield: Rolling split periods + """ + + def get_start_index(calendar, start_date): + start_index = bisect.bisect_left(calendar, start_date) + return start_index + + def get_end_index(calendar, end_date): + end_index = bisect.bisect_right(calendar, end_date) + return end_index - 1 + + calendar = self.raw_df.index.get_level_values("datetime").unique() + + train_start_index = get_start_index(calendar, pd.Timestamp(train_start_date)) + train_end_index = get_end_index(calendar, pd.Timestamp(train_end_date)) + valid_start_index = get_start_index(calendar, pd.Timestamp(validate_start_date)) + valid_end_index = get_end_index(calendar, pd.Timestamp(validate_end_date)) + test_start_index = get_start_index(calendar, pd.Timestamp(test_start_date)) + test_end_index = test_start_index + rolling_period - 1 + + need_stop_split = False + + bound_test_end_index = get_end_index(calendar, pd.Timestamp(test_end_date)) + + while not need_stop_split: + + if test_end_index > bound_test_end_index: + test_end_index = bound_test_end_index + need_stop_split = True + + yield ( + calendar[train_start_index], + calendar[train_end_index], + calendar[valid_start_index], + calendar[valid_end_index], + calendar[test_start_index], + calendar[test_end_index], + ) + + train_start_index += rolling_period + train_end_index += rolling_period + valid_start_index += rolling_period + valid_end_index += rolling_period + test_start_index += rolling_period + test_end_index += rolling_period + + def get_rolling_data( + self, + train_start_date, + train_end_date, + validate_start_date, + validate_end_date, + test_start_date, + test_end_date, + rolling_period, + calendar_freq="day", + ): + # Set generator. + for period in self.split_rolling_periods( + train_start_date, + train_end_date, + validate_start_date, + validate_end_date, + test_start_date, + test_end_date, + rolling_period, + calendar_freq, + ): + ( + x_train, + y_train, + x_validate, + y_validate, + x_test, + y_test, + ) = self.get_split_data(*period) + yield x_train, y_train, x_validate, y_validate, x_test, y_test + + def get_split_data( + self, + train_start_date, + train_end_date, + validate_start_date, + validate_end_date, + test_start_date, + test_end_date, + ): + """ + all return types are DataFrame + """ + ## TODO: loc can be slow, expecially when we put it at the second level index. + if self.raw_df.index.names[0] == "instrument": + df_train = self.raw_df.loc(axis=0)[:, train_start_date:train_end_date] + df_validate = self.raw_df.loc(axis=0)[:, validate_start_date:validate_end_date] + df_test = self.raw_df.loc(axis=0)[:, test_start_date:test_end_date] + else: + df_train = self.raw_df.loc[train_start_date:train_end_date] + df_validate = self.raw_df.loc[validate_start_date:validate_end_date] + df_test = self.raw_df.loc[test_start_date:test_end_date] + + TimeInspector.set_time_mark() + df_train, df_validate, df_test = self.setup_process_data(df_train, df_validate, df_test) + TimeInspector.log_cost_time("Finished setup processed data.") + + x_train = df_train[self.feature_names] + y_train = df_train[self.label_names] + + x_validate = df_validate[self.feature_names] + y_validate = df_validate[self.label_names] + + x_test = df_test[self.feature_names] + y_test = df_test[self.label_names] + + return x_train, y_train, x_validate, y_validate, x_test, y_test + + def setup_process_data(self, df_train, df_valid, df_test): + """ + process the train, valid and test data + :return: the processed train, valid and test data. + """ + for processor in self.processors: + df_train, df_valid, df_test = processor(df_train, df_valid, df_test) + return df_train, df_valid, df_test + + def get_origin_test_label_with_date(self, test_start_date, test_end_date, freq="day"): + """Get origin test label + + :param test_start_date: test start date + :param test_end_date: test end date + :param freq: freq + :return: pd.DataFrame + """ + test_end_date = transform_end_date(test_end_date, freq=freq) + return self.raw_df.loc[(slice(None), slice(test_start_date, test_end_date)), self.label_names] + + @abc.abstractmethod + def setup_feature(self): + """ + Implement this method to load raw feature. + the format of the feature is below + return: df_features + """ + pass + + @abc.abstractmethod + def setup_label(self): + """ + Implement this method to load and calculate label. + the format of the label is below + + return: df_label + """ + pass + + +class QLibDataHandler(BaseDataHandler): + def __init__(self, start_date, end_date, *args, **kwargs): + # Dates. + self.start_date = start_date + self.end_date = end_date + super().__init__(*args, **kwargs) + + def _init_kwargs(self, **kwargs): + + # Instruments + instruments = kwargs.get("instruments", None) + if instruments is None: + market = kwargs.get("market", "csi500").lower() + data_filter_list = kwargs.get("data_filter_list", list()) + self.instruments = D.instruments(market, filter_pipe=data_filter_list) + else: + self.instruments = instruments + + # Config of features and labels + self._fields = kwargs.get("fields", []) + self._names = kwargs.get("names", []) + self._labels = kwargs.get("labels", []) + self._label_names = kwargs.get("label_names", []) + + # Check arguments + assert len(self._fields) > 0, "features list is empty" + assert len(self._labels) > 0, "labels list is empty" + + # Check end_date + # If test_end_date is -1 or greater than the last date, the last date is used + self.end_date = transform_end_date(self.end_date) + + def setup_feature(self): + """ + Load the raw data. + return: df_features + """ + TimeInspector.set_time_mark() + + if len(self._names) == 0: + names = ["F%d" % i for i in range(len(self._fields))] + else: + names = self._names + + df_features = D.features(self.instruments, self._fields, self.start_date, self.end_date) + df_features.columns = names + + TimeInspector.log_cost_time("Finished loading features.") + + return df_features + + def setup_label(self): + """ + Build up labels in df through users' method + :return: df_labels + """ + TimeInspector.set_time_mark() + + if len(self._label_names) == 0: + label_names = ["LABEL%d" % i for i in range(len(self._labels))] + else: + label_names = self._label_names + + df_labels = D.features(self.instruments, self._labels, self.start_date, self.end_date) + df_labels.columns = label_names + + TimeInspector.log_cost_time("Finished loading labels.") + + return df_labels + + +def parse_config_to_fields(config): + """create factors from config + + config = { + 'kbar': {}, # whether to use some hard-code kbar features + 'price': { # whether to use raw price features + 'windows': [0, 1, 2, 3, 4], # use price at n days ago + 'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use + }, + 'volume': { # whether to use raw volume features + 'windows': [0, 1, 2, 3, 4], # use volume at n days ago + }, + 'rolling': { # whether to use rolling operator based features + 'windows': [5, 10, 20, 30, 60], # rolling windows size + 'include': ['ROC', 'MA', 'STD'], # rolling operator to use + #if include is None we will use default operators + 'exclude': ['RANK'], # rolling operator not to use + } + } + """ + fields = [] + names = [] + if "kbar" in config: + fields += [ + "($close-$open)/$open", + "($high-$low)/$open", + "($close-$open)/($high-$low+1e-12)", + "($high-Greater($open, $close))/$open", + "($high-Greater($open, $close))/($high-$low+1e-12)", + "(Less($open, $close)-$low)/$open", + "(Less($open, $close)-$low)/($high-$low+1e-12)", + "(2*$close-$high-$low)/$open", + "(2*$close-$high-$low)/($high-$low+1e-12)", + ] + names += [ + "KMID", + "KLEN", + "KMID2", + "KUP", + "KUP2", + "KLOW", + "KLOW2", + "KSFT", + "KSFT2", + ] + if "price" in config: + windows = config["price"].get("windows", range(5)) + feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"]) + for field in feature: + field = field.lower() + fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows] + names += [field.upper() + str(d) for d in windows] + if "volume" in config: + windows = config["volume"].get("windows", range(5)) + fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows] + names += ["VOLUME" + str(d) for d in windows] + if "rolling" in config: + windows = config["rolling"].get("windows", [5, 10, 20, 30, 60]) + include = config["rolling"].get("include", None) + exclude = config["rolling"].get("exclude", []) + # `exclude` in dataset config unnecessary filed + # `include` in dataset config necessary field + use = lambda x: x not in exclude and (include is None or x in include) + if use("ROC"): + fields += ["Ref($close, %d)/$close" % d for d in windows] + names += ["ROC%d" % d for d in windows] + if use("MA"): + fields += ["Mean($close, %d)/$close" % d for d in windows] + names += ["MA%d" % d for d in windows] + if use("STD"): + fields += ["Std($close, %d)/$close" % d for d in windows] + names += ["STD%d" % d for d in windows] + if use("BETA"): + fields += ["Slope($close, %d)/$close" % d for d in windows] + names += ["BETA%d" % d for d in windows] + if use("RSQR"): + fields += ["Rsquare($close, %d)" % d for d in windows] + names += ["RSQR%d" % d for d in windows] + if use("RESI"): + fields += ["Resi($close, %d)/$close" % d for d in windows] + names += ["RESI%d" % d for d in windows] + if use("MAX"): + fields += ["Max($high, %d)/$close" % d for d in windows] + names += ["MAX%d" % d for d in windows] + if use("LOW"): + fields += ["Min($low, %d)/$close" % d for d in windows] + names += ["MIN%d" % d for d in windows] + if use("QTLU"): + fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows] + names += ["QTLU%d" % d for d in windows] + if use("QTLD"): + fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows] + names += ["QTLD%d" % d for d in windows] + if use("RANK"): + fields += ["Rank($close, %d)" % d for d in windows] + names += ["RANK%d" % d for d in windows] + if use("RSV"): + fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows] + names += ["RSV%d" % d for d in windows] + if use("IMAX"): + fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows] + names += ["IMAX%d" % d for d in windows] + if use("IMIN"): + fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows] + names += ["IMIN%d" % d for d in windows] + if use("IMXD"): + fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows] + names += ["IMXD%d" % d for d in windows] + if use("CORR"): + fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows] + names += ["CORR%d" % d for d in windows] + if use("CORD"): + fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows] + names += ["CORD%d" % d for d in windows] + if use("CNTP"): + fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows] + names += ["CNTP%d" % d for d in windows] + if use("CNTN"): + fields += ["Mean($closeRef($close, 1), %d)-Mean($close= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True) + if self.fillna_feature: + x.fillna(0, inplace=True) + return x + + TimeInspector.set_time_mark() + + # Copy + df_new = df.copy() + + # Label + cols = df.columns[df.columns.str.contains("^LABEL")] + df_new[cols] = df[cols].groupby(level="datetime").apply(_label_norm) + + # Features + cols = df.columns[df.columns.str.contains("^KLEN|^KLOW|^KUP")] + df_new[cols] = df[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm) + + cols = df.columns[df.columns.str.contains("^KLOW2|^KUP2")] + df_new[cols] = df[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm) + + _cols = [ + "KMID", + "KSFT", + "OPEN", + "HIGH", + "LOW", + "CLOSE", + "VWAP", + "ROC", + "MA", + "BETA", + "RESI", + "QTLU", + "QTLD", + "RSV", + "SUMP", + "SUMN", + "SUMD", + "VSUMP", + "VSUMN", + "VSUMD", + ] + pat = "|".join(["^" + x for x in _cols]) + cols = df.columns[df.columns.str.contains(pat) & (~df.columns.isin(["HIGH0", "LOW0"]))] + df_new[cols] = df[cols].groupby(level="datetime").apply(_feature_norm) + + cols = df.columns[df.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")] + df_new[cols] = df[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm) + + cols = df.columns[df.columns.str.contains("^RSQR")] + df_new[cols] = df[cols].fillna(0).groupby(level="datetime").apply(_feature_norm) + + cols = df.columns[df.columns.str.contains("^MAX|^HIGH0")] + df_new[cols] = df[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm) + + cols = df.columns[df.columns.str.contains("^MIN|^LOW0")] + df_new[cols] = df[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm) + + cols = df.columns[df.columns.str.contains("^CORR|^CORD")] + df_new[cols] = df[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm) + + cols = df.columns[df.columns.str.contains("^WVMA")] + df_new[cols] = df[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm) + + TimeInspector.log_cost_time("Finished preprocessing data.") + + return df_new diff --git a/qlib/contrib/estimator/trainer.py b/qlib/contrib/estimator/trainer.py new file mode 100644 index 0000000000..d19051de92 --- /dev/null +++ b/qlib/contrib/estimator/trainer.py @@ -0,0 +1,315 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# coding=utf-8 + +from abc import abstractmethod + +import pandas as pd +import numpy as np +from scipy.stats import pearsonr + +from ...log import get_module_logger, TimeInspector +from .handler import BaseDataHandler +from .launcher import CONFIG_MANAGER +from .fetcher import create_fetcher_with_config +from ...utils import drop_nan_by_y_index, transform_end_date + + +class BaseTrainer(object): + def __init__(self, model_class, model_save_path, model_args, data_handler: BaseDataHandler, sacred_ex, **kwargs): + # 1. Model. + self.model_class = model_class + self.model_save_path = model_save_path + self.model_args = model_args + + # 2. Data handler. + self.data_handler = data_handler + + # 3. Sacred ex. + self.ex = sacred_ex + + # 4. Logger. + self.logger = get_module_logger("Trainer") + + # 5. Data time + self.train_start_date = kwargs.get("train_start_date", None) + self.train_end_date = kwargs.get("train_end_date", None) + self.validate_start_date = kwargs.get("validate_start_date", None) + self.validate_end_date = kwargs.get("validate_end_date", None) + self.test_start_date = kwargs.get("test_start_date", None) + self.test_end_date = transform_end_date(kwargs.get("test_end_date", None)) + + @abstractmethod + def train(self): + """ + Implement this method indicating how to train a model. + """ + pass + + @abstractmethod + def load(self): + """ + Implement this method indicating how to restore a model and the data. + """ + pass + + @abstractmethod + def get_test_pred(self): + """ + Implement this method indicating how to get prediction result(s) from a model. + """ + pass + + @abstractmethod + def get_test_performance(self): + """ + Implement this method indicating how to get the performance of the model. + """ + pass + + def get_test_score(self): + """ + Override this method to transfer the predict result(s) into the score of the stock. + Note: If this is a multi-label training, you need to transfer predict labels into one score. + Or you can just use the result of `get_test_pred()` (you can also process the result) if this is one label training. + We use the first column of the result of `get_test_pred()` as default method (regard it as one label training). + """ + pred = self.get_test_pred() + pred_score = pd.DataFrame(index=pred.index) + pred_score["score"] = pred.iloc(axis=1)[0] + return pred_score + + +class StaticTrainer(BaseTrainer): + def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs): + super(StaticTrainer, self).__init__(model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs) + self.model = None + + split_data = self.data_handler.get_split_data( + self.train_start_date, + self.train_end_date, + self.validate_start_date, + self.validate_end_date, + self.test_start_date, + self.test_end_date, + ) + ( + self.x_train, + self.y_train, + self.x_validate, + self.y_validate, + self.x_test, + self.y_test, + ) = split_data + + def train(self): + TimeInspector.set_time_mark() + model = self.model_class(**self.model_args) + + if CONFIG_MANAGER.ex_config.finetune: + fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True) + loader_model = fetcher.get_experiment( + exp_name=CONFIG_MANAGER.ex_config.loader_name, + exp_id=CONFIG_MANAGER.ex_config.loader_id, + fields=["model"], + )["model"] + + if isinstance(loader_model, list): + model_index = ( + -1 + if CONFIG_MANAGER.ex_config.loader_model_index is None + else CONFIG_MANAGER.ex_config.loader_model_index + ) + loader_model = loader_model[model_index] + + model.load(loader_model) + model.finetune(self.x_train, self.y_train, self.x_validate, self.y_validate) + else: + model.fit(self.x_train, self.y_train, self.x_validate, self.y_validate) + model.save(self.model_save_path) + self.ex.add_artifact(self.model_save_path) + self.model = model + TimeInspector.log_cost_time("Finished training model.") + + def load(self): + model = self.model_class(**self.model_args) + + # Load model + fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True) + loader_model = fetcher.get_experiment( + exp_name=CONFIG_MANAGER.ex_config.loader_name, + exp_id=CONFIG_MANAGER.ex_config.loader_id, + fields=["model"], + )["model"] + + if isinstance(loader_model, list): + model_index = ( + -1 + if CONFIG_MANAGER.ex_config.loader_model_index is None + else CONFIG_MANAGER.ex_config.loader_model_index + ) + loader_model = loader_model[model_index] + + model.load(loader_model) + + # Save model, after load, if you don't save the model, the result of this experiment will be no model + model.save(self.model_save_path) + self.ex.add_artifact(self.model_save_path) + self.model = model + + def get_test_pred(self): + pred = self.model.predict(self.x_test) + pred = pd.DataFrame(pred, index=self.x_test.index, columns=self.y_test.columns) + return pred + + def get_test_performance(self): + model_score = self.model.score(self.x_test, self.y_test) + # Remove rows from x, y and w, which contain Nan in any columns in y_test. + x_test, y_test, __ = drop_nan_by_y_index(self.x_test, self.y_test) + pred_test = self.model.predict(x_test) + model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0] + + performance = {"model_score": model_score, "model_pearsonr": model_pearsonr} + return performance + + +class RollingTrainer(BaseTrainer): + def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs): + super(RollingTrainer, self).__init__( + model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs + ) + self.rolling_period = kwargs.get("rolling_period", 60) + self.models = [] + self.rolling_data = [] + self.all_x_test = [] + self.all_y_test = [] + for data in self.data_handler.get_rolling_data( + self.train_start_date, + self.train_end_date, + self.validate_start_date, + self.validate_end_date, + self.test_start_date, + self.test_end_date, + self.rolling_period, + ): + self.rolling_data.append(data) + __, __, __, __, x_test, y_test = data + self.all_x_test.append(x_test) + self.all_y_test.append(y_test) + + def train(self): + # 1. Get total data parts. + # total_data_parts = self.data_handler.total_data_parts + # self.logger.warning('Total numbers of model are: {}, start training models...'.format(total_data_parts)) + if CONFIG_MANAGER.ex_config.finetune: + fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True) + loader_model = fetcher.get_experiment( + exp_name=CONFIG_MANAGER.ex_config.loader_name, + exp_id=CONFIG_MANAGER.ex_config.loader_id, + fields=["model"], + )["model"] + loader_model_index = CONFIG_MANAGER.ex_config.loader_model_index + previous_model_path = "" + # 2. Rolling train. + for ( + index, + (x_train, y_train, x_validate, y_validate, x_test, y_test), + ) in enumerate(self.rolling_data): + TimeInspector.set_time_mark() + model = self.model_class(**self.model_args) + + if CONFIG_MANAGER.ex_config.finetune: + # Finetune model + if loader_model_index is None and isinstance(loader_model, list): + try: + model.load(loader_model[index]) + except IndexError: + # Load model by previous_model_path + with open(previous_model_path, "rb") as fp: + model.load(fp) + model.finetune(x_train, y_train, x_validate, y_validate) + else: + + if index == 0: + loader_model = ( + loader_model[loader_model_index] if isinstance(loader_model, list) else loader_model + ) + model.load(loader_model) + else: + with open(previous_model_path, "rb") as fp: + model.load(fp) + + model.finetune(x_train, y_train, x_validate, y_validate) + + else: + model.fit(x_train, y_train, x_validate, y_validate) + + model_save_path = "{}_{}".format(self.model_save_path, index) + model.save(model_save_path) + previous_model_path = model_save_path + self.ex.add_artifact(model_save_path) + self.models.append(model) + TimeInspector.log_cost_time("Finished training model: {}.".format(index + 1)) + + def load(self): + """ + Load the data and the model + """ + fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True) + loader_model = fetcher.get_experiment( + exp_name=CONFIG_MANAGER.ex_config.loader_name, + exp_id=CONFIG_MANAGER.ex_config.loader_id, + fields=["model"], + )["model"] + for index in range(len(self.all_x_test)): + model = self.model_class(**self.model_args) + + model.load(loader_model[index]) + + # Save model + model_save_path = "{}_{}".format(self.model_save_path, index) + model.save(model_save_path) + self.ex.add_artifact(model_save_path) + + self.models.append(model) + + def get_test_pred(self): + """ + Predict the score on test data with the models. + Please ensure the models and data are loaded before call this score. + + :return: the predicted scores for the pred + """ + pred_df_list = [] + y_test_columns = self.all_y_test[0].columns + # Start iteration. + for model, x_test in zip(self.models, self.all_x_test): + pred = model.predict(x_test) + pred_df = pd.DataFrame(pred, index=x_test.index, columns=y_test_columns) + pred_df_list.append(pred_df) + return pd.concat(pred_df_list) + + def get_test_performance(self): + """ + Get the performances of the models + + :return: the performances of models + """ + pred_test_list = [] + y_test_list = [] + scorer = self.models[0]._scorer + for model, x_test, y_test in zip(self.models, self.all_x_test, self.all_y_test): + # Remove rows from x, y and w, which contain Nan in any columns in y_test. + x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test) + pred_test_list.append(model.predict(x_test)) + y_test_list.append(np.squeeze(y_test.values)) + + pred_test_array = np.concatenate(pred_test_list, axis=0) + y_test_array = np.concatenate(y_test_list, axis=0) + + model_score = scorer(y_test_array, pred_test_array) + model_pearsonr = pearsonr(np.ravel(y_test_array), np.ravel(pred_test_array))[0] + + performance = {"model_score": model_score, "model_pearsonr": model_pearsonr} + return performance diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py new file mode 100644 index 0000000000..4a25df4a02 --- /dev/null +++ b/qlib/contrib/evaluate.py @@ -0,0 +1,389 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import division +from __future__ import print_function + +import numpy as np +import pandas as pd +import inspect +from ..log import get_module_logger +from .strategy import TopkAmountStrategy, TopkWeightStrategy +from .strategy.strategy import BaseStrategy +from .backtest.exchange import Exchange +from .backtest.backtest import backtest as backtest_func, get_date_range + +from ..data import D +from ..config import C + +logger = get_module_logger("Evaluate") + + +def risk_analysis(r, N=252): + """Risk Analysis + + Parameters + ---------- + r : pandas.Series + daily return series + N: int + scaler for annualizing sharpe ratio (day: 250, week: 50, month: 12) + """ + mean = r.mean() + std = r.std(ddof=1) + annual = mean * N + sharpe = mean / std * np.sqrt(N) + mdd = (r.cumsum() - r.cumsum().cummax()).min() + data = {"mean": mean, "std": std, "annual": annual, "sharpe": sharpe, "mdd": mdd} + res = pd.Series(data, index=data.keys()).to_frame("risk") + return res + + +def get_strategy( + strategy=None, + topk=50, + margin=0.5, + risk_degree=0.95, + str_type="amount", + adjust_dates=None, +): + """get_strategy + + Parameters + ---------- + + strategy : Strategy() + strategy used in backtest + topk : int (Default value: 50) + top-N stocks to buy. + margin : int or float(Default value: 0.5) + if isinstance(margin, int): + sell_limit = margin + else: + sell_limit = pred_in_a_day.count() * margin + buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit) + sell_limit should be no less than topk + risk_degree: float + 0-1, 0.95 for example, use 95% money to trade + str_type: 'amount' or 'weight' + strategy type: TopkAmountStrategy or TopkWeightStrategy + + Returns + ------- + :class: Strategy + an initialized strategy object + """ + if strategy is None: + logger.info("Create new streategy ") + if str_type == "amount": + str_cls = TopkAmountStrategy + elif str_type == "weight": + str_cls = TopkWeightStrategy + else: + raise ValueError("Unsupported strategy type") + strategy = str_cls( + topk=topk, + buffer_margin=margin, + risk_degree=risk_degree, + adjust_dates=adjust_dates, + ) + if not isinstance(strategy, BaseStrategy): + raise TypeError("Strategy not supported") + return strategy + + +def get_exchange( + pred, + exchange=None, + subscribe_fields=[], + open_cost=0.0015, + close_cost=0.0025, + min_cost=5.0, + trade_unit=None, + limit_threshold=None, + deal_price=None, + extract_codes=False, + shift=1, +): + """get_exchange + + Parameters + ---------- + + # exchange related arguments + exchange: Exchange() + subscribe_fields: list + subscribe fields + open_cost : float + open transaction cost + close_cost : float + close transaction cost + min_cost : float + min transaction cost + trade_unit : int + 100 for China A + deal_price: str + dealing price type: 'close', 'open', 'vwap' + limit_threshold : float + limit move 0.1 (10%) for example, long and short with same limit + extract_codes: bool + will we pass the codes extracted from the pred to the exchange. + NOTE: This will be faster with offline qlib. + + Returns + ------- + :class: Exchange + an initialized Exchange object + """ + + if trade_unit is None: + trade_unit = C.trade_unit + if limit_threshold is None: + limit_threshold = C.limit_threshold + if deal_price is None: + deal_price = C.deal_price + if exchange is None: + logger.info("Create new exchange") + # handle exception for deal_price + if deal_price[0] != "$": + deal_price = "$" + deal_price + if extract_codes: + codes = sorted(pred.index.get_level_values(0).unique()) + else: + codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks + + dates = sorted(pred.index.get_level_values(1).unique()) + dates = np.append(dates, get_date_range(dates[-1], shift=shift)) + + exchange = Exchange( + trade_dates=dates, + codes=codes, + deal_price=deal_price, + subscribe_fields=subscribe_fields, + limit_threshold=limit_threshold, + open_cost=open_cost, + close_cost=close_cost, + min_cost=min_cost, + trade_unit=trade_unit, + ) + return exchange + + +# This is the api for compatibility for legacy code +def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **kwargs): + """This function will help you set a reasonable Exchange and provide default value for strategy + Parameter + ---------- + + # backtest workflow related or commmon arguments + pred : pandas.DataFrame + predict should has index and one `score` column + account : float + init account value + shift : int + whether to shift prediction by one day + benchmark : str + benchmark code, default is SH000905 CSI 500 + verbose : bool + whether to print log + + # strategy related arguments + strategy : Strategy() + strategy used in backtest + topk : int (Default value: 50) + top-N stocks to buy. + margin : int or float(Default value: 0.5) + if isinstance(margin, int): + sell_limit = margin + else: + sell_limit = pred_in_a_day.count() * margin + buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit) + sell_limit should be no less than topk + risk_degree: float + 0-1, 0.95 for example, use 95% money to trade + str_type: 'amount' or 'weight' + strategy type: TopkAmountStrategy or TopkWeightStrategy + + # exchange related arguments + exchange: Exchange() + pass the exchange for speeding up. + subscribe_fields: list + subscribe fields + open_cost : float + open transaction cost. The default value is 0.002(0.2%). + close_cost : float + close transaction cost. The default value is 0.002(0.2%). + min_cost : float + min transaction cost + trade_unit : int + 100 for China A + deal_price: str + dealing price type: 'close', 'open', 'vwap' + limit_threshold : float + limit move 0.1 (10%) for example, long and short with same limit + extract_codes: bool + will we pass the codes extracted from the pred to the exchange. + NOTE: This will be faster with offline qlib. + """ + # check strategy: + spec = inspect.getfullargspec(get_strategy) + str_args = {k: v for k, v in kwargs.items() if k in spec.args} + strategy = get_strategy(**str_args) + + # init exchange: + spec = inspect.getfullargspec(get_exchange) + ex_args = {k: v for k, v in kwargs.items() if k in spec.args} + trade_exchange = get_exchange(pred, **ex_args) + + # run backtest + report_df, positions = backtest_func( + pred=pred, + strategy=strategy, + trade_exchange=trade_exchange, + shift=shift, + verbose=verbose, + account=account, + benchmark=benchmark, + ) + # for compatibility of the old api. return the dict positions + positions = {k: p.position for k, p in positions.items()} + return report_df, positions + + +def long_short_backtest( + pred, + topk=50, + deal_price=None, + shift=1, + open_cost=0, + close_cost=0, + trade_unit=None, + limit_threshold=None, + min_cost=5, + subscribe_fields=[], + extract_codes=False, +): + """ + A backtest for long-short strategy + + :param pred: The trading signal produced on day `T` + :param topk: The short topk securities and long topk securities + :param deal_price: The price to deal the trading + :param shift: Whether to shift prediction by one day. The trading day will be T+1 if shift==1. + :param open_cost: open transaction cost + :param close_cost: close transaction cost + :param trade_unit: 100 for China A + :param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit + :param min_cost: min transaction cost + :param subscribe_fields: subscribe fields + :param extract_codes: bool + will we pass the codes extracted from the pred to the exchange. + NOTE: This will be faster with offline qlib. + :return: The result of backtest, it is represented by a dict. + { "long": long_returns(excess), + "short": short_returns(excess), + "long_short": long_short_returns} + """ + + if trade_unit is None: + trade_unit = C.trade_unit + if limit_threshold is None: + limit_threshold = C.limit_threshold + if deal_price is None: + deal_price = C.deal_price + if deal_price[0] != "$": + deal_price = "$" + deal_price + + subscribe_fields = subscribe_fields.copy() + profit_str = f"Ref({deal_price}, -1)/{deal_price} - 1" + subscribe_fields.append(profit_str) + + trade_exchange = get_exchange( + pred=pred, + deal_price=deal_price, + subscribe_fields=subscribe_fields, + limit_threshold=limit_threshold, + open_cost=open_cost, + close_cost=close_cost, + min_cost=min_cost, + trade_unit=trade_unit, + extract_codes=extract_codes, + shift=shift, + ) + + _pred_dates = pred.index.get_level_values(level="datetime") + predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max()) + trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift)) + + long_returns = {} + short_returns = {} + ls_returns = {} + + for pdate, date in zip(predict_dates, trade_dates): + score = pred.loc(axis=0)[:, pdate] + score = score.reset_index().sort_values(by="score", ascending=False) + + long_stocks = list(score.iloc[:topk]["instrument"]) + short_stocks = list(score.iloc[-topk:]["instrument"]) + + score = score.set_index(["instrument", "datetime"]).sort_index() + + long_profit = [] + short_profit = [] + all_profit = [] + + for stock in long_stocks: + if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date): + continue + profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str] + if np.isnan(profit): + long_profit.append(0) + else: + long_profit.append(profit) + + for stock in short_stocks: + if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date): + continue + profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str] + if np.isnan(profit): + short_profit.append(0) + else: + short_profit.append(-profit) + + for stock in list(score.loc(axis=0)[:, pdate].index.get_level_values(level=0)): + # exclude the suspend stock + if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date): + continue + profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str] + if np.isnan(profit): + all_profit.append(0) + else: + all_profit.append(profit) + + long_returns[date] = np.mean(long_profit) - np.mean(all_profit) + short_returns[date] = np.mean(short_profit) + np.mean(all_profit) + ls_returns[date] = np.mean(short_profit) + np.mean(long_profit) + + return dict( + zip( + ["long", "short", "long_short"], + map(pd.Series, [long_returns, short_returns, ls_returns]), + ) + ) + + +def t_run(): + pred_FN = "./check_pred.csv" + pred = pd.read_csv(pred_FN) + pred["datetime"] = pd.to_datetime(pred["datetime"]) + pred = pred.set_index([pred.columns[0], pred.columns[1]]) + pred = pred.iloc[:9000] + report_df, positions = backtest(pred=pred) + print(report_df.head()) + print(positions.keys()) + print(positions[list(positions.keys())[0]]) + return 0 + + +if __name__ == "__main__": + t_run() diff --git a/qlib/contrib/evaluate_portfolio.py b/qlib/contrib/evaluate_portfolio.py new file mode 100644 index 0000000000..04ddd8db04 --- /dev/null +++ b/qlib/contrib/evaluate_portfolio.py @@ -0,0 +1,246 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import copy +import numpy as np +import pandas as pd +from scipy.stats import spearmanr, pearsonr + + +from ..data import D + +from collections import OrderedDict + + +def _get_position_value_from_df(evaluate_date, position, close_data_df): + """Get position value by existed close data df + close_data_df: + pd.DataFrame + multi-index + close_data_df['$close'][stock_id][evaluate_date]: close price for (stock_id, evaluate_date) + position: + same in get_position_value() + """ + value = 0 + for stock_id, report in position.items(): + if stock_id != "cash": + value += report["amount"] * close_data_df["$close"][stock_id][evaluate_date] + # value += report['amount'] * report['price'] + if "cash" in position: + value += position["cash"] + return value + + +def get_position_value(evaluate_date, position): + """sum of close*amount + + get value of postion + + use close price + + postions: + { + Timestamp('2016-01-05 00:00:00'): + { + 'SH600022': + { + 'amount':100.00, + 'price':12.00 + }, + + 'cash':100000.0 + } + } + + It means Hold 100.0 'SH600022' and 100000.0 RMB in '2016-01-05' + """ + # load close price for position + # position should also consider cash + instruments = list(position.keys()) + instruments = list(set(instruments) - set(["cash"])) # filter 'cash' + fields = ["$close"] + close_data_df = D.features( + instruments, + fields, + start_time=evaluate_date, + end_time=evaluate_date, + freq="day", + disk_cache=0, + ) + value = _get_position_value_from_df(evaluate_date, position, close_data_df) + return value + + +def get_position_list_value(positions): + # generate instrument list and date for whole poitions + instruments = set() + for day, position in positions.items(): + instruments.update(position.keys()) + instruments = list(set(instruments) - set(["cash"])) # filter 'cash' + instruments.sort() + day_list = list(positions.keys()) + day_list.sort() + start_date, end_date = day_list[0], day_list[-1] + # load data + fields = ["$close"] + close_data_df = D.features( + instruments, + fields, + start_time=start_date, + end_time=end_date, + freq="day", + disk_cache=0, + ) + # generate value + # return dict for time:position_value + value_dict = OrderedDict() + for day, position in positions.items(): + value = _get_position_value_from_df(evaluate_date=day, position=position, close_data_df=close_data_df) + value_dict[day] = value + return value_dict + + +def get_daily_return_series_from_positions(positions, init_asset_value): + """Parameters + generate daily return series from position view + positions: positions generated by strategy + init_asset_value : init asset value + return: pd.Series of daily return , return_series[date] = daily return rate + """ + value_dict = get_position_list_value(positions) + value_series = pd.Series(value_dict) + value_series = value_series.sort_index() # check date + return_series = value_series.pct_change() + return_series[value_series.index[0]] = ( + value_series[value_series.index[0]] / init_asset_value - 1 + ) # update daily return for the first date + return return_series + + +def get_annual_return_from_positions(positions, init_asset_value): + """Annualized Returns + + p_r = (p_end / p_start)^{(250/n)} - 1 + + p_r annual return + p_end final value + p_start init value + n days of backtest + + """ + date_range_list = sorted(list(positions.keys())) + end_time = date_range_list[-1] + p_end = get_position_value(end_time, positions[end_time]) + p_start = init_asset_value + n_period = len(date_range_list) + annual = pow((p_end / p_start), (250 / n_period)) - 1 + + return annual + + +def get_annaul_return_from_return_series(r, method="ci"): + """Risk Analysis from daily return series + + Parameters + ---------- + r : pandas.Series + daily return series + method : str + interest calculation method, ci(compound interest)/si(simple interest) + """ + mean = r.mean() + annual = (1 + mean) ** 250 - 1 if method == "ci" else mean * 250 + + return annual + + +def get_sharpe_ratio_from_return_series(r, risk_free_rate=0.00, method="ci"): + """Risk Analysis + + Parameters + ---------- + r : pandas.Series + daily return series + method : str + interest calculation method, ci(compound interest)/si(simple interest) + risk_free_rate : float + risk_free_rate, default as 0.00, can set as 0.03 etc + """ + std = r.std(ddof=1) + annual = get_annaul_return_from_return_series(r, method=method) + sharpe = (annual - risk_free_rate) / std / np.sqrt(250) + + return sharpe + + +def get_max_drawdown_from_series(r): + """Risk Analysis from asset value + + cumprod way + + Parameters + ---------- + r : pandas.Series + daily return series + """ + # mdd = ((r.cumsum() - r.cumsum().cummax()) / (1 + r.cumsum().cummax())).min() + + mdd = (((1 + r).cumprod() - (1 + r).cumprod().cummax()) / ((1 + r).cumprod().cummax())).min() + + return mdd + + +def get_turnover_rate(): + # in backtest + pass + + +def get_beta(r, b): + """Risk Analysis beta + + Parameters + ---------- + r : pandas.Series + daily return series of strategy + b : pandas.Series + daily return series of baseline + """ + cov_r_b = np.cov(r, b) + var_b = np.var(b) + return cov_r_b / var_b + + +def get_alpha(r, b, risk_free_rate=0.03): + beta = get_beta(r, b) + annaul_r = get_annaul_return_from_return_series(r) + annaul_b = get_annaul_return_from_return_series(b) + + alpha = annaul_r - risk_free_rate - beta * (annaul_b - risk_free_rate) + + return alpha + + +def get_volatility_from_series(r): + return r.std(ddof=1) + + +def get_rank_ic(a, b): + """Rank IC + + Parameters + ---------- + r : pandas.Series + daily score series of feature + b : pandas.Series + daily return series + + """ + return spearmanr(a, b).correlation + + +def get_normal_ic(a, b): + return pearsonr(a, b).correlation diff --git a/qlib/contrib/model/__init__.py b/qlib/contrib/model/__init__.py new file mode 100644 index 0000000000..c639b57f53 --- /dev/null +++ b/qlib/contrib/model/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import warnings + +from .base import Model diff --git a/qlib/contrib/model/base.py b/qlib/contrib/model/base.py new file mode 100644 index 0000000000..b3ea917a52 --- /dev/null +++ b/qlib/contrib/model/base.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import abc +import six + + +@six.add_metaclass(abc.ABCMeta) +class Model(object): + """Model base class""" + + @property + def name(self): + return type(self).__name__ + + def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs): + """fix train with cross-validation + Fit model when ex_config.finetune is False + + Parameters + ---------- + x_train : pd.dataframe + train data + y_train : pd.dataframe + train label + x_valid : pd.dataframe + valid data + y_valid : pd.dataframe + valid label + w_train : pd.dataframe + train weight + w_valid : pd.dataframe + valid weight + + Returns + ---------- + Model + trained model + """ + raise NotImplementedError() + + def score(self, x_test, y_test, w_test=None, **kwargs): + """evaluate model with test data/label + + Parameters + ---------- + x_test : pd.dataframe + test data + y_test : pd.dataframe + test label + w_test : pd.dataframe + test weight + + Returns + ---------- + float + evaluation score + """ + raise NotImplementedError() + + def predict(self, x_test, **kwargs): + """predict given test data + + Parameters + ---------- + x_test : pd.dataframe + test data + + Returns + ---------- + np.ndarray + test predict label + """ + raise NotImplementedError() + + def save(self, fname, **kwargs): + """save model + + Parameters + ---------- + fname : str + model filename + """ + # TODO: Currently need to save the model as a single file, otherwise the estimator may not be compatible + raise NotImplementedError() + + def load(self, buffer, **kwargs): + """load model + + Parameters + ---------- + buffer : bytes + binary data of model parameters + + Returns + ---------- + Model + loaded model + """ + raise NotImplementedError() + + def get_data_with_date(self, date, **kwargs): + """ + Will be called in online module + need to return the data that used to predict the label (score) of stocks at date. + + :param + date: pd.Timestamp + predict date + :return: + data: the input data that used to predict the label (score) of stocks at predict date. + """ + raise NotImplementedError("get_data_with_date for this model is not implemented.") + + def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs): + """Finetune model + In `RollingTrainer`: + if loader.model_index is None: + If provide 'Static Model', based on the provided 'Static' model update. + If provide 'Rolling Model', skip the model of load, based on the last 'provided model' update. + + if loader.model_index is not None: + Based on the provided model(loader.model_index) update. + + In `StaticTrainer`: + If the load is 'static model': + Based on the 'static model' update + If the load is 'rolling model': + Based on the provided model(`loader.model_index`) update. If `loader.model_index` is None, use the last model. + + Parameters + ---------- + x_train : pd.dataframe + train data + y_train : pd.dataframe + train label + x_valid : pd.dataframe + valid data + y_valid : pd.dataframe + valid label + w_train : pd.dataframe + train weight + w_valid : pd.dataframe + valid weight + + Returns + ---------- + Model + finetune model + """ + raise NotImplementedError("Finetune for this model is not implemented.") diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py new file mode 100644 index 0000000000..e79945d8af --- /dev/null +++ b/qlib/contrib/model/gbdt.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import numpy as np +import lightgbm as lgb +from sklearn.metrics import roc_auc_score, mean_squared_error + +from .base import Model +from ...utils import drop_nan_by_y_index + + +class LGBModel(Model): + """LightGBM Model + + Parameters + ---------- + param_update : dict + training parameters + """ + + _params = dict() + + def __init__(self, loss="mse", **kwargs): + if loss not in {"mse", "binary"}: + raise NotImplementedError + self._scorer = mean_squared_error if loss == "mse" else roc_auc_score + self._params.update(objective=loss, **kwargs) + self._model = None + + def fit( + self, + x_train, + y_train, + x_valid, + y_valid, + w_train=None, + w_valid=None, + num_boost_round=1000, + early_stopping_rounds=50, + verbose_eval=20, + evals_result=dict(), + **kwargs + ): + #print("input featrue", x_train) + #print("input label", y_train) + #print("input weight", w_train) + # Lightgbm need 1D array as its label + if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: + y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values) + else: + raise ValueError("LightGBM doesn't support multi-label training") + + w_train_weight = None if w_train is None else w_train.values + w_valid_weight = None if w_valid is None else w_valid.values + + dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight) + dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight) + self._model = lgb.train( + self._params, + dtrain, + num_boost_round=num_boost_round, + valid_sets=[dtrain, dvalid], + valid_names=["train", "valid"], + early_stopping_rounds=early_stopping_rounds, + verbose_eval=verbose_eval, + evals_result=evals_result, + **kwargs + ) + evals_result["train"] = list(evals_result["train"].values())[0] + evals_result["valid"] = list(evals_result["valid"].values())[0] + + def predict(self, x_test): + print("predict test", x_test) + if self._model is None: + raise ValueError("model is not fitted yet!") + return self._model.predict(x_test.values) + + def score(self, x_test, y_test, w_test=None): + # Remove rows from x, y and w, which contain Nan in any columns in y_test. + x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test) + preds = self.predict(x_test) + w_test_weight = None if w_test is None else w_test.values + return self._scorer(y_test.values, preds, sample_weight=w_test_weight) + + def save(self, filename): + if self._model is None: + raise ValueError("model is not fitted yet!") + self._model.save_model(filename) + + def load(self, buffer): + self._model = lgb.Booster(params={"model_str": buffer.decode("utf-8")}) diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py new file mode 100644 index 0000000000..5402c25bd9 --- /dev/null +++ b/qlib/contrib/model/pytorch_nn.py @@ -0,0 +1,363 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import os +import numpy as np +import pandas as pd +from sklearn.metrics import roc_auc_score, mean_squared_error +import logging +from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index +from ...log import get_module_logger, TimeInspector + +import torch +import torch.nn as nn +import torch.optim as optim + +from .base import Model + + +class DNNModelPytorch(Model): + """DNN Model + + Parameters + ---------- + input_dim : int + input dimension + output_dim : int + output dimension + layers : tuple + layer sizes + lr : float + learning rate + lr_decay : float + learning rate decay + lr_decay_steps : int + learning rate decay steps + optimizer : str + optimizer name + GPU : str + the GPU ID(s) used for training + """ + + def __init__( + self, + input_dim, + output_dim, + layers=(256, 256, 128), + lr=0.001, + max_steps=300, + batch_size=2000, + early_stop_rounds=50, + eval_steps=20, + lr_decay=0.96, + lr_decay_steps=100, + optimizer="gd", + loss="mse", + GPU="0", + **kwargs + ): + # Set logger. + self.logger = get_module_logger("DNNModelPytorch") + self.logger.info("DNN pytorch version...") + + # set hyper-parameters. + self.layers = layers + self.lr = lr + self.max_steps = max_steps + self.batch_size = batch_size + self.early_stop_rounds = early_stop_rounds + self.eval_steps = eval_steps + self.lr_decay = lr_decay + self.lr_decay_steps = lr_decay_steps + self.optimizer = optimizer.lower() + self.loss_type = loss + self.visible_GPU = GPU + + self.logger.info( + "DNN parameters setting:" + "\nlayers : {}" + "\nlr : {}" + "\nmax_steps : {}" + "\nbatch_size : {}" + "\nearly_stop_rounds : {}" + "\neval_steps : {}" + "\nlr_decay : {}" + "\nlr_decay_steps : {}" + "\noptimizer : {}" + "\nloss_type : {}" + "\neval_steps : {}" + "\nvisible_GPU : {}".format( + layers, + lr, + max_steps, + batch_size, + early_stop_rounds, + eval_steps, + lr_decay, + lr_decay_steps, + optimizer, + loss, + eval_steps, + GPU, + ) + ) + + if loss not in {"mse", "binary"}: + raise NotImplementedError("loss {} is not supported!".format(loss)) + self._scorer = mean_squared_error if loss == "mse" else roc_auc_score + + self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type) + if optimizer.lower() == "adam": + self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr) + elif optimizer.lower() == "gd": + self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr) + else: + raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + + # Reduce learning rate when loss has stopped decrease + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.train_optimizer, + mode="min", + factor=0.5, + patience=10, + verbose=True, + threshold=0.0001, + threshold_mode="rel", + cooldown=0, + min_lr=0.00001, + eps=1e-08, + ) + + self._fitted = False + self.dnn_model.cuda() + + # set the visible GPU + if self.visible_GPU: + os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU + + def fit( + self, + x_train, + y_train, + x_valid, + y_valid, + w_train=None, + w_valid=None, + evals_result=dict(), + verbose=True, + save_path=None, + ): + + if w_train is None: + w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index) + if w_valid is None: + w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index) + + save_path = create_save_path(save_path) + stop_steps = 0 + train_loss = 0 + best_loss = np.inf + evals_result["train"] = [] + evals_result["valid"] = [] + + # train + self.logger.info("training...") + self._fitted = True + + # prepare training data + x_train_values = torch.from_numpy(x_train.values).float() + y_train_values = torch.from_numpy(y_train.values).float() + w_train_values = torch.from_numpy(w_train.values).float() + train_num = y_train_values.shape[0] + + # prepare validation data + x_val_cuda = torch.from_numpy(x_valid.values).float() + y_val_cuda = torch.from_numpy(y_valid.values).float() + w_val_cuda = torch.from_numpy(w_valid.values).float() + + x_val_cuda = x_val_cuda.cuda() + y_val_cuda = y_val_cuda.cuda() + w_val_cuda = w_val_cuda.cuda() + + for step in range(self.max_steps): + if stop_steps >= self.early_stop_rounds: + if verbose: + self.logger.info("\tearly stop") + break + loss = AverageMeter() + self.dnn_model.train() + self.train_optimizer.zero_grad() + + choice = np.random.choice(train_num, self.batch_size) + x_batch = x_train_values[choice] + y_batch = y_train_values[choice] + w_batch = w_train_values[choice] + + x_batch_cuda = x_batch.float().cuda() + y_batch_cuda = y_batch.float().cuda() + w_batch_cuda = w_batch.float().cuda() + + # forward + preds = self.dnn_model(x_batch_cuda) + + cur_loss = self.get_loss(preds, w_batch_cuda, y_batch_cuda, self.loss_type) + cur_loss.backward() + self.train_optimizer.step() + loss.update(cur_loss.item()) + + # validation + train_loss += loss.val + if step and step % self.eval_steps == 0: + stop_steps += 1 + train_loss /= self.eval_steps + + with torch.no_grad(): + self.dnn_model.eval() + loss_val = AverageMeter() + + # forward + preds = self.dnn_model(x_val_cuda) + cur_loss_val = self.get_loss(preds, w_val_cuda, y_val_cuda, self.loss_type) + loss_val.update(cur_loss_val.item()) + if verbose: + self.logger.info( + "[Epoch {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val) + ) + evals_result["train"].append(train_loss) + evals_result["valid"].append(loss_val.val) + if loss_val.val < best_loss: + if verbose: + self.logger.info( + "\tvalid loss update from {:.6f} to {:.6f}, save checkpoint.".format( + best_loss, loss_val.val + ) + ) + best_loss = loss_val.val + stop_steps = 0 + torch.save(self.dnn_model.state_dict(), save_path) + train_loss = 0 + # update learning rate + self.scheduler.step(cur_loss_val) + + # restore the optimal parameters after training + self.dnn_model.load_state_dict(torch.load(save_path)) + torch.cuda.empty_cache() + + def get_loss(self, pred, w, target, loss_type): + if loss_type == "mse": + sqr_loss = torch.mul(pred - target, pred - target) + loss = torch.mul(sqr_loss, w).mean() + return loss + elif loss_type == "binary": + loss = nn.BCELoss() + return loss(pred, target) + else: + raise NotImplementedError("loss {} is not supported!".format(loss_type)) + + def predict(self, x_test): + if not self._fitted: + raise ValueError("model is not fitted yet!") + x_test = torch.from_numpy(x_test.values).float().cuda() + self.dnn_model.eval() + preds = self.dnn_model(x_test).detach().cpu().numpy() + return preds + + def score(self, x_test, y_test, w_test=None): + # Remove rows from x, y and w, which contain Nan in any columns in y_test. + x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test) + preds = self.predict(x_test) + w_test_weight = None if w_test is None else w_test.values + return self._scorer(y_test.values, preds, sample_weight=w_test_weight) + + def save(self, filename, **kwargs): + with save_multiple_parts_file(filename) as model_dir: + model_path = os.path.join(model_dir, os.path.split(model_dir)[-1]) + # Save model + torch.save(self.dnn_model.state_dict(), model_path) + + def load(self, buffer, **kwargs): + with unpack_archive_with_buffer(buffer) as model_dir: + # Get model name + _model_name = os.path.splitext(list(filter(lambda x: x.startswith("model.bin"), os.listdir(model_dir)))[0])[ + 0 + ] + _model_path = os.path.join(model_dir, _model_name) + # Load model + self.dnn_model.load_state_dict(torch.load(_model_path)) + self._fitted = True + + def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs): + self.fit(x_train, y_train, x_valid, y_valid, w_train=w_train, w_valid=w_valid, **kwargs) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class Net(nn.Module): + def __init__(self, input_dim, output_dim, layers=(256, 256, 256), loss="mse"): + super(Net, self).__init__() + layers = [input_dim] + list(layers) + self.hidden_layer_num = len(layers) + dnn_layers = [] + drop_input = nn.Dropout(0.1) + dnn_layers.append(drop_input) + for i, (input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])): + fc = nn.Linear(input_dim, hidden_units) + # drop = nn.Dropout(0.2) + # relu = nn.ReLU() + # activation = nn.Sigmoid() + activation = nn.Tanh() + bn = nn.BatchNorm1d(hidden_units) + seq = nn.Sequential(fc, activation, bn) + dnn_layers.append(seq) + + drop_output = nn.Dropout(0.1) + dnn_layers.append(drop_output) + self.dnn_layers = nn.ModuleList(dnn_layers) + + if loss == "mse": + fc = nn.Linear(hidden_units, output_dim) + self.output_layer = fc + + elif loss == "binary": + fc = nn.Linear(hidden_units, output_dim) + sigmoid = nn.Sigmoid() + self.output_layer = nn.Sequential(fc, sigmoid) + else: + raise NotImplementedError("loss {} is not supported!".format(loss)) + # optimizer + + self._weight_init() + + def _weight_init(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + m.weight = nn.init.xavier_normal_(m.weight) + + def forward(self, x): + cur_input = x + for i in range(self.hidden_layer_num): + output = self.dnn_layers[i](cur_input) + cur_input = output + output = self.output_layer(output) + return output diff --git a/qlib/contrib/online/__init__.py b/qlib/contrib/online/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/qlib/contrib/online/executor.py b/qlib/contrib/online/executor.py new file mode 100644 index 0000000000..2bd0937a03 --- /dev/null +++ b/qlib/contrib/online/executor.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import re +import json +import copy +import pathlib +import pandas as pd +from ...data import D +from ...utils import get_date_in_file_name +from ...utils import get_pre_trading_date +from ..backtest.order import Order + + +class BaseExecutor: + """ + # Strategy framework document + + class Executor(BaseExecutor): + """ + + def execute(self, trade_account, order_list, trade_date): + """ + return the executed result (trade_info) after trading at trade_date. + NOTICE: trade_account will not be modified after executing. + Parameter + --------- + trade_account : Account() + order_list : list + [Order()] + trade_date : pd.Timestamp + Return + --------- + trade_info : list + [Order(), float, float, float] + """ + raise NotImplementedError("get_execute_result for this model is not implemented.") + + def save_executed_file_from_trade_info(self, trade_info, user_path, trade_date): + """ + Save the trade_info to the .csv transaction file in disk + the columns of result file is + ['date', 'stock_id', 'direction', 'trade_val', 'trade_cost', 'trade_price', 'factor'] + Parameter + --------- + trade_info : list of [Order(), float, float, float] + (order, trade_val, trade_cost, trade_price), trade_info with out factor + user_path: str / pathlib.Path() + the sub folder to save user data + + transaction_path : string / pathlib.Path() + """ + YYYY, MM, DD = str(trade_date.date()).split("-") + folder_path = pathlib.Path(user_path) / "trade" / YYYY / MM + if not folder_path.exists(): + folder_path.mkdir(parents=True) + transaction_path = folder_path / "transaction_{}.csv".format(str(trade_date.date())) + columns = [ + "date", + "stock_id", + "direction", + "amount", + "trade_val", + "trade_cost", + "trade_price", + "factor", + ] + data = [] + for [order, trade_val, trade_cost, trade_price] in trade_info: + data.append( + [ + trade_date, + order.stock_id, + order.direction, + order.amount, + trade_val, + trade_cost, + trade_price, + order.factor, + ] + ) + df = pd.DataFrame(data, columns=columns) + df.to_csv(transaction_path, index=False) + + def load_trade_info_from_executed_file(self, user_path, trade_date): + YYYY, MM, DD = str(trade_date.date()).split("-") + file_path = pathlib.Path(user_path) / "trade" / YYYY / MM / "transaction_{}.csv".format(str(trade_date.date())) + if not file_path.exists(): + raise ValueError("File {} not exists!".format(file_path)) + + filedate = get_date_in_file_name(file_path) + transaction = pd.read_csv(file_path) + trade_info = [] + for i in range(len(transaction)): + date = transaction.loc[i]["date"] + if not date == filedate: + continue + # raise ValueError("date in transaction file {} not equal to it's file date{}".format(date, filedate)) + order = Order( + stock_id=transaction.loc[i]["stock_id"], + amount=transaction.loc[i]["amount"], + trade_date=transaction.loc[i]["date"], + direction=transaction.loc[i]["direction"], + factor=transaction.loc[i]["factor"], + ) + trade_val = transaction.loc[i]["trade_val"] + trade_cost = transaction.loc[i]["trade_cost"] + trade_price = transaction.loc[i]["trade_price"] + trade_info.append([order, trade_val, trade_cost, trade_price]) + return trade_info + + +class SimulatorExecutor(BaseExecutor): + def __init__(self, trade_exchange, verbose=False): + self.trade_exchange = trade_exchange + self.verbose = verbose + self.order_list = [] + + def execute(self, trade_account, order_list, trade_date): + """ + execute the order list, do the trading wil exchange at date. + Will not modify the trade_account. + Parameter + trade_account : Account() + order_list : list + list or orders + trade_date : pd.Timestamp + :return: + trade_info : list of [Order(), float, float, float] + (order, trade_val, trade_cost, trade_price), trade_info with out factor + """ + account = copy.deepcopy(trade_account) + trade_info = [] + + for order in order_list: + # check holding thresh is done in strategy + # if order.direction==0: # sell order + # # checking holding thresh limit for sell order + # if trade_account.current.get_stock_count(order.stock_id) < thresh: + # # can not sell this code + # continue + # is order executable + # check order + if self.trade_exchange.check_order(order) is True: + # execute the order + trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(order, trade_account=account) + trade_info.append([order, trade_val, trade_cost, trade_price]) + if self.verbose: + if order.direction == Order.SELL: # sell + print( + "[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, value {:.2f}.".format( + trade_date, + order.stock_id, + trade_price, + order.deal_amount, + trade_val, + ) + ) + else: + print( + "[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, value {:.2f}.".format( + trade_date, + order.stock_id, + trade_price, + order.deal_amount, + trade_val, + ) + ) + + else: + if self.verbose: + print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_date, order.stock_id)) + # do nothing + pass + return trade_info + + +def save_score_series(score_series, user_path, trade_date): + """Save the score_series into a .csv file. + The columns of saved file is + [stock_id, score] + + Parameter + --------- + order_list: [Order()] + list of Order() + date: pd.Timestamp + the date to save the order list + user_path: str / pathlib.Path() + the sub folder to save user data + """ + user_path = pathlib.Path(user_path) + YYYY, MM, DD = str(trade_date.date()).split("-") + folder_path = user_path / "score" / YYYY / MM + if not folder_path.exists(): + folder_path.mkdir(parents=True) + file_path = folder_path / "score_{}.csv".format(str(trade_date.date())) + score_series.to_csv(file_path) + + +def load_score_series(user_path, trade_date): + """Save the score_series into a .csv file. + The columns of saved file is + [stock_id, score] + + Parameter + --------- + order_list: [Order()] + list of Order() + date: pd.Timestamp + the date to save the order list + user_path: str / pathlib.Path() + the sub folder to save user data + """ + user_path = pathlib.Path(user_path) + YYYY, MM, DD = str(trade_date.date()).split("-") + folder_path = user_path / "score" / YYYY / MM + if not folder_path.exists(): + folder_path.mkdir(parents=True) + file_path = folder_path / "score_{}.csv".format(str(trade_date.date())) + score_series = pd.read_csv(file_path, index_col=0, header=None, names=["instrument", "score"]) + return score_series + + +def save_order_list(order_list, user_path, trade_date): + """ + Save the order list into a json file. + Will calculate the real amount in order according to factors at date. + + The format in json file like + {"sell": {"stock_id": amount, ...} + ,"buy": {"stock_id": amount, ...}} + + :param + order_list: [Order()] + list of Order() + date: pd.Timestamp + the date to save the order list + user_path: str / pathlib.Path() + the sub folder to save user data + """ + user_path = pathlib.Path(user_path) + YYYY, MM, DD = str(trade_date.date()).split("-") + folder_path = user_path / "trade" / YYYY / MM + if not folder_path.exists(): + folder_path.mkdir(parents=True) + sell = {} + buy = {} + for order in order_list: + if order.direction == 0: # sell + sell[order.stock_id] = [order.amount, order.factor] + else: + buy[order.stock_id] = [order.amount, order.factor] + order_dict = {"sell": sell, "buy": buy} + file_path = folder_path / "orderlist_{}.json".format(str(trade_date.date())) + with file_path.open("w") as fp: + json.dump(order_dict, fp) + + +def load_order_list(user_path, trade_date): + user_path = pathlib.Path(user_path) + YYYY, MM, DD = str(trade_date.date()).split("-") + path = user_path / "trade" / YYYY / MM / "orderlist_{}.json".format(str(trade_date.date())) + if not path.exists(): + raise ValueError("File {} not exists!".format(path)) + # get orders + with path.open("r") as fp: + order_dict = json.load(fp) + order_list = [] + for stock_id in order_dict["sell"]: + amount, factor = order_dict["sell"][stock_id] + order = Order( + stock_id=stock_id, + amount=amount, + trade_date=pd.Timestamp(trade_date), + direction=Order.SELL, + factor=factor, + ) + order_list.append(order) + for stock_id in order_dict["buy"]: + amount, factor = order_dict["buy"][stock_id] + order = Order( + stock_id=stock_id, + amount=amount, + trade_date=pd.Timestamp(trade_date), + direction=Order.BUY, + factor=factor, + ) + order_list.append(order) + return order_list diff --git a/qlib/contrib/online/manager.py b/qlib/contrib/online/manager.py new file mode 100644 index 0000000000..7e9c766e85 --- /dev/null +++ b/qlib/contrib/online/manager.py @@ -0,0 +1,147 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import pickle +import yaml +import pathlib +import pandas as pd +import shutil +from ..backtest.account import Account +from ..backtest.exchange import Exchange +from .user import User +from .utils import load_instance +from .utils import save_instance, init_instance_by_config + + +class UserManager: + def __init__(self, user_data_path, save_report=True): + """ + This module is designed to manager the users in online system + all users' data were assumed to be saved in user_data_path + Parameter + user_data_path : string + data path that all users' data were saved in + + variables: + data_path : string + data path that all users' data were saved in + users_file : string + A path of the file record the add_date of users + save_report : bool + whether to save report after each trading process + users : dict{} + [user_id]->User() + the python dict save instances of User() for each user_id + user_record : pd.Dataframe + user_id(string), add_date(string) + indicate the add_date for each users + """ + self.data_path = pathlib.Path(user_data_path) + self.users_file = self.data_path / "users.csv" + self.save_report = save_report + self.users = {} + self.user_record = None + + def load_users(self): + """ + load all users' data into manager + """ + self.users = {} + self.user_record = pd.read_csv(self.users_file, index_col=0) + for user_id in self.user_record.index: + self.users[user_id] = self.load_user(user_id) + + def load_user(self, user_id): + """ + return a instance of User() represents a user to be processed + Parameter + user_id : string + :return + user : User() + """ + account_path = self.data_path / user_id + strategy_file = self.data_path / user_id / "strategy_{}.pickle".format(user_id) + model_file = self.data_path / user_id / "model_{}.pickle".format(user_id) + cur_user_list = [user_id for user_id in self.users] + if user_id in cur_user_list: + raise ValueError("User {} has been loaded".format(user_id)) + else: + trade_account = Account(0) + trade_account.load_account(account_path) + strategy = load_instance(strategy_file) + model = load_instance(model_file) + user = User(account=trade_account, strategy=strategy, model=model) + return user + + def save_user_data(self, user_id): + """ + save a instance of User() to user data path + Parameter + user_id : string + """ + if not user_id in self.users: + raise ValueError("Cannot find user {}".format(user_id)) + self.users[user_id].account.save_account(self.data_path / user_id) + save_instance( + self.users[user_id].strategy, + self.data_path / user_id / "strategy_{}.pickle".format(user_id), + ) + save_instance( + self.users[user_id].model, + self.data_path / user_id / "model_{}.pickle".format(user_id), + ) + + def add_user(self, user_id, config_file, add_date): + """ + add the new user {user_id} into user data + will create a new folder named "{user_id}" in user data path + Parameter + user_id : string + init_cash : int + config_file : str/pathlib.Path() + path of config file + """ + config_file = pathlib.Path(config_file) + if not config_file.exists(): + raise ValueError("Cannot find config file {}".format(config_file)) + user_path = self.data_path / user_id + if user_path.exists(): + raise ValueError("User data for {} already exists".format(user_id)) + + with config_file.open("r") as fp: + config = yaml.load(fp) + # load model + model = init_instance_by_config(config["model"]) + + # load strategy + strategy = init_instance_by_config(config["strategy"]) + init_args = strategy.get_init_args_from_model(model, add_date) + strategy.init(**init_args) + + # init Account + trade_account = Account(init_cash=config["init_cash"]) + + # save user + user_path.mkdir() + save_instance(model, self.data_path / user_id / "model_{}.pickle".format(user_id)) + save_instance(strategy, self.data_path / user_id / "strategy_{}.pickle".format(user_id)) + trade_account.save_account(self.data_path / user_id) + user_record = pd.read_csv(self.users_file, index_col=0) + user_record.loc[user_id] = [add_date] + user_record.to_csv(self.users_file) + + def remove_user(self, user_id): + """ + remove user {user_id} in current user dataset + will delete the folder "{user_id}" in user data path + :param + user_id : string + """ + user_path = self.data_path / user_id + if not user_path.exists(): + raise ValueError("Cannot find user data {}".format(user_id)) + shutil.rmtree(user_path) + user_record = pd.read_csv(self.users_file, index_col=0) + user_record.drop([user_id], inplace=True) + user_record.to_csv(self.users_file) diff --git a/qlib/contrib/online/online_model.py b/qlib/contrib/online/online_model.py new file mode 100644 index 0000000000..0e8c0cb19d --- /dev/null +++ b/qlib/contrib/online/online_model.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import random +import pandas as pd +from ...data import D +from ..model.base import Model + + +class ScoreFileModel(Model): + """ + This model will load a score file, and return score at date exists in score file. + """ + + def __init__(self, score_path): + pred_test = pd.read_csv(score_path, index_col=[0, 1], parse_dates=True, infer_datetime_format=True) + self.pred = pred_test + + def get_data_with_date(self, date, **kwargs): + score = self.pred.loc(axis=0)[:, date] # (stock_id, trade_date) multi_index, score in pdate + score_series = score.reset_index(level="datetime", drop=True)[ + "score" + ] # pd.Series ; index:stock_id, data: score + return score_series + + def predict(self, x_test, **kwargs): + return x_test + + def score(self, x_test, **kwargs): + return + + def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs): + return + + def save(self, fname, **kwargs): + return diff --git a/qlib/contrib/online/operator.py b/qlib/contrib/online/operator.py new file mode 100644 index 0000000000..500e732ffd --- /dev/null +++ b/qlib/contrib/online/operator.py @@ -0,0 +1,317 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import fire +import pandas as pd +import pathlib +import qlib +import logging + +from ...data import D +from ...log import get_module_logger +from ...utils import get_pre_trading_date, is_tradable_date +from ..evaluate import risk_analysis +from ..backtest.backtest import update_account + +from .manager import UserManager +from .utils import prepare +from .utils import create_user_folder +from .executor import load_order_list, save_order_list +from .executor import SimulatorExecutor +from .executor import save_score_series, load_score_series + + +class Operator(object): + def __init__(self, client: str): + """ + Parameters + ---------- + client: str + The qlib client config file(.yaml) + """ + self.logger = get_module_logger("online operator", level=logging.INFO) + self.client = client + + @staticmethod + def init(client, path, date=None): + """Initial UserManager(), get predict date and trade date + Parameters + ---------- + client: str + The qlib client config file(.yaml) + path : str + Path to save user account. + date : str (YYYY-MM-DD) + Trade date, when the generated order list will be traded. + Return + ---------- + um: UserManager() + pred_date: pd.Timestamp + trade_date: pd.Timestamp + """ + qlib.init_from_yaml_conf(client) + um = UserManager(user_data_path=pathlib.Path(path)) + um.load_users() + if not date: + trade_date, pred_date = None, None + else: + trade_date = pd.Timestamp(date) + if not is_tradable_date(trade_date): + raise ValueError("trade date is not tradable date".format(trade_date.date())) + pred_date = get_pre_trading_date(trade_date, future=True) + return um, pred_date, trade_date + + def add_user(self, id, config, path, date): + """Add a new user into the a folder to run 'online' module. + + Parameters + ---------- + id : str + User id, should be unique. + config : str + The file path (yaml) of user config + path : str + Path to save user account. + date : str (YYYY-MM-DD) + The date that user account was added. + """ + create_user_folder(path) + qlib.init_from_yaml_conf(self.client) + um = UserManager(user_data_path=path) + add_date = D.calendar(end_time=date)[-1] + if not is_tradable_date(add_date): + raise ValueError("add date is not tradable date".format(add_date.date())) + um.add_user(user_id=id, config_file=config, add_date=add_date) + + def remove_user(self, id, path): + """Remove user from folder used in 'online' module. + + Parameters + ---------- + id : str + User id, should be unique. + path : str + Path to save user account. + """ + um = UserManager(user_data_path=path) + um.remove_user(user_id=id) + + def generate(self, date, path): + """Generate order list that will be traded at 'date'. + + Parameters + ---------- + date : str (YYYY-MM-DD) + Trade date, when the generated order list will be traded. + path : str + Path to save user account. + """ + um, pred_date, trade_date = self.init(self.client, path, date) + for user_id, user in um.users.items(): + dates, trade_exchange = prepare(um, pred_date, user_id) + # get and save the score at predict date + input_data = user.model.get_data_with_date(pred_date) + score_series = user.model.predict(input_data) + save_score_series(score_series, (pathlib.Path(path) / user_id), trade_date) + + # update strategy (and model) + user.strategy.update(score_series, pred_date, trade_date) + + # generate and save order list + order_list = user.strategy.generate_order_list( + score_series=score_series, + current=user.account.current, + trade_exchange=trade_exchange, + trade_date=trade_date, + ) + save_order_list( + order_list=order_list, + user_path=(pathlib.Path(path) / user_id), + trade_date=trade_date, + ) + self.logger.info("Generate order list at {} for {}".format(trade_date, user_id)) + um.save_user_data(user_id) + + def execute(self, date, exchange_config, path): + """Execute the orderlist at 'date'. + + Parameters + ---------- + date : str (YYYY-MM-DD) + Trade date, that the generated order list will be traded. + exchange_config: str + The file path (yaml) of exchange config + path : str + Path to save user account. + """ + um, pred_date, trade_date = self.init(self.client, path, date) + for user_id, user in um.users.items(): + dates, trade_exchange = prepare(um, trade_date, user_id, exchange_config) + executor = SimulatorExecutor(trade_exchange=trade_exchange) + if not str(dates[0].date()) == str(pred_date.date()): + raise ValueError( + "The account data is not newest! last trading date {}, today {}".format( + dates[0].date(), trade_date.date() + ) + ) + + # load and execute the order list + # will not modify the trade_account after executing + order_list = load_order_list(user_path=(pathlib.Path(path) / user_id), trade_date=trade_date) + trade_info = executor.execute(order_list=order_list, trade_account=user.account, trade_date=trade_date) + executor.save_executed_file_from_trade_info( + trade_info=trade_info, + user_path=(pathlib.Path(path) / user_id), + trade_date=trade_date, + ) + self.logger.info("execute order list at {} for {}".format(trade_date.date(), user_id)) + + def update(self, date, path, type="SIM"): + """Update account at 'date'. + + Parameters + ---------- + date : str (YYYY-MM-DD) + Trade date, that the generated order list will be traded. + path : str + Path to save user account. + type : str + which executor was been used to execute the order list + 'SIM': SimulatorExecutor() + """ + if type not in ["SIM", "YC"]: + raise ValueError("type is invalid, {}".format(type)) + um, pred_date, trade_date = self.init(self.client, path, date) + for user_id, user in um.users.items(): + dates, trade_exchange = prepare(um, trade_date, user_id) + if type == "SIM": + executor = SimulatorExecutor(trade_exchange=trade_exchange) + else: + raise ValueError("not found executor") + # dates[0] is the last_trading_date + if str(dates[0].date()) > str(pred_date.date()): + raise ValueError( + "The account data is not newest! last trading date {}, today {}".format( + dates[0].date(), trade_date.date() + ) + ) + # load trade info and update account + trade_info = executor.load_trade_info_from_executed_file( + user_path=(pathlib.Path(path) / user_id), trade_date=trade_date + ) + score_series = load_score_series((pathlib.Path(path) / user_id), trade_date) + update_account(user.account, trade_info, trade_exchange, trade_date) + + report = user.account.report.generate_report_dataframe() + self.logger.info(report) + um.save_user_data(user_id) + self.logger.info("Update account state {} for {}".format(trade_date, user_id)) + + def simulate(self, id, config, exchange_config, start, end, path, bench="SH000905"): + """Run the ( generate_order_list -> execute_order_list -> update_account) process everyday + from start date to end date. + + Parameters + ---------- + id : str + user id, need to be unique + config : str + The file path (yaml) of user config + exchange_config: str + The file path (yaml) of exchange config + start : str "YYYY-MM-DD" + The start date to run the online simulate + end : str "YYYY-MM-DD" + The end date to run the online simulate + path : str + Path to save user account. + bench : str + The benchmark that our result compared with. + 'SH000905' for csi500, 'SH000300' for csi300 + """ + # Clear the current user if exists, then add a new user. + create_user_folder(path) + um = self.init(self.client, path, None)[0] + start_date, end_date = pd.Timestamp(start), pd.Timestamp(end) + try: + um.remove_user(user_id=id) + except BaseException: + pass + um.add_user(user_id=id, config_file=config, add_date=pd.Timestamp(start_date)) + + # Do the online simulate + um.load_users() + user = um.users[id] + dates, trade_exchange = prepare(um, end_date, id, exchange_config) + executor = SimulatorExecutor(trade_exchange=trade_exchange) + for pred_date, trade_date in zip(dates[:-2], dates[1:-1]): + user_path = pathlib.Path(path) / id + + # 1. load and save score_series + input_data = user.model.get_data_with_date(pred_date) + score_series = user.model.predict(input_data) + save_score_series(score_series, (pathlib.Path(path) / id), trade_date) + + # 2. update strategy (and model) + user.strategy.update(score_series, pred_date, trade_date) + + # 3. generate and save order list + order_list = user.strategy.generate_order_list( + score_series=score_series, + current=user.account.current, + trade_exchange=trade_exchange, + trade_date=trade_date, + ) + save_order_list(order_list=order_list, user_path=user_path, trade_date=trade_date) + + # 4. auto execute order list + order_list = load_order_list(user_path=user_path, trade_date=trade_date) + trade_info = executor.execute(trade_account=user.account, order_list=order_list, trade_date=trade_date) + executor.save_executed_file_from_trade_info( + trade_info=trade_info, user_path=user_path, trade_date=trade_date + ) + # 5. update account state + trade_info = executor.load_trade_info_from_executed_file(user_path=user_path, trade_date=trade_date) + update_account(user.account, trade_info, trade_exchange, trade_date) + report = user.account.report.generate_report_dataframe() + self.logger.info(report) + um.save_user_data(id) + self.show(id, path, bench) + + def show(self, id, path, bench="SH000905"): + """show the newly report (mean, std, sharpe, annual) + + Parameters + ---------- + id : str + user id, need to be unique + path : str + Path to save user account. + bench : str + The benchmark that our result compared with. + 'SH000905' for csi500, 'SH000300' for csi300 + """ + um = self.init(self.client, path, None)[0] + if id not in um.users: + raise ValueError("Cannot find user ".format(id)) + bench = D.features([bench], ["$change"]).loc[bench, "$change"] + report = um.users[id].account.report.generate_report_dataframe() + report["bench"] = bench + analysis_result = {} + r = (report["return"] - report["bench"]).dropna() + analysis_result["sub_bench"] = risk_analysis(r) + r = (report["return"] - report["bench"] - report["cost"]).dropna() + analysis_result["sub_cost"] = risk_analysis(r) + print("Result:") + print("sub_bench:") + print(analysis_result["sub_bench"]) + print("sub_cost:") + print(analysis_result["sub_cost"]) + + +def run(): + fire.Fire(Operator) + + +if __name__ == "__main__": + run() diff --git a/qlib/contrib/online/user.py b/qlib/contrib/online/user.py new file mode 100644 index 0000000000..d8a8fdabe3 --- /dev/null +++ b/qlib/contrib/online/user.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging + +from ...log import get_module_logger +from ..evaluate import risk_analysis +from ...data import D + + +class User: + def __init__(self, account, strategy, model, verbose=False): + """ + A user in online system, which contains account, strategy and model three module. + Parameter + account : Account() + strategy : + a strategy instance + model : + a model instance + report_save_path : string + the path to save report. Will not save report if None + verbose : bool + Whether to print the info during the process + """ + self.logger = get_module_logger("User", level=logging.INFO) + self.account = account + self.strategy = strategy + self.model = model + self.verbose = verbose + + def init_state(self, date): + """ + init state when each trading date begin + Parameter + date : pd.Timestamp + """ + self.account.init_state(today=date) + self.strategy.init_state(trade_date=date, model=self.model, account=self.account) + return + + def get_latest_trading_date(self): + """ + return the latest trading date for user {user_id} + Parameter + user_id : string + :return + date : string (e.g '2018-10-08') + """ + if not self.account.last_trade_date: + return None + return str(self.account.last_trade_date.date()) + + def showReport(self, benchmark="SH000905"): + """ + show the newly report (mean, std, sharpe, annual) + Parameter + benchmark : string + bench that to be compared, 'SH000905' for csi500 + """ + bench = D.features([benchmark], ["$change"], disk_cache=True).loc[benchmark, "$change"] + report = self.account.report.generate_report_dataframe() + report["bench"] = bench + analysis_result = {"pred": {}, "sub_bench": {}, "sub_cost": {}} + r = (report["return"] - report["bench"]).dropna() + analysis_result["sub_bench"][0] = risk_analysis(r) + r = (report["return"] - report["bench"] - report["cost"]).dropna() + analysis_result["sub_cost"][0] = risk_analysis(r) + self.logger.info("Result of porfolio:") + self.logger.info("sub_bench:") + self.logger.info(analysis_result["sub_bench"][0]) + self.logger.info("sub_cost:") + self.logger.info(analysis_result["sub_cost"][0]) + return report diff --git a/qlib/contrib/online/utils.py b/qlib/contrib/online/utils.py new file mode 100644 index 0000000000..cf08e4dbe9 --- /dev/null +++ b/qlib/contrib/online/utils.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pathlib +import pickle +import yaml +import pandas as pd +from ...data import D +from ...log import get_module_logger +from ...utils import get_module_by_module_path +from ...utils import get_next_trading_date +from ..backtest.exchange import Exchange + +log = get_module_logger("utils") + + +def load_instance(file_path): + """ + load a pickle file + Parameter + file_path : string / pathlib.Path() + path of file to be loaded + :return + An instance loaded from file + """ + file_path = pathlib.Path(file_path) + if not file_path.exists(): + raise ValueError("Cannot find file {}".format(file_path)) + with file_path.open("rb") as fr: + instance = pickle.load(fr) + return instance + + +def save_instance(instance, file_path): + """ + save(dump) an instance to a pickle file + Parameter + instance : + data to te dumped + file_path : string / pathlib.Path() + path of file to be dumped + """ + file_path = pathlib.Path(file_path) + with file_path.open("wb") as fr: + pickle.dump(instance, fr) + + +def init_instance_by_config(config): + """ + generate an instance with settings in config + Parameter + config : dict + python dict indicate a init parameters to create an item + :return + An instance + """ + module = get_module_by_module_path(config["module_path"]) + instance_class = getattr(module, config["class"]) + instance = instance_class(**config["args"]) + return instance + + +def create_user_folder(path): + path = pathlib.Path(path) + if path.exists(): + return + path.mkdir(parents=True) + head = pd.DataFrame(columns=("user_id", "add_date")) + head.to_csv(path / "users.csv", index=None) + + +def prepare(um, today, user_id, exchange_config=None): + """ + 1. Get the dates that need to do trading till today for user {user_id} + dates[0] indicate the latest trading date of User{user_id}, + if User{user_id} haven't do trading before, than dates[0] presents the init date of User{user_id}. + 2. Set the exchange with exchange_config file + + Parameter + um : UserManager() + today : pd.Timestamp() + user_id : str + :return + dates : list of pd.Timestamp + trade_exchange : Exchange() + """ + # get latest trading date for {user_id} + # if is None, indicate it haven't traded, then last trading date is init date of {user_id} + latest_trading_date = um.users[user_id].get_latest_trading_date() + if not latest_trading_date: + latest_trading_date = um.user_record.loc[user_id][0] + + if str(today.date()) < latest_trading_date: + log.warning("user_id:{}, last trading date {} after today {}".format(user_id, latest_trading_date, today)) + return [pd.Timestamp(latest_trading_date)], None + + dates = D.calendar( + start_time=pd.Timestamp(latest_trading_date), + end_time=pd.Timestamp(today), + future=True, + ) + dates = list(dates) + dates.append(get_next_trading_date(dates[-1], future=True)) + if exchange_config: + with pathlib.Path(exchange_config).open("r") as fp: + exchange_paras = yaml.load(fp) + else: + exchange_paras = {} + trade_exchange = Exchange(trade_dates=dates, **exchange_paras) + return dates, trade_exchange diff --git a/qlib/contrib/report/__init__.py b/qlib/contrib/report/__init__.py new file mode 100644 index 0000000000..06309f4120 --- /dev/null +++ b/qlib/contrib/report/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +GRAPH_NAME_LISt = [ + "analysis_position.report_graph", + "analysis_position.score_ic_graph", + "analysis_position.cumulative_return_graph", + "analysis_position.risk_analysis_graph", + "analysis_position.rank_label_graph", + "analysis_model.model_performance_graph", +] diff --git a/qlib/contrib/report/analysis_model/__init__.py b/qlib/contrib/report/analysis_model/__init__.py new file mode 100644 index 0000000000..496805d671 --- /dev/null +++ b/qlib/contrib/report/analysis_model/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .analysis_model_performance import model_performance_graph diff --git a/qlib/contrib/report/analysis_model/analysis_model_performance.py b/qlib/contrib/report/analysis_model/analysis_model_performance.py new file mode 100644 index 0000000000..51bfcba073 --- /dev/null +++ b/qlib/contrib/report/analysis_model/analysis_model_performance.py @@ -0,0 +1,304 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pandas as pd + +import plotly.tools as tls +import plotly.graph_objs as go + +import statsmodels.api as sm +import matplotlib.pyplot as plt + +from scipy import stats + +from ..graph import ScatterGraph, SubplotsGraph, BarGraph, HeatmapGraph + + +def _group_return( + pred_label: pd.DataFrame = None, reverse: bool = False, N: int = 5, **kwargs +) -> tuple: + """ + + :param pred_label: + :param reverse: + :param N: + :return: + """ + if reverse: + pred_label["score"] *= -1 + + pred_label = pred_label.sort_values("score", ascending=False) + + # Group1 ~ Group5 only consider the dropna values + pred_label_drop = pred_label.dropna(subset=["score"]) + + # Group + t_df = pd.DataFrame( + { + "Group-%d" + % (i + 1): pred_label_drop.groupby(level="datetime")["label"].apply( + lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean() + ) + for i in range(N) + } + ) + t_df.index = pd.to_datetime(t_df.index) + + # Long-Short + t_df["long-short"] = t_df["Group-1"] - t_df["Group-%d" % N] + + # Long-Average + t_df["long-average"] = ( + t_df["Group-1"] - pred_label.groupby(level="datetime")["label"].mean() + ) + + t_df = t_df.dropna(how="all") # for days which does not contain label + # FIXME: support HIGH-FREQ + t_df.index = t_df.index.strftime("%Y-%m-%d") + # Cumulative Return By Group + group_scatter_figure = ScatterGraph( + t_df.cumsum(), + layout=dict( + title="Cumulative Return", xaxis=dict(type="category", tickangle=45) + ), + ).figure + + t_df = t_df.loc[:, ["long-short", "long-average"]] + _bin_size = ((t_df.max() - t_df.min()) / 20).min() + group_hist_figure = SubplotsGraph( + t_df, + kind_map=dict(kind="DistplotGraph", kwargs=dict(bin_size=_bin_size)), + subplots_kwargs=dict( + rows=1, + cols=2, + print_grid=False, + subplot_titles=["long-short", "long-average"], + ), + ).figure + + return group_scatter_figure, group_hist_figure + + +def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure: + """ + + :param data: + :param dist: + :return: + """ + fig, ax = plt.subplots(figsize=(8, 5)) + _mpl_fig = sm.qqplot(data.dropna(), dist, fit=True, line="45", ax=ax) + return tls.mpl_to_plotly(_mpl_fig) + + +def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple: + """ + + :param pred_label: + :param rank: + :return: + """ + if rank: + ic = pred_label.groupby(level="datetime").apply( + lambda x: x["label"].rank(pct=True).corr(x["score"].rank(pct=True)) + ) + else: + ic = pred_label.groupby(level="datetime").apply( + lambda x: x["label"].corr(x["score"]) + ) + + _index = ( + ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6) + ) + _monthly_ic = ic.groupby(_index).mean() + _monthly_ic.index = pd.MultiIndex.from_arrays( + [_monthly_ic.index.str.slice(0, 4), _monthly_ic.index.str.slice(4, 6)], + names=["year", "month"], + ) + + # fill month + _month_list = pd.date_range( + start=pd.Timestamp(f"{_index.min()[:4]}0101"), + end=pd.Timestamp(f"{_index.max()[:4]}1231"), + freq="1M", + ) + _years = [] + _month = [] + for _date in _month_list: + _date = _date.strftime("%Y%m%d") + _years.append(_date[:4]) + _month.append(_date[4:6]) + + fill_index = pd.MultiIndex.from_arrays([_years, _month], names=["year", "month"]) + + _monthly_ic = _monthly_ic.reindex(fill_index) + + _ic_df = ic.to_frame("ic") + ic_bar_figure = ic_figure(_ic_df, kwargs.get("show_nature_day", True)) + + ic_heatmap_figure = HeatmapGraph( + _monthly_ic.unstack(), + layout=dict(title="Monthly IC", yaxis=dict(tickformat=",d")), + graph_kwargs=dict(xtype="array", ytype="array"), + ).figure + + dist = stats.norm + _qqplot_fig = _plot_qq(ic, dist) + + if isinstance(dist, stats.norm.__class__): + dist_name = "Normal" + else: + dist_name = "Unknown" + + _bin_size = ((_ic_df.max() - _ic_df.min()) / 20).min() + _sub_graph_data = [ + ( + "ic", + dict( + row=1, + col=1, + name="", + kind="DistplotGraph", + graph_kwargs=dict(bin_size=_bin_size), + ), + ), + (_qqplot_fig, dict(row=1, col=2)), + ] + ic_hist_figure = SubplotsGraph( + _ic_df.dropna(), + kind_map=dict(kind="HistogramGraph", kwargs=dict()), + subplots_kwargs=dict( + rows=1, + cols=2, + print_grid=False, + subplot_titles=["IC", "IC %s Dist. Q-Q" % dist_name], + ), + sub_graph_data=_sub_graph_data, + layout=dict( + yaxis2=dict(title="Observed Quantile"), + xaxis2=dict(title=f"{dist_name} Distribution Quantile"), + ), + ).figure + + return ic_bar_figure, ic_heatmap_figure, ic_hist_figure + + +def _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple: + pred = pred_label.copy() + pred["score_last"] = pred.groupby(level="instrument")["score"].shift(lag) + ac = pred.groupby(level="datetime").apply( + lambda x: x["score"].rank(pct=True).corr(x["score_last"].rank(pct=True)) + ) + # FIXME: support HIGH-FREQ + _df = ac.to_frame("value") + _df.index = _df.index.strftime("%Y-%m-%d") + ac_figure = ScatterGraph( + _df, + layout=dict( + title="Auto Correlation", xaxis=dict(type="category", tickangle=45) + ), + ).figure + return (ac_figure,) + + +def _pred_turnover(pred_label: pd.DataFrame, N=5, lag=1, **kwargs) -> tuple: + pred = pred_label.copy() + pred["score_last"] = pred.groupby(level="instrument")["score"].shift(lag) + top = pred.groupby(level="datetime").apply( + lambda x: 1 + - x.nlargest(len(x) // N, columns="score") + .index.isin(x.nlargest(len(x) // N, columns="score_last").index) + .sum() + / (len(x) // N) + ) + bottom = pred.groupby(level="datetime").apply( + lambda x: 1 + - x.nsmallest(len(x) // N, columns="score") + .index.isin(x.nsmallest(len(x) // N, columns="score_last").index) + .sum() + / (len(x) // N) + ) + r_df = pd.DataFrame({"Top": top, "Bottom": bottom,}) + # FIXME: support HIGH-FREQ + r_df.index = r_df.index.strftime("%Y-%m-%d") + turnover_figure = ScatterGraph( + r_df, + layout=dict( + title="Top-Bottom Turnover", xaxis=dict(type="category", tickangle=45) + ), + ).figure + return (turnover_figure,) + + +def ic_figure(ic_df: pd.DataFrame, show_nature_day=True, **kwargs) -> go.Figure: + """IC figure + + :param ic_df: ic DataFrame + :param show_nature_day: whether to display the abscissa of non-trading day + :return: plotly.graph_objs.Figure + """ + if show_nature_day: + date_index = pd.date_range(ic_df.index.min(), ic_df.index.max()) + ic_df = ic_df.reindex(date_index) + # FIXME: support HIGH-FREQ + ic_df.index = ic_df.index.strftime("%Y-%m-%d") + ic_bar_figure = BarGraph( + ic_df, + layout=dict( + title="Information Coefficient (IC)", + xaxis=dict(type="category", tickangle=45), + ), + ).figure + return ic_bar_figure + + +def model_performance_graph( + pred_label: pd.DataFrame, + lag: int = 1, + N: int = 5, + reverse=False, + rank=False, + graph_names: list = ["group_return", "pred_ic", "pred_autocorr", "pred_turnover"], + show_notebook: bool = True, + show_nature_day=True, +) -> [list, tuple]: + """Model performance + + :param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]** + + + .. code-block:: python + + instrument datetime score label + SH600004 2017-12-11 -0.013502 -0.013502 + 2017-12-12 -0.072367 -0.072367 + 2017-12-13 -0.068605 -0.068605 + 2017-12-14 0.012440 0.012440 + 2017-12-15 -0.102778 -0.102778 + + + :param lag: `pred.groupby(level='instrument')['score'].shift(lag)`. It will be only used in the auto-correlation computing. + :param N: group number, default 5 + :param reverse: if `True`, `pred['score'] *= -1` + :param rank: if **True**, calculate rank ic + :param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover'] + :param show_notebook: whether to display graphics in notebook, the default is `True` + :param show_nature_day: whether to display the abscissa of non-trading day + :return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list + """ + figure_list = [] + for graph_name in graph_names: + fun_res = eval(f"_{graph_name}")( + pred_label=pred_label, + lag=lag, + N=N, + reverse=reverse, + rank=rank, + show_nature_day=show_nature_day, + ) + figure_list += fun_res + + if show_notebook: + BarGraph.show_graph_in_notebook(figure_list) + else: + return figure_list diff --git a/qlib/contrib/report/analysis_position/__init__.py b/qlib/contrib/report/analysis_position/__init__.py new file mode 100644 index 0000000000..86d8803dd2 --- /dev/null +++ b/qlib/contrib/report/analysis_position/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .cumulative_return import cumulative_return_graph +from .score_ic import score_ic_graph +from .report import report_graph +from .rank_label import rank_label_graph +from .risk_analysis import risk_analysis_graph diff --git a/qlib/contrib/report/analysis_position/cumulative_return.py b/qlib/contrib/report/analysis_position/cumulative_return.py new file mode 100644 index 0000000000..c31179bab0 --- /dev/null +++ b/qlib/contrib/report/analysis_position/cumulative_return.py @@ -0,0 +1,281 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import copy +from typing import Iterable + +import pandas as pd +import plotly.graph_objs as go + +from ..graph import BaseGraph, SubplotsGraph + +from ..analysis_position.parse_position import get_position_data + + +def _get_cum_return_data_with_position( + position: dict, + report_normal: pd.DataFrame, + label_data: pd.DataFrame, + start_date=None, + end_date=None, +): + """ + + :param position: + :param report_normal: + :param label_data: + :param start_date: + :param end_date: + :return: + """ + _cumulative_return_df = get_position_data( + position=position, + report_normal=report_normal, + label_data=label_data, + start_date=start_date, + end_date=end_date, + ).copy() + + _cumulative_return_df["label"] = ( + _cumulative_return_df["label"] - _cumulative_return_df["bench"] + ) + _cumulative_return_df = _cumulative_return_df.dropna() + df_gp = _cumulative_return_df.groupby(level="datetime") + result_list = [] + for gp in df_gp: + date = gp[0] + day_df = gp[1] + + _hold_df = day_df[day_df["status"] == 0] + _buy_df = day_df[day_df["status"] == 1] + _sell_df = day_df[day_df["status"] == -1] + + hold_value = (_hold_df["label"] * _hold_df["weight"]).sum() + hold_weight = _hold_df["weight"].sum() + hold_mean = (hold_value / hold_weight) if hold_weight else 0 + + sell_value = (_sell_df["label"] * _sell_df["weight"]).sum() + sell_weight = _sell_df["weight"].sum() + sell_mean = (sell_value / sell_weight) if sell_weight else 0 + + buy_value = (_buy_df["label"] * _buy_df["weight"]).sum() + buy_weight = _buy_df["weight"].sum() + buy_mean = (buy_value / buy_weight) if buy_weight else 0 + + result_list.append( + dict( + hold_value=hold_value, + hold_mean=hold_mean, + hold_weight=hold_weight, + buy_value=buy_value, + buy_mean=buy_mean, + buy_weight=buy_weight, + sell_value=sell_value, + sell_mean=sell_mean, + sell_weight=sell_weight, + buy_minus_sell_value=buy_value - sell_value, + buy_minus_sell_mean=buy_mean - sell_mean, + buy_plus_sell_weight=buy_weight + sell_weight, + date=date, + ) + ) + + r_df = pd.DataFrame(data=result_list) + r_df["cum_hold"] = r_df["hold_mean"].cumsum() + r_df["cum_buy"] = r_df["buy_mean"].cumsum() + r_df["cum_sell"] = r_df["sell_mean"].cumsum() + r_df["cum_buy_minus_sell"] = r_df["buy_minus_sell_mean"].cumsum() + return r_df + + +def _get_figure_with_position( + position: dict, + report_normal: pd.DataFrame, + label_data: pd.DataFrame, + start_date=None, + end_date=None, +) -> Iterable[go.Figure]: + """Get average analysis figures + + :param position: position + :param report_normal: + :param label_data: + :param start_date: + :param end_date: + :return: + """ + + cum_return_df = _get_cum_return_data_with_position( + position, report_normal, label_data, start_date, end_date + ) + cum_return_df = cum_return_df.set_index("date") + # FIXME: support HIGH-FREQ + cum_return_df.index = cum_return_df.index.strftime('%Y-%m-%d') + + # Create figures + for _t_name in ["buy", "sell", "buy_minus_sell", "hold"]: + sub_graph_data = [ + ( + "cum_{}".format(_t_name), + dict( + row=1, col=1, graph_kwargs={"mode": "lines+markers", "xaxis": "x3"} + ), + ), + ( + "{}_weight".format( + _t_name.replace("minus", "plus") if "minus" in _t_name else _t_name + ), + dict(row=2, col=1), + ), + ( + "{}_value".format(_t_name), + dict(row=1, col=2, kind="HistogramGraph", graph_kwargs={}), + ), + ] + + _default_xaxis = dict(showline=False, zeroline=True, tickangle=45) + _default_yaxis = dict(zeroline=True, showline=True, showticklabels=True) + sub_graph_layout = dict( + xaxis1=dict(**_default_xaxis, type="category", showticklabels=False), + xaxis3=dict(**_default_xaxis, type="category"), + xaxis2=_default_xaxis, + yaxis1=dict(**_default_yaxis, title=_t_name), + yaxis2=_default_yaxis, + yaxis3=_default_yaxis, + ) + + mean_value = cum_return_df["{}_value".format(_t_name)].mean() + layout = dict( + height=500, + title=f"{_t_name}(the red line in the histogram on the right represents the average)", + shapes=[ + { + "type": "line", + "xref": "x2", + "yref": "paper", + "x0": mean_value, + "y0": 0, + "x1": mean_value, + "y1": 1, + # NOTE: 'fillcolor': '#d3d3d3', 'opacity': 0.3, + "line": {"color": "red", "width": 1}, + }, + ], + ) + + kind_map = dict(kind="ScatterGraph", kwargs=dict(mode="lines+markers")) + specs = [ + [{"rowspan": 1}, {"rowspan": 2}], + [{"rowspan": 1}, None], + ] + subplots_kwargs = dict( + vertical_spacing=0.01, + rows=2, + cols=2, + row_width=[1, 2], + column_width=[3, 1], + print_grid=False, + specs=specs, + ) + yield SubplotsGraph( + cum_return_df, + layout=layout, + kind_map=kind_map, + sub_graph_layout=sub_graph_layout, + sub_graph_data=sub_graph_data, + subplots_kwargs=subplots_kwargs, + ).figure + + +def cumulative_return_graph( + position: dict, + report_normal: pd.DataFrame, + label_data: pd.DataFrame, + show_notebook=True, + start_date=None, + end_date=None, +) -> Iterable[go.Figure]: + """Backtest buy, sell, and holding cumulative return graph + + Example: + + + .. code-block:: python + + from qlib.data import D + from qlib.contrib.evaluate import risk_analysis, backtest, long_short_backtest + from qlib.contrib.strategy import TopkAmountStrategy + + # backtest parameters + bparas = {} + bparas['limit_threshold'] = 0.095 + bparas['account'] = 1000000000 + + sparas = {} + sparas['topk'] = 50 + sparas['buffer_margin'] = 230 + strategy = TopkAmountStrategy(**sparas) + + report_normal_df, positions = backtest(pred_df, strategy, **bparas) + + pred_df_dates = pred_df.index.get_level_values(level='datetime') + features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max()) + features_df.columns = ['label'] + + qcr.cumulative_return_graph(positions, report_normal_df, features_df) + + + Graph desc: + - Axis X: Trading day + - Axis Y: + - Above axis Y: (((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum() + - Below axis Y: Daily weight sum + - In the sell graph, y < 0 stands for profit; in other cases, y > 0 stands for profit. + - In the buy_minus_sell graph, the y value of the weight graph at the bottom is buy_weight + sell_weight. + - In each graph, the red line in the histogram on the right represents the average. + + :param position: position data + :param report_normal: + + + .. code-block:: python + + return cost bench turnover + date + 2017-01-04 0.003421 0.000864 0.011693 0.576325 + 2017-01-05 0.000508 0.000447 0.000721 0.227882 + 2017-01-06 -0.003321 0.000212 -0.004322 0.102765 + 2017-01-09 0.006753 0.000212 0.006874 0.105864 + 2017-01-10 -0.000416 0.000440 -0.003350 0.208396 + + + :param label_data: `D.features` result; index is `pd.MultiIndex`, index name is [`instrument`, `datetime`]; columns names is [`label`]. + **The ``label`` T is the change from T to T+1**, it is recommended to use ``close``, example: D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1']) + + + .. code-block:: python + + label + instrument datetime + SH600004 2017-12-11 -0.013502 + 2017-12-12 -0.072367 + 2017-12-13 -0.068605 + 2017-12-14 0.012440 + 2017-12-15 -0.102778 + + + :param show_notebook: True or False. If True, show graph in notebook, else return figures + :param start_date: start date + :param end_date: end date + :return: + """ + position = copy.deepcopy(position) + report_normal = report_normal.copy() + label_data.columns = ["label"] + _figures = _get_figure_with_position( + position, report_normal, label_data, start_date, end_date + ) + if show_notebook: + BaseGraph.show_graph_in_notebook(_figures) + else: + return _figures diff --git a/qlib/contrib/report/analysis_position/parse_position.py b/qlib/contrib/report/analysis_position/parse_position.py new file mode 100644 index 0000000000..c3a7807e36 --- /dev/null +++ b/qlib/contrib/report/analysis_position/parse_position.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pandas as pd + + +from ...backtest.profit_attribution import get_stock_weight_df + + +def parse_position(position: dict = None) -> pd.DataFrame: + """Parse position dict to position DataFrame + + :param position: position data + :return: position DataFrame; + + + .. code-block:: python + + position_df = parse_position(positions) + print(position_df.head()) + # status: 0-hold, -1-sell, 1-buy + + amount cash count price status weight + instrument datetime + SZ000547 2017-01-04 44.154290 211405.285654 1 205.189575 1 0.031255 + SZ300202 2017-01-04 60.638845 211405.285654 1 154.356506 1 0.032290 + SH600158 2017-01-04 46.531681 211405.285654 1 153.895142 1 0.024704 + SH600545 2017-01-04 197.173093 211405.285654 1 48.607037 1 0.033063 + SZ000930 2017-01-04 103.938300 211405.285654 1 80.759453 1 0.028958 + + + """ + + position_weight_df = get_stock_weight_df(position) + # If the day does not exist, use the last weight + position_weight_df.fillna(method="ffill", inplace=True) + + previous_data = {"date": None, "code_list": []} + + result_df = pd.DataFrame() + for _trading_date, _value in position.items(): + # pd_date type: pd.Timestamp + _cash = _value.pop("cash") + for _item in ["today_account_value"]: + if _item in _value: + _value.pop(_item) + + _trading_day_df = pd.DataFrame.from_dict(_value, orient="index") + _trading_day_df["weight"] = position_weight_df.loc[_trading_date] + _trading_day_df["cash"] = _cash + _trading_day_df["date"] = _trading_date + # status: 0-hold, -1-sell, 1-buy + _trading_day_df["status"] = 0 + + # T not exist, T-1 exist, T sell + _cur_day_sell = set(previous_data["code_list"]) - set(_trading_day_df.index) + # T exist, T-1 not exist, T buy + _cur_day_buy = set(_trading_day_df.index) - set(previous_data["code_list"]) + + # Trading day buy + _trading_day_df.loc[_trading_day_df.index.isin(_cur_day_buy), "status"] = 1 + + # Trading day sell + if not result_df.empty: + _trading_day_sell_df = result_df.loc[ + (result_df["date"] == previous_data["date"]) + & (result_df.index.isin(_cur_day_sell)) + ].copy() + if not _trading_day_sell_df.empty: + _trading_day_sell_df["status"] = -1 + _trading_day_sell_df["date"] = _trading_date + _trading_day_df = _trading_day_df.append( + _trading_day_sell_df, sort=False + ) + + result_df = result_df.append(_trading_day_df, sort=True) + + previous_data = dict( + date=_trading_date, + code_list=_trading_day_df[_trading_day_df["status"] != -1].index, + ) + + result_df.reset_index(inplace=True) + result_df.rename(columns={"date": "datetime", "index": "instrument"}, inplace=True) + return result_df.set_index(["instrument", "datetime"]) + + +def _add_label_to_position( + position_df: pd.DataFrame, label_data: pd.DataFrame +) -> pd.DataFrame: + """Concat position with custom label + + :param position_df: position DataFrame + :param label_data: + :return: concat result + """ + + _start_time = position_df.index.get_level_values(level="datetime").min() + _end_time = position_df.index.get_level_values(level="datetime").max() + label_data = label_data.loc(axis=0)[:, pd.to_datetime(_start_time) :] + _result_df = pd.concat([position_df, label_data], axis=1, sort=True).reindex( + label_data.index + ) + _result_df = _result_df.loc[_result_df.index.get_level_values(1) <= _end_time] + return _result_df + + +def _add_bench_to_position( + position_df: pd.DataFrame = None, bench: pd.Series = None +) -> pd.DataFrame: + """Concat position with bench + + :param position_df: position DataFrame + :param bench: report normal data + :return: concat result + """ + _temp_df = position_df.reset_index(level="instrument") + # FIXME: After the stock is bought and sold, the rise and fall of the next trading day are calculated. + _temp_df["bench"] = bench.shift(-1) + res_df = _temp_df.set_index(["instrument", _temp_df.index]) + return res_df + + +def _calculate_label_rank(df: pd.DataFrame) -> pd.DataFrame: + """calculate label rank + + :param df: + :return: + """ + _label_name = "label" + + def _calculate_day_value(g_df: pd.DataFrame): + g_df = g_df.copy() + g_df["rank_ratio"] = g_df[_label_name].rank(ascending=False) / len(g_df) * 100 + + # Sell: -1, Hold: 0, Buy: 1 + for i in [-1, 0, 1]: + g_df.loc[g_df["status"] == i, "rank_label_mean"] = g_df[ + g_df["status"] == i + ]["rank_ratio"].mean() + + g_df["excess_return"] = g_df[_label_name] - g_df[_label_name].mean() + return g_df + + return df.groupby(level="datetime").apply(_calculate_day_value) + + +def get_position_data( + position: dict, + label_data: pd.DataFrame, + report_normal: pd.DataFrame = None, + calculate_label_rank=False, + start_date=None, + end_date=None, +) -> pd.DataFrame: + """Concat position data with pred/report_normal + + :param position: position data + :param report_normal: report normal, must be container 'bench' column + :param label_data: + :param calculate_label_rank: + :param start_date: start date + :param end_date: end date + :return: concat result, + columns: ['amount', 'cash', 'count', 'price', 'status', 'weight', 'label', + 'rank_ratio', 'rank_label_mean', 'excess_return', 'score', 'bench'] + index: ['instrument', 'date'] + """ + _position_df = parse_position(position) + + # Add custom_label, rank_ratio, rank_mean, and excess_return field + _position_df = _add_label_to_position(_position_df, label_data) + + if calculate_label_rank: + _position_df = _calculate_label_rank(_position_df) + + if report_normal is not None: + # Add bench field + _position_df = _add_bench_to_position(_position_df, report_normal["bench"]) + + _date_list = _position_df.index.get_level_values(level="datetime") + start_date = _date_list.min() if start_date is None else start_date + end_date = _date_list.max() if end_date is None else end_date + _position_df = _position_df.loc[ + (start_date <= _date_list) & (_date_list <= end_date) + ] + return _position_df diff --git a/qlib/contrib/report/analysis_position/rank_label.py b/qlib/contrib/report/analysis_position/rank_label.py new file mode 100644 index 0000000000..75b42f033f --- /dev/null +++ b/qlib/contrib/report/analysis_position/rank_label.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import copy +from typing import Iterable + +import pandas as pd +import plotly.graph_objs as go + +from ..graph import ScatterGraph +from ..analysis_position.parse_position import get_position_data + + +def _get_figure_with_position( + position: dict, label_data: pd.DataFrame, start_date=None, end_date=None +) -> Iterable[go.Figure]: + """Get average analysis figures + + :param position: position + :param label_data: + :param start_date: + :param end_date: + :return: + """ + _position_df = get_position_data( + position, + label_data, + calculate_label_rank=True, + start_date=start_date, + end_date=end_date, + ) + + res_dict = dict() + _pos_gp = _position_df.groupby(level=1) + for _item in _pos_gp: + _date = _item[0] + _day_df = _item[1] + + _day_value = res_dict.setdefault(_date, {}) + for _i, _name in {0: "Hold", 1: "Buy", -1: "Sell"}.items(): + _temp_df = _day_df[_day_df["status"] == _i] + if _temp_df.empty: + _day_value[_name] = 0 + else: + _day_value[_name] = _temp_df["rank_label_mean"].values[0] + + _res_df = pd.DataFrame.from_dict(res_dict, orient="index") + # FIXME: support HIGH-FREQ + _res_df.index = _res_df.index.strftime('%Y-%m-%d') + for _col in _res_df.columns: + yield ScatterGraph( + _res_df.loc[:, [_col]], + layout=dict( + title=_col, + xaxis=dict(type="category", tickangle=45), + yaxis=dict(title="lable-rank-ratio: %"), + ), + graph_kwargs=dict(mode="lines+markers"), + ).figure + + +def rank_label_graph( + position: dict, + label_data: pd.DataFrame, + start_date=None, + end_date=None, + show_notebook=True, +) -> Iterable[go.Figure]: + """Ranking percentage of stocks buy, sell, and holding on the trading day. + Average rank-ratio(similar to **sell_df['label'].rank(ascending=False) / len(sell_df)**) of daily trading + + Example: + + + .. code-block:: python + + from qlib.data import D + from qlib.contrib.evaluate import backtest + from qlib.contrib.strategy import TopkAmountStrategy + + # backtest parameters + bparas = {} + bparas['limit_threshold'] = 0.095 + bparas['account'] = 1000000000 + + sparas = {} + sparas['topk'] = 50 + sparas['buffer_margin'] = 230 + strategy = TopkAmountStrategy(**sparas) + + _, positions = backtest(pred_df, strategy, **bparas) + + pred_df_dates = pred_df.index.get_level_values(level='datetime') + features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max()) + features_df.columns = ['label'] + + qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max()) + + + :param position: position data; **qlib.contrib.backtest.backtest.backtest** result + :param label_data: **D.features** result; index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[label]**. + **The ``label`` T is the change from T to T+1**, it is recommended to use ``close``, example: D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1']) + + + .. code-block:: python + + label + instrument datetime + SH600004 2017-12-11 -0.013502 + 2017-12-12 -0.072367 + 2017-12-13 -0.068605 + 2017-12-14 0.012440 + 2017-12-15 -0.102778 + + + :param start_date: start date + :param end_date: end_date + :param show_notebook: **True** or **False**. If True, show graph in notebook, else return figures + :return: + """ + position = copy.deepcopy(position) + label_data.columns = ["label"] + _figures = _get_figure_with_position(position, label_data, start_date, end_date) + if show_notebook: + ScatterGraph.show_graph_in_notebook(_figures) + else: + return _figures diff --git a/qlib/contrib/report/analysis_position/report.py b/qlib/contrib/report/analysis_position/report.py new file mode 100644 index 0000000000..680c6777fd --- /dev/null +++ b/qlib/contrib/report/analysis_position/report.py @@ -0,0 +1,220 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pandas as pd + +from ..graph import SubplotsGraph, BaseGraph + + +def _calculate_maximum(df: pd.DataFrame, is_ex: bool = False): + """ + + :param df: + :param is_ex: + :return: + """ + if is_ex: + end_date = df["cum_ex_return_wo_cost_mdd"].idxmin() + start_date = df.loc[df.index <= end_date]["cum_ex_return_wo_cost"].idxmax() + else: + end_date = df["return_wo_mdd"].idxmin() + start_date = df.loc[df.index <= end_date]["cum_return_wo_cost"].idxmax() + return start_date, end_date + + +def _calculate_mdd(series): + """ + Calculate mdd + + :param series: + :return: + """ + return series - series.cummax() + + +def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame: + """ + + :param df: + :return: + """ + + df.index = df.index.strftime("%Y-%m-%d") + + report_df = pd.DataFrame() + + report_df["cum_bench"] = df["bench"].cumsum() + report_df["cum_return_wo_cost"] = df["return"].cumsum() + report_df["cum_return_w_cost"] = (df["return"] - df["cost"]).cumsum() + # report_df['cum_return'] - report_df['cum_return'].cummax() + report_df["return_wo_mdd"] = _calculate_mdd(report_df["cum_return_wo_cost"]) + report_df["return_w_cost_mdd"] = _calculate_mdd( + (df["return"] - df["cost"]).cumsum() + ) + + report_df["cum_ex_return_wo_cost"] = (df["return"] - df["bench"]).cumsum() + report_df["cum_ex_return_w_cost"] = ( + df["return"] - df["bench"] - df["cost"] + ).cumsum() + report_df["cum_ex_return_wo_cost_mdd"] = _calculate_mdd( + (df["return"] - df["bench"]).cumsum() + ) + report_df["cum_ex_return_w_cost_mdd"] = _calculate_mdd( + (df["return"] - df["cost"] - df["bench"]).cumsum() + ) + # return_wo_mdd , return_w_cost_mdd, cum_ex_return_wo_cost_mdd, cum_ex_return_w + + report_df["turnover"] = df["turnover"] + report_df.sort_index(ascending=True, inplace=True) + return report_df + + +def _report_figure(df: pd.DataFrame) -> [list, tuple]: + """ + + :param df: + :return: + """ + + # Get data + report_df = _calculate_report_data(df) + + # Maximum Drawdown + max_start_date, max_end_date = _calculate_maximum(report_df) + ex_max_start_date, ex_max_end_date = _calculate_maximum(report_df, True) + + _temp_df = report_df.reset_index() + _temp_df.loc[-1] = 0 + _temp_df = _temp_df.shift(1) + _temp_df.loc[0, "index"] = "T0" + _temp_df.set_index("index", inplace=True) + _temp_df.iloc[0] = 0 + report_df = _temp_df + + # Create figure + _default_kind_map = dict(kind="ScatterGraph", kwargs={"mode": "lines+markers"}) + _temp_fill_args = {"fill": "tozeroy", "mode": "lines+markers"} + _column_row_col_dict = [ + ("cum_bench", dict(row=1, col=1)), + ("cum_return_wo_cost", dict(row=1, col=1)), + ("cum_return_w_cost", dict(row=1, col=1)), + ("return_wo_mdd", dict(row=2, col=1, graph_kwargs=_temp_fill_args)), + ("return_w_cost_mdd", dict(row=3, col=1, graph_kwargs=_temp_fill_args)), + ("cum_ex_return_wo_cost", dict(row=4, col=1)), + ("cum_ex_return_w_cost", dict(row=4, col=1)), + ("turnover", dict(row=5, col=1)), + ("cum_ex_return_w_cost_mdd", dict(row=6, col=1, graph_kwargs=_temp_fill_args)), + ("cum_ex_return_wo_cost_mdd", dict(row=7, col=1, graph_kwargs=_temp_fill_args)), + ] + + _subplot_layout = dict( + xaxis=dict(showline=True, type="category", tickangle=45), + yaxis=dict(zeroline=True, showline=True, showticklabels=True), + ) + for i in range(2, 8): + # yaxis + _subplot_layout.update( + { + "yaxis{}".format(i): dict( + zeroline=True, showline=True, showticklabels=True + ) + } + ) + _layout_style = dict( + height=1200, + title=" ", + shapes=[ + { + "type": "rect", + "xref": "x", + "yref": "paper", + "x0": max_start_date, + "y0": 0.55, + "x1": max_end_date, + "y1": 1, + "fillcolor": "#d3d3d3", + "opacity": 0.3, + "line": {"width": 0,}, + }, + { + "type": "rect", + "xref": "x", + "yref": "paper", + "x0": ex_max_start_date, + "y0": 0, + "x1": ex_max_end_date, + "y1": 0.55, + "fillcolor": "#d3d3d3", + "opacity": 0.3, + "line": {"width": 0,}, + }, + ], + ) + + _subplot_kwargs = dict( + shared_xaxes=True, + vertical_spacing=0.01, + rows=7, + cols=1, + row_width=[1, 1, 1, 3, 1, 1, 3], + print_grid=False, + ) + figure = SubplotsGraph( + df=report_df, + layout=_layout_style, + sub_graph_data=_column_row_col_dict, + subplots_kwargs=_subplot_kwargs, + kind_map=_default_kind_map, + sub_graph_layout=_subplot_layout, + ).figure + return (figure,) + + +def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list, tuple]: + """display backtest report + + Example: + + + .. code-block:: python + + from qlib.contrib.evaluate import backtest + from qlib.contrib.strategy import TopkAmountStrategy + + # backtest parameters + bparas = {} + bparas['limit_threshold'] = 0.095 + bparas['account'] = 1000000000 + + sparas = {} + sparas['topk'] = 50 + sparas['buffer_margin'] = 230 + strategy = TopkAmountStrategy(**sparas) + + report_normal_df, _ = backtest(pred_df, strategy, **bparas) + + qcr.report_graph(report_normal_df) + + :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench** + + + .. code-block:: python + + return cost bench turnover + date + 2017-01-04 0.003421 0.000864 0.011693 0.576325 + 2017-01-05 0.000508 0.000447 0.000721 0.227882 + 2017-01-06 -0.003321 0.000212 -0.004322 0.102765 + 2017-01-09 0.006753 0.000212 0.006874 0.105864 + 2017-01-10 -0.000416 0.000440 -0.003350 0.208396 + + + :param show_notebook: whether to display graphics in notebook, the default is **True** + :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list + """ + report_df = report_df.copy() + fig_list = _report_figure(report_df) + if show_notebook: + BaseGraph.show_graph_in_notebook(fig_list) + else: + return fig_list diff --git a/qlib/contrib/report/analysis_position/risk_analysis.py b/qlib/contrib/report/analysis_position/risk_analysis.py new file mode 100644 index 0000000000..46122e17d1 --- /dev/null +++ b/qlib/contrib/report/analysis_position/risk_analysis.py @@ -0,0 +1,271 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Iterable + +import pandas as pd + +import plotly.graph_objs as py + +from ...evaluate import risk_analysis + +from ..graph import SubplotsGraph, ScatterGraph + + +def _get_risk_analysis_data_with_report( + report_normal_df: pd.DataFrame, + report_long_short_df: pd.DataFrame, + date: pd.Timestamp, +) -> pd.DataFrame: + """Get risk analysis data with report + + :param report_normal_df: report data + :param report_long_short_df: report data + :param date: date string + :return: + """ + + analysis = dict() + if not report_long_short_df.empty: + analysis["pred_long"] = risk_analysis(report_long_short_df["long"]) + analysis["pred_short"] = risk_analysis(report_long_short_df["short"]) + analysis["pred_long_short"] = risk_analysis(report_long_short_df["long_short"]) + + if not report_normal_df.empty: + analysis["sub_bench"] = risk_analysis( + report_normal_df["return"] - report_normal_df["bench"] + ) + analysis["sub_cost"] = risk_analysis( + report_normal_df["return"] + - report_normal_df["bench"] + - report_normal_df["cost"] + ) + analysis_df = pd.concat(analysis) # type: pd.DataFrame + analysis_df["date"] = date + return analysis_df + + +def _get_all_risk_analysis(risk_df: pd.DataFrame) -> pd.DataFrame: + """risk_df to standard + + :param risk_df: risk data + :return: + """ + if risk_df is None: + return pd.DataFrame() + risk_df = risk_df.unstack() + risk_df.columns = risk_df.columns.droplevel(0) + return risk_df.drop("mean", axis=1) + + +def _get_monthly_risk_analysis_with_report( + report_normal_df: pd.DataFrame, report_long_short_df: pd.DataFrame +) -> pd.DataFrame: + """Get monthly analysis data + + :param report_normal_df: + :param report_long_short_df: + :return: + """ + + # Group by month + report_normal_gp = report_normal_df.groupby( + [report_normal_df.index.year, report_normal_df.index.month] + ) + report_long_short_gp = report_long_short_df.groupby( + [report_long_short_df.index.year, report_long_short_df.index.month] + ) + + gp_month = sorted( + set(report_normal_gp.size().index) & set(report_long_short_gp.size().index) + ) + + _monthly_df = pd.DataFrame() + for gp_m in gp_month: + _m_report_normal = report_normal_gp.get_group(gp_m) + _m_report_long_short = report_long_short_gp.get_group(gp_m) + + if (len(_m_report_normal) < 3) and (len(_m_report_long_short) < 3): + # The month's data is less than 3, not displayed + # FIXME: If the trading day of a month is less than 3 days, a breakpoint will appear in the graph + continue + month_days = pd.Timestamp(year=gp_m[0], month=gp_m[1], day=1).days_in_month + _temp_df = _get_risk_analysis_data_with_report( + _m_report_normal, + _m_report_long_short, + pd.Timestamp(year=gp_m[0], month=gp_m[1], day=month_days), + ) + _monthly_df = _monthly_df.append(_temp_df, sort=False) + + return _monthly_df + + +def _get_monthly_analysis_with_feature( + monthly_df: pd.DataFrame, feature: str = "annual" +) -> pd.DataFrame: + """ + + :param monthly_df: + :param feature: + :return: + """ + _monthly_df_gp = monthly_df.reset_index().groupby(["level_1"]) + + _name_df = _monthly_df_gp.get_group(feature).set_index(["level_0", "level_1"]) + _temp_df = _name_df.pivot_table( + index="date", values=["risk"], columns=_name_df.index + ) + _temp_df.columns = map(lambda x: "_".join(x[-1]), _temp_df.columns) + _temp_df.index = _temp_df.index.strftime("%Y-%m") + + return _temp_df + + +def _get_risk_analysis_figure(analysis_df: pd.DataFrame) -> Iterable[py.Figure]: + """Get analysis graph figure + + :param analysis_df: + :return: + """ + if analysis_df is None: + return [] + + _figure = SubplotsGraph( + _get_all_risk_analysis(analysis_df), kind_map=dict(kind="BarGraph", kwargs={}) + ).figure + return (_figure,) + + +def _get_monthly_risk_analysis_figure( + report_normal_df: pd.DataFrame, report_long_short_df: pd.DataFrame +) -> Iterable[py.Figure]: + """Get analysis monthly graph figure + + :param report_normal_df: + :param report_long_short_df: + :return: + """ + + if report_normal_df is None and report_long_short_df is None: + return [] + + if report_normal_df is None: + report_normal_df = pd.DataFrame(index=report_long_short_df.index) + + if report_long_short_df is None: + report_long_short_df = pd.DataFrame(index=report_normal_df.index) + + _monthly_df = _get_monthly_risk_analysis_with_report( + report_normal_df=report_normal_df, report_long_short_df=report_long_short_df + ) + + for _feature in ["annual", "mdd", "sharpe", "std"]: + _temp_df = _get_monthly_analysis_with_feature(_monthly_df, _feature) + yield ScatterGraph( + _temp_df, + layout=dict(title=_feature, xaxis=dict(type="category", tickangle=45)), + graph_kwargs={"mode": "lines+markers"}, + ).figure + + +def risk_analysis_graph( + analysis_df: pd.DataFrame = None, + report_normal_df: pd.DataFrame = None, + report_long_short_df: pd.DataFrame = None, + show_notebook: bool = True, +) -> Iterable[py.Figure]: + """Generate analysis graph and monthly analysis + + Example: + + + .. code-block:: python + + from qlib.contrib.evaluate import risk_analysis, backtest, long_short_backtest + from qlib.contrib.strategy import TopkAmountStrategy + from qlib.contrib.report import analysis_position + + # backtest parameters + bparas = {} + bparas['limit_threshold'] = 0.095 + bparas['account'] = 1000000000 + + sparas = {} + sparas['topk'] = 50 + sparas['buffer_margin'] = 230 + strategy = TopkAmountStrategy(**sparas) + + report_normal_df, positions = backtest(pred_df, strategy, **bparas) + long_short_map = long_short_backtest(pred_df) + report_long_short_df = pd.DataFrame(long_short_map) + + analysis = dict() + analysis['pred_long'] = risk_analysis(report_long_short_df['long']) + analysis['pred_short'] = risk_analysis(report_long_short_df['short']) + analysis['pred_long_short'] = risk_analysis(report_long_short_df['long_short']) + analysis['sub_bench'] = risk_analysis(report_normal_df['return'] - report_normal_df['bench']) + analysis['sub_cost'] = risk_analysis(report_normal_df['return'] - report_normal_df['bench'] - report_normal_df['cost']) + analysis_df = pd.concat(analysis) + + analysis_position.risk_analysis_graph(analysis_df, report_normal_df, report_long_short_df) + + + + :param analysis_df: analysis data, index is **pd.MultiIndex**; columns names is **[risk]**. + + + .. code-block:: python + + risk + pred_long mean 0.002444 + std 0.004391 + annual 0.615868 + sharpe 8.835900 + mdd -0.016492 + pred_short mean 0.002655 + std 0.004241 + annual 0.669002 + sharpe 9.936303 + mdd -0.016071 + + + :param report_normal_df: **df.index.name** must be **date**, df.columns must contain **return**, **turnover**, **cost**, **bench** + + + .. code-block:: python + + return cost bench turnover + date + 2017-01-04 0.003421 0.000864 0.011693 0.576325 + 2017-01-05 0.000508 0.000447 0.000721 0.227882 + 2017-01-06 -0.003321 0.000212 -0.004322 0.102765 + 2017-01-09 0.006753 0.000212 0.006874 0.105864 + 2017-01-10 -0.000416 0.000440 -0.003350 0.208396 + + + :param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short** + + + .. code-block:: python + + long short long_short + date + 2017-01-04 -0.001360 0.001394 0.000034 + 2017-01-05 0.002456 0.000058 0.002514 + 2017-01-06 0.000120 0.002739 0.002859 + 2017-01-09 0.001436 0.001838 0.003273 + 2017-01-10 0.000824 -0.001944 -0.001120 + + + :param show_notebook: Whether to display graphics in a notebook, default **True** + If True, show graph in notebook + If False, return graph figure + :return: + """ + _figure_list = list(_get_risk_analysis_figure(analysis_df)) + list( + _get_monthly_risk_analysis_figure(report_normal_df, report_long_short_df) + ) + if show_notebook: + ScatterGraph.show_graph_in_notebook(_figure_list) + else: + return _figure_list diff --git a/qlib/contrib/report/analysis_position/score_ic.py b/qlib/contrib/report/analysis_position/score_ic.py new file mode 100644 index 0000000000..bc6f8f5ff1 --- /dev/null +++ b/qlib/contrib/report/analysis_position/score_ic.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pandas as pd + +from ..graph import ScatterGraph + + +def _get_score_ic(pred_label: pd.DataFrame): + """ + + :param pred_label: + :return: + """ + concat_data = pred_label.copy() + concat_data.dropna(axis=0, how="any", inplace=True) + _ic = concat_data.groupby(level="datetime").apply( + lambda x: x["label"].corr(x["score"]) + ) + _rank_ic = concat_data.groupby(level="datetime").apply( + lambda x: x["label"].corr(x["score"], method="spearman") + ) + return pd.DataFrame({"ic": _ic, "rank_ic": _rank_ic}) + + +def score_ic_graph( + pred_label: pd.DataFrame, show_notebook: bool = True +) -> [list, tuple]: + """score IC + + Example: + + + .. code-block:: python + + from qlib.data import D + from qlib.contrib.report import analysis_position + pred_df_dates = pred_df.index.get_level_values(level='datetime') + features_df = D.features(D.instruments('csi500'), ['Ref($close, -2)/Ref($close, -1)-1'], pred_df_dates.min(), pred_df_dates.max()) + features_df.columns = ['label'] + pred_label = pd.concat([features_df, pred], axis=1, sort=True).reindex(features_df.index) + analysis_position.score_ic_graph(pred_label) + + + :param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]** + + + .. code-block:: python + + instrument datetime score label + SH600004 2017-12-11 -0.013502 -0.013502 + 2017-12-12 -0.072367 -0.072367 + 2017-12-13 -0.068605 -0.068605 + 2017-12-14 0.012440 0.012440 + 2017-12-15 -0.102778 -0.102778 + + + :param show_notebook: whether to display graphics in notebook, the default is **True** + :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list + """ + _ic_df = _get_score_ic(pred_label) + # FIXME: support HIGH-FREQ + _ic_df.index = _ic_df.index.strftime("%Y-%m-%d") + _figure = ScatterGraph( + _ic_df, + layout=dict(title="Score IC", xaxis=dict(type="category", tickangle=45)), + graph_kwargs={"mode": "lines+markers"}, + ).figure + if show_notebook: + ScatterGraph.show_graph_in_notebook([_figure]) + else: + return (_figure,) diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py new file mode 100644 index 0000000000..082eafa49c --- /dev/null +++ b/qlib/contrib/report/graph.py @@ -0,0 +1,370 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import math +import importlib +from pathlib import Path +from typing import Iterable + +import pandas as pd + +import plotly.offline as py +import plotly.graph_objs as go + +from plotly.tools import make_subplots +from plotly.figure_factory import create_distplot + +from ...utils import get_module_by_module_path + + +class BaseGraph(object): + """""" + + _name = None + + def __init__( + self, df: pd.DataFrame = None, layout: dict = None, graph_kwargs: dict = None, name_dict: dict = None, **kwargs + ): + """ + + :param df: + :param layout: + :param graph_kwargs: + :param name_dict: + :param kwargs: + layout: dict + go.Layout parameters + graph_kwargs: dict + Graph parameters, eg: go.Bar(**graph_kwargs) + """ + self._df = df + + self._layout = dict() if layout is None else layout + self._graph_kwargs = dict() if graph_kwargs is None else graph_kwargs + self._name_dict = name_dict + + self.data = None + + self._init_parameters(**kwargs) + self._init_data() + + def _init_data(self): + """ + + :return: + """ + if self._df.empty: + raise ValueError("df is empty.") + + self.data = self._get_data() + + def _init_parameters(self, **kwargs): + """ + + :param kwargs + """ + + # Instantiate graphics parameters + self._graph_type = self._name.lower().capitalize() + + # Displayed column name + if self._name_dict is None: + self._name_dict = {_item: _item for _item in self._df.columns} + + @staticmethod + def get_instance_with_graph_parameters(graph_type: str = None, **kwargs): + """ + + :param graph_type: + :param kwargs: + :return: + """ + try: + _graph_module = importlib.import_module("plotly.graph_objs") + _graph_class = getattr(_graph_module, graph_type) + except AttributeError: + _graph_module = importlib.import_module("qlib.contrib.report.graph") + _graph_class = getattr(_graph_module, graph_type) + return _graph_class(**kwargs) + + @staticmethod + def show_graph_in_notebook(figure_list: Iterable[go.Figure] = None): + """ + + :param figure_list: + :return: + """ + py.init_notebook_mode() + for _fig in figure_list: + py.iplot(_fig) + + def _get_layout(self) -> go.Layout: + """ + + :return: + """ + return go.Layout(**self._layout) + + def _get_data(self) -> list: + """ + + :return: + """ + + _data = [ + self.get_instance_with_graph_parameters( + graph_type=self._graph_type, x=self._df.index, y=self._df[_col], name=_name, **self._graph_kwargs + ) + for _col, _name in self._name_dict.items() + ] + return _data + + @property + def figure(self) -> go.Figure: + """ + + :return: + """ + return go.Figure(data=self.data, layout=self._get_layout()) + + +class ScatterGraph(BaseGraph): + _name = "scatter" + + +class BarGraph(BaseGraph): + _name = "bar" + + +class DistplotGraph(BaseGraph): + _name = "distplot" + + def _get_data(self): + """ + + :return: + """ + _t_df = self._df.dropna() + _data_list = [_t_df[_col] for _col in self._name_dict] + _label_list = [_name for _name in self._name_dict.values()] + _fig = create_distplot(_data_list, _label_list, show_rug=False, **self._graph_kwargs) + + return _fig["data"] + + +class HeatmapGraph(BaseGraph): + _name = "heatmap" + + def _get_data(self): + """ + + :return: + """ + _data = [ + self.get_instance_with_graph_parameters( + graph_type=self._graph_type, + x=self._df.columns, + y=self._df.index, + z=self._df.values.tolist(), + **self._graph_kwargs + ) + ] + return _data + + +class HistogramGraph(BaseGraph): + _name = "histogram" + + def _get_data(self): + """ + + :return: + """ + _data = [ + self.get_instance_with_graph_parameters( + graph_type=self._graph_type, x=self._df[_col], name=_name, **self._graph_kwargs + ) + for _col, _name in self._name_dict.items() + ] + return _data + + +class SubplotsGraph(object): + """Create subplots same as df.plot(subplots=True) + + Simple package for `plotly.tools.subplots` + """ + + def __init__( + self, + df: pd.DataFrame = None, + kind_map: dict = None, + layout: dict = None, + sub_graph_layout: dict = None, + sub_graph_data: list = None, + subplots_kwargs: dict = None, + **kwargs + ): + """ + + :param df: pd.DataFrame + + :param kind_map: dict, subplots graph kind and kwargs + eg: dict(kind='ScatterGraph', kwargs=dict()) + + :param layout: `go.Layout` parameters + + :param sub_graph_layout: Layout of each graphic, similar to 'layout' + + :param sub_graph_data: Instantiation parameters for each sub-graphic + eg: [(column_name, instance_parameters), ] + + column_name: str or go.Figure + + Instance_parameters: + + - row: int, the row where the graph is located + + - col: int, the col where the graph is located + + - name: str, show name, default column_name in 'df' + + - kind: str, graph kind, default `kind` param, eg: bar, scatter, ... + + - graph_kwargs: dict, graph kwargs, default {}, used in `go.Bar(**graph_kwargs)` + + :param subplots_kwargs: `plotly.tools.make_subplots` original parameters + + - shared_xaxes: bool, default False + + - shared_yaxes: bool, default False + + - vertical_spacing: float, default 0.3 / rows + + - subplot_titles: list, default [] + If `sub_graph_data` is None, will generate 'subplot_titles' according to `df.columns`, + this field will be discarded + + + - specs: list, see `make_subplots` docs + + - rows: int, Number of rows in the subplot grid, default 1 + If `sub_graph_data` is None, will generate 'rows' according to `df`, this field will be discarded + + - cols: int, Number of cols in the subplot grid, default 1 + If `sub_graph_data` is None, will generate 'cols' according to `df`, this field will be discarded + + + :param kwargs: + + """ + + self._df = df + self._layout = layout + self._sub_graph_layout = sub_graph_layout + + self._kind_map = kind_map + if self._kind_map is None: + self._kind_map = dict(kind="ScatterGraph", kwargs=dict()) + + self._subplots_kwargs = subplots_kwargs + if self._subplots_kwargs is None: + self._init_subplots_kwargs() + + self.__cols = self._subplots_kwargs.get("cols", 2) + self.__rows = self._subplots_kwargs.get("rows", math.ceil(len(self._df.columns) / self.__cols)) + + self._sub_graph_data = sub_graph_data + if self._sub_graph_data is None: + self._init_sub_graph_data() + + self._init_figure() + + def _init_sub_graph_data(self): + """ + + :return: + """ + self._sub_graph_data = list() + self._subplot_titles = list() + + for i, column_name in enumerate(self._df.columns): + row = math.ceil((i + 1) / self.__cols) + _temp = (i + 1) % self.__cols + col = _temp if _temp else self.__cols + res_name = column_name.replace("_", " ") + _temp_row_data = ( + column_name, + dict( + row=row, + col=col, + name=res_name, + kind=self._kind_map["kind"], + graph_kwargs=self._kind_map["kwargs"], + ), + ) + self._sub_graph_data.append(_temp_row_data) + self._subplot_titles.append(res_name) + + def _init_subplots_kwargs(self): + """ + + :return: + """ + # Default cols, rows + _cols = 2 + _rows = math.ceil(len(self._df.columns) / 2) + self._subplots_kwargs = dict() + self._subplots_kwargs["rows"] = _rows + self._subplots_kwargs["cols"] = _cols + self._subplots_kwargs["shared_xaxes"] = False + self._subplots_kwargs["shared_yaxes"] = False + self._subplots_kwargs["vertical_spacing"] = 0.3 / _rows + self._subplots_kwargs["print_grid"] = False + self._subplots_kwargs["subplot_titles"] = self._df.columns.tolist() + + def _init_figure(self): + """ + + :return: + """ + self._figure = make_subplots(**self._subplots_kwargs) + + for column_name, column_map in self._sub_graph_data: + if isinstance(column_name, go.Figure): + _graph_obj = column_name + elif isinstance(column_name, str): + temp_name = column_map.get("name", column_name.replace("_", " ")) + kind = column_map.get("kind", self._kind_map.get("kind", "ScatterGraph")) + _graph_kwargs = column_map.get("graph_kwargs", self._kind_map.get("kwargs", {})) + _graph_obj = BaseGraph.get_instance_with_graph_parameters( + kind, + **dict( + df=self._df.loc[:, [column_name]], + name_dict={column_name: temp_name}, + graph_kwargs=_graph_kwargs, + ) + ) + else: + raise TypeError() + + row = column_map["row"] + col = column_map["col"] + + _graph_data = getattr(_graph_obj, "data") + # for _item in _graph_data: + # _item.pop('xaxis', None) + # _item.pop('yaxis', None) + + for _g_obj in _graph_data: + self._figure.append_trace(_g_obj, row=row, col=col) + + if self._sub_graph_layout is not None: + for k, v in self._sub_graph_layout.items(): + self._figure["layout"][k].update(v) + + self._figure["layout"].update(self._layout) + + @property + def figure(self): + return self._figure diff --git a/qlib/contrib/strategy/__init__.py b/qlib/contrib/strategy/__init__.py new file mode 100644 index 0000000000..a67a7e7a55 --- /dev/null +++ b/qlib/contrib/strategy/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from .strategy import ( + TopkAmountStrategy, + TopkWeightStrategy, + TopkDropoutStrategy, + BaseStrategy, + WeightStrategyBase, +) diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py new file mode 100644 index 0000000000..58d299edbe --- /dev/null +++ b/qlib/contrib/strategy/cost_control.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from .strategy import StrategyWrapper, WeightStrategyBase +import copy + + +class SoftTopkStrategy(WeightStrategyBase): + def __init__(self, topk, max_sold_weight=1.0, risk_degree=0.95, buy_method="first_fill"): + """Parameter + topk : int + top-N stocks to buy + buffer_margin : int + buffer margin, in single score_mode, continue holding stock if it is in nlargest(margin) + margin should be no less than topk + risk_degree : float + position percentage of total value + buy_method : + rank_fill: assign the weight stocks that rank high first(1/topk max) + average_fill: assign the weight to the stocks rank high averagely. + """ + super().__init__() + self.topk = topk + self.max_sold_weight = max_sold_weight + self.risk_degree = risk_degree + self.buy_method = buy_method + + def get_risk_degree(self, date): + """get_risk_degree + Return the proportion of your total value you will used in investment. + Dynamically risk_degree will result in Market timing + """ + # It will use 95% amoutn of your total value by default + return self.risk_degree + + def generate_target_weight_position(self, score, current, trade_date): + """Parameter: + score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column + current : current position, use Position() class + trade_date : trade date + generate target position from score for this date and the current position + The cache is not considered in the position + """ + # TODO: + # If the current stock list is more than topk(eg. The weights are modified + # by risk control), the weight will not be handled correctly. + buy_signal_stocks = set(score.sort_values(ascending=False).iloc[: self.topk].index) + cur_stock_weight = current.get_stock_weight_dict(only_stock=True) + + if len(cur_stock_weight) == 0: + final_stock_weight = {code: 1 / self.topk for code in buy_signal_stocks} + else: + final_stock_weight = copy.deepcopy(cur_stock_weight) + sold_stock_weight = 0.0 + for stock_id in final_stock_weight: + if stock_id not in buy_signal_stocks: + sw = min(self.max_sold_weight, final_stock_weight[stock_id]) + sold_stock_weight += sw + final_stock_weight[stock_id] -= sw + if self.buy_method == "first_fill": + for stock_id in buy_signal_stocks: + add_weight = min( + max(1 / self.topk - final_stock_weight.get(stock_id, 0), 0.0), + sold_stock_weight, + ) + final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + add_weight + sold_stock_weight -= add_weight + elif self.buy_method == "average_fill": + for stock_id in buy_signal_stocks: + final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + sold_stock_weight / len( + buy_signal_stocks + ) + else: + raise ValueError("Buy method not found") + return final_stock_weight diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py new file mode 100644 index 0000000000..494981ecc0 --- /dev/null +++ b/qlib/contrib/strategy/order_generator.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This order generator is for strategies based on WeightStrategyBase +""" +from ..backtest.position import Position +from ..backtest.exchange import Exchange +import pandas as pd +import copy + + +class OrderGenerator: + def generate_order_list_from_target_weight_position( + self, + current: Position, + trade_exchange: Exchange, + target_weight_position: dict, + risk_degree: float, + pred_date: pd.Timestamp, + trade_date: pd.Timestamp, + ) -> list: + """generate_order_list_from_target_weight_position + + :param current: The current position + :type current: Position + :param trade_exchange: + :type trade_exchange: Exchange + :param target_weight_position: {stock_id : weight} + :type target_weight_position: dict + :param risk_degree: + :type risk_degree: float + :param pred_date: the date the score is predicted + :type pred_date: pd.Timestamp + :param trade_date: the date the stock is traded + :type trade_date: pd.Timestamp + + :rtype: list + """ + raise NotImplementedError() + + +class OrderGenWInteract(OrderGenerator): + """Order Generator With Interact""" + + def generate_order_list_from_target_weight_position( + self, + current: Position, + trade_exchange: Exchange, + target_weight_position: dict, + risk_degree: float, + pred_date: pd.Timestamp, + trade_date: pd.Timestamp, + ) -> list: + """generate_order_list_from_target_weight_position + + No adjustment for for the nontradable share. + All the tadable value is assigned to the tadable stock according to the weight. + if interact == True, will use the price at trade date to generate order list + else, will only use the price before the trade date to generate order list + + :param current: + :type current: Position + :param trade_exchange: + :type trade_exchange: Exchange + :param target_weight_position: + :type target_weight_position: dict + :param risk_degree: + :type risk_degree: float + :param pred_date: + :type pred_date: pd.Timestamp + :param trade_date: + :type trade_date: pd.Timestamp + + :rtype: list + """ + # calculate current_tradable_value + current_amount_dict = current.get_stock_amount_dict() + current_total_value = trade_exchange.calculate_amount_position_value( + amount_dict=current_amount_dict, trade_date=trade_date, only_tradable=False + ) + current_tradable_value = trade_exchange.calculate_amount_position_value( + amount_dict=current_amount_dict, trade_date=trade_date, only_tradable=True + ) + # add cash + current_tradable_value += current.get_cash() + + reserved_cash = (1.0 - risk_degree) * (current_total_value + current.get_cash()) + current_tradable_value -= reserved_cash + + if current_tradable_value < 0: + # if you sell all the tradable stock can not meet the reserved + # value. Then just sell all the stocks + target_amount_dict = copy.deepcopy(current_amount_dict.copy()) + for stock_id in list(target_amount_dict.keys()): + if trade_exchange.is_stock_tradable(stock_id, trade_date): + del target_amount_dict[stock_id] + else: + # consider cost rate + current_tradable_value /= 1 + max(trade_exchange.close_cost, trade_exchange.open_cost) + + # strategy 1 : generate amount_position by weight_position + # Use API in Exchange() + target_amount_dict = trade_exchange.generate_amount_position_from_weight_position( + weight_position=target_weight_position, + cash=current_tradable_value, + trade_date=trade_date, + ) + order_list = trade_exchange.generate_order_for_target_amount_position( + target_position=target_amount_dict, + current_position=current_amount_dict, + trade_date=trade_date, + ) + return order_list + + +class OrderGenWOInteract(OrderGenerator): + """Order Generator Without Interact""" + + def generate_order_list_from_target_weight_position( + self, + current: Position, + trade_exchange: Exchange, + target_weight_position: dict, + risk_degree: float, + pred_date: pd.Timestamp, + trade_date: pd.Timestamp, + ) -> list: + """generate_order_list_from_target_weight_position + + generate order list directly not using the information (e.g. whether can be traded, the accurate trade price) at trade date. + In target weight position, generating order list need to know the price of objective stock in trade date, but we cannot get that + value when do not interact with exchange, so we check the %close price at pred_date or price recorded in current position. + + :param current: + :type current: Position + :param trade_exchange: + :type trade_exchange: Exchange + :param target_weight_position: + :type target_weight_position: dict + :param risk_degree: + :type risk_degree: float + :param pred_date: + :type pred_date: pd.Timestamp + :param trade_date: + :type trade_date: pd.Timestamp + + :rtype: list + """ + risk_total_value = risk_degree * current.calculate_value() + + current_stock = current.get_stock_list() + amount_dict = {} + for stock_id in target_weight_position: + # Current rule will ignore the stock that not hold and cannot be traded at predict date + if trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=pred_date): + amount_dict[stock_id] = ( + risk_total_value * target_weight_position[stock_id] / trade_exchange.get_close(stock_id, pred_date) + ) + elif stock_id in current_stock: + amount_dict[stock_id] = ( + risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id) + ) + else: + continue + order_list = trade_exchange.generate_order_for_target_amount_position( + target_position=amount_dict, + current_position=current.get_stock_amount_dict(), + trade_date=trade_date, + ) + return order_list diff --git a/qlib/contrib/strategy/strategy.py b/qlib/contrib/strategy/strategy.py new file mode 100644 index 0000000000..64d7d97164 --- /dev/null +++ b/qlib/contrib/strategy/strategy.py @@ -0,0 +1,765 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import copy +import numpy as np +import pandas as pd + +from ..backtest.order import Order +from ...utils import get_pre_trading_date +from .order_generator import OrderGenWInteract + + +class BaseStrategy: + """ + # Strategy framework document + + class Strategy(BaseStrategy): + + def __init__(self): + # init for strategy + super(Strategy, self).__init__() + pass + + def generate_target_weight_position(self, score, current, trade_exchange, topk, buffer_margin, trade_date, risk_degree): + '''Parameter: + score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column + current : current position, use Position() class + trade_exchange : Exchange() + topk : topk + buffer_margin : buffer margin + trade_date : trade date + risk_degree : 0-1, 0.95 for example, use 95% money to trade + :return + target weight position + generate target position from score for this date and the current position + + ''' + # strategy 1 :select top k stocks by model scores, then equal-weight + new_stock_list = list(score.sort_values(ascending=False).iloc[:topk].index) + target_weight_position = {code: 1 / topk for code in new_stock_list} + + # strategy 2:select top buffer_margin stock as the buffer, for stock in current position: keep if in buffer, sell if not; then buy new stocks + buffer = score.sort_values(ascending=False).iloc[:buffer_margin] + current_stock_list = current.get_stock_list() + mask = buffer.index.isin(current_stock_list) + keep = set(buffer[mask].index) + new_stock_list = list(keep) + list(buffer[~mask].index[:topk-len(keep)]) + target_weight_position = {code : 1/topk for code in new_stock_list} + + return target_weight_position + + def generate_target_amount_position(self, score, current, target_weight_position ,topk, buffer_margin, trade_exchange, trade_date, risk_degree): + ''' + score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column + current : current position, use Position() class + target_weight_position : {stock_id : weight} + trade_exchange : Exchange() + topk : topk + buffer_margin : buffer margin + trade_date : trade date + risk_degree : 0-1, 0.95 for example, use 95% money to trade + :return: + ''' + # strategy 1 + # parameters : + # topk : int, select topk stocks + # buffer_margin : size of buffer margin + # + # description : + # hold topk stocks at each trade date + # when adjust position + # the score model will generate scores for each stock + # if the stock of current position not in top buffer_margin score, sell them out; + # then equally buy recommended stocks + target_amount_dict = {} + current_amount_dict = current.get_stock_amount_dict() + buffer = score.sort_values(ascending=False).iloc[:buffer_margin] + mask = buffer.index.isin(current_amount_dict) + keep = set(buffer[mask].index) + buy_stock_list = list(buffer[~mask].index[:topk - len(keep)]) + buy_cash = 0 + # calculate cash for buy order + for stock_id in current_amount_dict: + if stock_id in keep: + target_amount_dict[stock_id] = current_amount_dict[stock_id] + else: + # stock_id not in keep + # trade check + # assume all of them can be sold out + if trade_exchange.check_stock_suspended(stock_id=stock_id, trade_date=trade_date): + pass + else: + buy_cash += current_amount_dict[stock_id] * trade_exchange.get_deal_price(stock_id=stock_id, trade_date=trade_date) + # update close cost + buy_cash /= (1 + trade_exchange.close_cost) + # update cash + buy_cash += current.get_cash() + # update open cost + buy_cash /= (1 + trade_exchange.open_cost) + # consider risk degree + buy_cash *= risk_degree + # equally assigned + value = buy_cash / len(buy_stock_list) + # equally assigned + value = buy_cash / len(buy_stock_list) + for stock_id in buy_stock_list: + if trade_exchange.check_stock_suspended(stock_id=stock_id, trade_date=trade_date): + pass + else: + target_amount_dict[stock_id] = value / trade_exchange.get_deal_price(stock_id=stock_id, trade_date=trade_date) // trade_exchange.trade_unit * trade_exchange.trade_unit + return target_amount_dict + + + # strategy 2 : use trade_exchange.generate_amount_position_from_weight_position() + # calculate value for current position + current_amount_dict = current.get_stock_amount_dict() + current_tradable_value = trade_exchange.calculate_amount_position_value(amount_dict=current_amount_dict, + trade_date=trade_date, only_tradable=True) + # consider cost rate + current_tradable_value /= (1 + max(trade_exchange.close_cost, trade_exchange.open_cost)) + # consider risk degree + current_tradable_value *= risk_degree + target_amount_dict = trade_exchange.generate_amount_position_from_weight_position( + weight_position=target_weight_position, cash=current_tradable_value, trade_date=trade_date) + + return target_amount_dict + + pass + + def generate_order_list_from_target_amount_position(self, current, trade_exchange, target_amount_position, trade_date): + '''Parameter: + current : Position() + trade_exchange : Exchange() + target_amount_position : {stock_id : amount} + trade_date : trade date + generate order list from weight_position + ''' + # strategy: + current_amount_dict = current.get_stock_amount_dict() + order_list = trade_exchange.generate_order_for_target_amount_position(target_position=target_amount_position, + current_position=current_amount_dict, + trade_date=trade_date) + return order_list + + def generate_order_list_from_target_weight_position(self, current, trade_exchange, target_weight_position, risk_degree ,trade_date, interact=True): + ''' + generate order_list from weight_position + use API from trade_exchage + current : Postion(), current position + trade_exchange : Exchange() + target_weight_position : {stock_id : weight} + risk_degree : 0-1, 0.95 for example, use 95% money to trade + trade_date : trade date + interact : bool + :return: order_list + ''' + # calculate value for current position + current_amount_dict = current.get_stock_amount_dict() + current_tradable_value = trade_exchange.calculate_amount_position_value(amount_dict=current_amount_dict, trade_date=trade_date, only_tradable=True) + # add cash + current_tradable_value += current.get_cash() + # consider cost rate + current_tradable_value /= (1+max(trade_exchange.close_cost, trade_exchange.open_cost)) + # consider risk degree + current_tradable_value *= risk_degree + # Strategy 1 : generate amount_position from weight_position + # use API of trade_exchange + target_amount_dict = trade_exchange.generate_amount_position_from_weight_position(weight_position=target_weight_position, cash=current_tradable_value, trade_date=trade_date) + order_list = trade_exchange.generate_order_for_target_amount_position(target_position=target_amount_dict, current_position=current_amount_dict, trade_date=trade_date) + + return order_list + + def generate_order_list(self, score_series, current, trade_exchange, trade_date, topk, margin, risk_degree): + ''' + score_series: pred score for this trade date, pd.Series, index is stock_id, contain 'score' column + current: Postion(), current position + trade_exchange: trade date + trade_date: + topk: topk + margin: buffer margin + risk_degree: risk_degree : 0-1, 0.95 for example, use 95% money to trade + :return: order list : list of Order() + ''' + # generate_order_list + # strategy 1,generate_target_weight_position() and xecute_target_weight_position_by_order_list() for order_list + if not self.is_adjust(trade_date): + return [] + target_weigth_position = self.generate_target_weight_position(score=score_series, + current=current, + trade_exchange=trade_exchange, + topk=topk, + buffer_margin=margin, + trade_date=trade_date, + risk_degree=risk_degree + ) + order_list = self.generate_order_list_from_target_weight_positione( current=current, + trade_exchange=trade_exchange, + target_weight_position=target_weigth_position, + risk_degree=risk_degree, + trade_date=trade_date) + + + # strategy 2 : amount_position's view generate_target_amount_position() and generate_order_list_from_target_amount_position() to generate order_list + target_amount_position = self.generate_target_amount_position(score=score_series, + current=current, + trade_exchange=trade_exchange, + target_weight_position=None, + topk=topk, + buffer_margin=margin, + trade_date=trade_date, + risk_degree=risk_degree + ) + order_list = self.generate_order_list_from_target_amount_position(current=current, + trade_exchange=trade_exchange, + target_amount_position=target_amount_position, + trade_date=trade_date) + return order_list + """ + + def __init__(self): + pass + + def get_risk_degree(self, date): + """get_risk_degree + Return the proportion of your total value you will used in investment. + Dynamically risk_degree will result in Market timing + """ + # It will use 95% amount of your total value by default + return 0.95 + + def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date): + """Parameter + score_series : pd.Seires + stock_id , score + current : Position() + current state of position + DO NOT directly change the state of current + trade_exchange : Exchange() + trade exchange + pred_date : pd.Timestamp + predict date + trade_date : pd.Timestamp + trade date + + DO NOT directly change the state of current + """ + pass + + def update(self, score_series, pred_date, trade_date): + """User can use this method to update strategy state each trade date. + Parameter + --------- + score_series : pd.Series + stock_id , score + pred_date : pd.Timestamp + oredict date + trade_date : pd.Timestamp + trade date + """ + pass + + def init(self, **kwargs): + """Some strategy need to be initial after been implemented, + User can use this method to init his strategy with parameters needed. + """ + pass + + def get_init_args_from_model(self, model, init_date): + """ + This method only be used in 'online' module, it will generate the *args to initial the strategy. + :param + mode : model used in 'online' module + """ + return {} + + +class StrategyWrapper: + """ + StrategyWrapper is a wrapper of another strategy. + By overriding some methods to make some changes on the basic strategy + Cost control and risk control will base on this class. + """ + + def __init__(self, inner_strategy): + """__init__ + + :param inner_strategy: set the inner strategy + """ + self.inner_strategy = inner_strategy + + def __getattr__(self, name): + """__getattr__ + + :param name: If no implementation in this method. Call the method in the innter_strategy by default. + """ + return getattr(self.inner_strategy, name) + + +class AdjustTimer: + """AdjustTimer + Responsible for timing of position adjusting + + This is designed as multiple inheritance mechanism due to + - the is_adjust may need access to the internel state of a strategyw + - it can be reguard as a enhancement to the existing strategy + """ + + # adjust position in each trade date + def is_adjust(self, trade_date): + """is_adjust + Return if the strategy can adjust positions on `trade_date` + Will normally be used in strategy do trading with trade frequency + """ + return True + + +class ListAdjustTimer(AdjustTimer): + def __init__(self, adjust_dates=None): + """__init__ + + :param adjust_dates: an iterable object, it will return a timelist for trading dates + """ + if adjust_dates is None: + # None indicates that all dates is OK for adjusting + self.adjust_dates = None + else: + self.adjust_dates = {pd.Timestamp(dt) for dt in adjust_dates} + + def is_adjust(self, trade_date): + if self.adjust_dates is None: + return True + return pd.Timestamp(trade_date) in self.adjust_dates + + +class WeightStrategyBase(BaseStrategy, AdjustTimer): + def __init__(self, order_generator_cls_or_obj=OrderGenWInteract, *args, **kwargs): + super().__init__(*args, **kwargs) + if isinstance(order_generator_cls_or_obj, type): + self.order_generator = order_generator_cls_or_obj() + else: + self.order_generator = order_generator_cls_or_obj + + def generate_target_weight_position(self, score, current, trade_date): + """Parameter: + score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column + current : current position, use Position() class + trade_exchange : Exchange() + trade_date : trade date + generate target position from score for this date and the current position + The cash is not considered in the position + """ + raise NotImplementedError() + + def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date): + """Parameter + score_series : pd.Seires + stock_id , score + current : Position() + current of account + trade_exchange : Exchange() + exchange + trade_date : pd.Timestamp + date + """ + # judge if to adjust + if not self.is_adjust(trade_date): + return [] + # generate_order_list + # generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list + current_temp = copy.deepcopy(current) + target_weight_position = self.generate_target_weight_position( + score=score_series, current=current_temp, trade_date=trade_date + ) + + order_list = self.order_generator.generate_order_list_from_target_weight_position( + current=current_temp, + trade_exchange=trade_exchange, + risk_degree=self.get_risk_degree(trade_date), + target_weight_position=target_weight_position, + pred_date=pred_date, + trade_date=trade_date, + ) + return order_list + + +def get_sell_limit(score, buffer_margin): + """get_sell_limit + + :param score: pred score for this trade date, pd.Series, index is stock_id, contain 'score' column + :param buffer_margin: int or float + """ + if isinstance(buffer_margin, int): + return buffer_margin + else: + if buffer_margin < 0.0 or buffer_margin > 1.0: + raise ValueError("Buffer margin should range in [0, 1]") + return int(score.count() * buffer_margin) + + +class MarginInterface: + def get_buffer_margin(self, trade_date): + """get_buffer_margin + Get the buffer margin dynamically for topk strategy. + + :param trade_date: trading date + """ + raise NotImplementedError("Please implement the margin dynamically") + + +class TopkWeightStrategy(ListAdjustTimer, WeightStrategyBase, MarginInterface): + # NOTE: The list adjust Timer must be placed before WeightStrategyBase before. + def __init__(self, topk, buffer_margin, risk_degree=0.95, **kwargs): + """Parameter + topk : int + top-N stocks to buy + + buffer_margin : int or float + if isinstance(margin, int): + sell_limit = margin + else: + sell_limit = pred_in_a_day.count() * margin + buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit) + sell_limit should be no less than topk + + risk_degree : float + position percentage of total value + """ + WeightStrategyBase.__init__(self, **kwargs) + ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None)) + self.topk = topk + self.buffer_margin = buffer_margin + self.risk_degree = risk_degree + + def get_risk_degree(self, date): + """get_risk_degree + Return the proportion of your total value you will used in investment. + Dynamically risk_degree will result in Market timing + """ + # It will use 95% amoutn of your total value by default + return self.risk_degree + + def get_buffer_margin(self, trade_date): + return self.buffer_margin + + def generate_target_weight_position(self, score, current, trade_date): + """Parameter: + score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column + current : current position, use Position() class + trade_exchange : Exchange() + trade_date : trade date + generate target position from score for this date and the current position + The cache is not considered in the position + """ + sell_limit = get_sell_limit(score, self.get_buffer_margin(trade_date)) + buffer = score.sort_values(ascending=False).iloc[:sell_limit] + if sell_limit <= self.topk: + # no buffer + target_weight_position = {code: 1 / self.topk for code in buffer.index} + else: + # buffer is considered + current_stock_list = current.get_stock_list() + mask = buffer.index.isin(current_stock_list) + keep = set(buffer[mask].index) + new_stock_list = list(keep) + if len(keep) < self.topk: + new_stock_list += list(buffer[~mask].index[: self.topk - len(keep)]) + else: + # truncate the stocks + new_stock_list.sort(key=score.get, reverse=True) + new_stock_list = new_stock_list[: self.topk] + target_weight_position = {code: 1 / self.topk for code in new_stock_list} + return target_weight_position + + +class TopkAmountStrategy(BaseStrategy, MarginInterface, ListAdjustTimer): + def __init__(self, topk, buffer_margin, risk_degree=0.95, thresh=1, hold_thresh=1, **kwargs): + """Parameter + topk : int + top-N stocks to buy + buffer_margin : int or float + if isinstance(margin, int): + sell_limit = margin + else: + sell_limit = pred_in_a_day.count() * margin + buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit) + sell_limit should be no less than topk + risk_degree : float + position percentage of total value + thresh : int + minimun holding days since last buy singal of the stock + hold_thresh : int + minimum holding days + before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh + """ + BaseStrategy.__init__(self) + ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None)) + self.topk = topk + self.buffer_margin = buffer_margin + self.risk_degree = risk_degree + self.thresh = thresh + # self.stock_count['code'] will be the days the stock has been hold + # since last buy signal. This is designed for thresh + self.stock_count = {} + + self.hold_thresh = hold_thresh + + def get_risk_degree(self, date): + """get_risk_degree + Return the proportion of your total value you will used in investment. + Dynamically risk_degree will result in Market timing + """ + # It will use 95% amoutn of your total value by default + return self.risk_degree + + def get_buffer_margin(self, trade_date): + return self.buffer_margin + + def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date): + """Gnererate order list according to score_series at trade_date. + will not change current. + Parameter + score_series : pd.Seires + stock_id , score + current : Position() + current of account + trade_exchange : Exchange() + exchange + pred_date : pd.Timestamp + predict date + trade_date : pd.Timestamp + trade date + """ + if not self.is_adjust(trade_date): + return [] + # generate order list for this adjust date + current_temp = copy.deepcopy( + current + ) # this copy is necessary. Due to the trade_exchange.calc_deal_order will simulate the dealing process + + sell_order_list = [] + buy_order_list = [] + # load score + cash = current_temp.get_cash() + buffer = score_series.sort_values(ascending=False).iloc[ + : get_sell_limit(score_series, self.get_buffer_margin(trade_date)) + ] + current_stock_list = current_temp.get_stock_list() + mask = buffer.index.isin(current_stock_list) + keep = set(buffer[mask].index) + # stocks that get buy signals + buy = set(buffer.iloc[: self.topk].index) + new = buffer[~mask].index.get_level_values(0) # new stocks to buy + # sell stock not in keep + # sell mode: sell all + for code in current_stock_list: + if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date): + continue + if code not in keep: + # check hold limit + if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code) < self.hold_thresh: + # can not sell this code + # no buy signal, but the stock is kept + self.stock_count[code] += 1 + continue + # sell order + sell_amount = current_temp.get_stock_amount(code=code) + sell_order = Order( + stock_id=code, + amount=sell_amount, + trade_date=trade_date, + direction=Order.SELL, # 0 for sell, 1 for buy + factor=trade_exchange.get_factor(code, trade_date), + ) + # is order executable + if trade_exchange.check_order(sell_order): + sell_order_list.append(sell_order) + trade_val, trade_cost, trade_price = trade_exchange.deal_order(sell_order, position=current_temp) + # update cash + cash += trade_val - trade_cost + # sold + del self.stock_count[code] + else: + # no buy signal, but the stock is kept + self.stock_count[code] += 1 + elif code in buy: + # NOTE: This is different from the original version + # get new buy signal + # Only the stock fall in to topk will produce buy signal + # Only in margin will no produce buy signal + self.stock_count[code] = 1 + else: + self.stock_count[code] += 1 + # buy new stock + # note the current has been changed + current_stock_list = current_temp.get_stock_list() + n_buy = self.topk - len(current_stock_list) + value = cash * self.risk_degree / n_buy if n_buy > 0 else 0 + + # open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not consider it + # as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line + # value = value / (1+trade_exchange.open_cost) # set open_cost limit + for code in new[:n_buy]: + # check is stock supended + if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date): + continue + # buy order + buy_price = trade_exchange.get_deal_price(stock_id=code, trade_date=trade_date) + buy_amount = value / buy_price + factor = trade_exchange.quote[(code, trade_date)]["$factor"] + buy_amount = trade_exchange.round_amount_by_trade_unit(buy_amount, factor) + buy_order = Order( + stock_id=code, + amount=buy_amount, + trade_date=trade_date, + direction=Order.BUY, # 1 for buy + factor=factor, + ) + buy_order_list.append(buy_order) + self.stock_count[code] = 1 + return sell_order_list + buy_order_list + + +class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer): + def __init__(self, topk, n_drop, method="bottom", risk_degree=0.95, thresh=1, hold_thresh=1, **kwargs): + """Parameter + topk : int + top-N stocks to buy + n_drop : int + number of stocks to be replaced + method : str + dropout method, random/bottom + risk_degree : float + position percentage of total value + thresh : int + minimun holding days since last buy singal of the stock + hold_thresh : int + minimum holding days + before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh + """ + super(TopkDropoutStrategy, self).__init__() + ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None)) + self.topk = topk + self.n_drop = n_drop + self.method = method + self.risk_degree = risk_degree + self.thresh = thresh + # self.stock_count['code'] will be the days the stock has been hold + # since last buy signal. This is designed for thresh + self.stock_count = {} + + self.hold_thresh = hold_thresh + + def get_risk_degree(self, date): + """get_risk_degree + Return the proportion of your total value you will used in investment. + Dynamically risk_degree will result in Market timing + """ + # It will use 95% amoutn of your total value by default + return self.risk_degree + + def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date): + """Gnererate order list according to score_series at trade_date. + will not change current. + Parameter + score_series : pd.Seires + stock_id , score + current : Position() + current of account + trade_exchange : Exchange() + exchange + pred_date : pd.Timestamp + predict date + trade_date : pd.Timestamp + trade date + """ + if not self.is_adjust(trade_date): + return [] + current_temp = copy.deepcopy(current) + # generate order list for this adjust date + sell_order_list = [] + buy_order_list = [] + # load score + cash = current_temp.get_cash() + current_stock_list = current_temp.get_stock_list() + last = score_series.reindex(current_stock_list).sort_values(ascending=False).index + today = ( + score_series[~score_series.index.isin(last)] + .sort_values(ascending=False) + .index[: self.n_drop + self.topk - len(last)] + ) + comb = score_series.reindex(last.union(today)).sort_values(ascending=False).index + if self.method == "bottom": + sell = last[last.isin(comb[-self.n_drop :])] + elif self.method == "random": + sell = pd.Index(np.random.choice(last, self.n_drop) if len(last) else []) + buy = today[: len(sell) + self.topk - len(last)] + for code in current_stock_list: + if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date): + continue + if code in sell: + # check hold limit + if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code) < self.hold_thresh: + # can not sell this code + # no buy signal, but the stock is kept + self.stock_count[code] += 1 + continue + # sell order + sell_amount = current_temp.get_stock_amount(code=code) + sell_order = Order( + stock_id=code, + amount=sell_amount, + trade_date=trade_date, + direction=Order.SELL, # 0 for sell, 1 for buy + factor=trade_exchange.get_factor(code, trade_date), + ) + # is order executable + if trade_exchange.check_order(sell_order): + sell_order_list.append(sell_order) + # excute the order + trade_val, trade_cost, trade_price = trade_exchange.calc_deal_order(sell_order) + # update cash + cash += trade_val - trade_cost + # updte current + current_temp.update_order(sell_order, trade_price) + # sold + del self.stock_count[code] + else: + # no buy signal, but the stock is kept + self.stock_count[code] += 1 + elif code in buy: + # NOTE: This is different from the original version + # get new buy signal + # Only the stock fall in to topk will produce buy signal + # Only in margin will no produce buy signal + self.stock_count[code] = 1 + else: + self.stock_count[code] += 1 + # buy new stock + # note the current has been changed + current_stock_list = current_temp.get_stock_list() + value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0 + + # open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not consider it + # as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line + # value = value / (1+trade_exchange.open_cost) # set open_cost limit + for code in buy: + # check is stock supended + if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date): + continue + # buy order + buy_price = trade_exchange.get_deal_price(stock_id=code, trade_date=trade_date) + buy_amount = value / buy_price + factor = trade_exchange.quote[(code, trade_date)]["$factor"] + buy_amount = trade_exchange.round_amount_by_trade_unit(buy_amount, factor) + buy_order = Order( + stock_id=code, + amount=buy_amount, + trade_date=trade_date, + direction=Order.BUY, # 1 for buy + factor=factor, + ) + buy_order_list.append(buy_order) + self.stock_count[code] = 1 + return sell_order_list + buy_order_list diff --git a/qlib/contrib/tuner/__init__.py b/qlib/contrib/tuner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/qlib/contrib/tuner/config.py b/qlib/contrib/tuner/config.py new file mode 100644 index 0000000000..28796bcf2f --- /dev/null +++ b/qlib/contrib/tuner/config.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import yaml +import copy +import os + + +class TunerConfigManager(object): + def __init__(self, config_path): + + if not config_path: + raise ValueError("Config path is invalid.") + self.config_path = config_path + + with open(config_path) as fp: + config = yaml.load(fp) + self.config = copy.deepcopy(config) + + self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self) + self.pipeline_config = config.get("tuner_pipeline", list()) + self.optim_config = OptimizationConfig(config.get("optimization_criteria", dict()), self) + + self.time_config = config.get("time_period", dict()) + self.data_config = config.get("data", dict()) + self.backtest_config = config.get("backtest", dict()) + self.qlib_client_config = config.get("qlib_client", dict()) + + +class PipelineExperimentConfig(object): + def __init__(self, config, TUNER_CONFIG_MANAGER): + """ + :param config: The config dict for tuner experiment + :param TUNER_CONFIG_MANAGER: The tuner config manager + """ + self.name = config.get("name", "tuner_experiment") + # The dir of the config + self.global_dir = config.get("dir", os.path.dirname(TUNER_CONFIG_MANAGER.config_path)) + # The dir of the result of tuner experiment + self.tuner_ex_dir = config.get("tuner_ex_dir", os.path.join(self.global_dir, self.name)) + if not os.path.exists(self.tuner_ex_dir): + os.makedirs(self.tuner_ex_dir) + # The dir of the results of all estimator experiments + self.estimator_ex_dir = config.get("estimator_ex_dir", os.path.join(self.tuner_ex_dir, "estimator_experiment")) + if not os.path.exists(self.estimator_ex_dir): + os.makedirs(self.estimator_ex_dir) + # Get the tuner type + self.tuner_module_path = config.get("tuner_module_path", "qlib.contrib.tuner.tuner") + self.tuner_class = config.get("tuner_class", "QLibTuner") + # Save the tuner experiment for further view + tuner_ex_config_path = os.path.join(self.tuner_ex_dir, "tuner_config.yaml") + with open(tuner_ex_config_path, "w") as fp: + yaml.dump(TUNER_CONFIG_MANAGER.config, fp) + + +class OptimizationConfig(object): + def __init__(self, config, TUNER_CONFIG_MANAGER): + + self.report_type = config.get("report_type", "pred_long") + if self.report_type not in [ + "pred_long", + "pred_long_short", + "pred_short", + "sub_bench", + "sub_cost", + "model", + ]: + raise ValueError( + "report_type should be one of pred_long, pred_long_short, pred_short, sub_bench, sub_cost and model" + ) + + self.report_factor = config.get("report_factor", "sharpe") + if self.report_factor not in [ + "annual", + "sharpe", + "mdd", + "mean", + "std", + "model_score", + "model_pearsonr", + ]: + raise ValueError( + "report_factor should be one of annual, sharpe, mdd, mean, std, model_pearsonr and model_score" + ) + + self.optim_type = config.get("optim_type", "max") + if self.optim_type not in ["min", "max", "correlation"]: + raise ValueError("optim_type should be min, max or correlation") diff --git a/qlib/contrib/tuner/launcher.py b/qlib/contrib/tuner/launcher.py new file mode 100644 index 0000000000..711658c9a6 --- /dev/null +++ b/qlib/contrib/tuner/launcher.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# coding=utf-8 + +import argparse +import importlib +import os +import yaml + +from .config import TunerConfigManager + + +args_parser = argparse.ArgumentParser(prog="tuner") +args_parser.add_argument( + "-c", + "--config_path", + required=True, + type=str, + help="config path indicates where to load yaml config.", +) + +args = args_parser.parse_args() + +TUNER_CONFIG_MANAGER = TunerConfigManager(args.config_path) + + +def run(): + # 1. Get pipeline class. + tuner_pipeline_class = getattr(importlib.import_module(".pipeline", package="qlib.contrib.tuner"), "Pipeline") + # 2. Init tuner pipeline. + tuner_pipeline = tuner_pipeline_class(TUNER_CONFIG_MANAGER) + # 3. Begin to tune + tuner_pipeline.run() diff --git a/qlib/contrib/tuner/pipeline.py b/qlib/contrib/tuner/pipeline.py new file mode 100644 index 0000000000..3a76d071d2 --- /dev/null +++ b/qlib/contrib/tuner/pipeline.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import json +import logging +import importlib +from abc import abstractmethod + +from ...log import get_module_logger, TimeInspector +from ...utils import get_module_by_module_path + + +class Pipeline(object): + + GLOBAL_BEST_PARAMS_NAME = "global_best_params.json" + + def __init__(self, tuner_config_manager): + + self.logger = get_module_logger("Pipeline", sh_level=logging.INFO) + + self.tuner_config_manager = tuner_config_manager + + self.pipeline_ex_config = tuner_config_manager.pipeline_ex_config + self.optim_config = tuner_config_manager.optim_config + self.time_config = tuner_config_manager.time_config + self.pipeline_config = tuner_config_manager.pipeline_config + self.data_config = tuner_config_manager.data_config + self.backtest_config = tuner_config_manager.backtest_config + self.qlib_client_config = tuner_config_manager.qlib_client_config + + self.global_best_res = None + self.global_best_params = None + self.best_tuner_index = None + + def run(self): + + TimeInspector.set_time_mark() + for tuner_index, tuner_config in enumerate(self.pipeline_config): + tuner = self.init_tuner(tuner_index, tuner_config) + tuner.tune() + if self.global_best_res is None or self.global_best_res > tuner.best_res: + self.global_best_res = tuner.best_res + self.global_best_params = tuner.best_params + self.best_tuner_index = tuner_index + TimeInspector.log_cost_time("Finished tuner pipeline.") + + self.save_tuner_exp_info() + + def init_tuner(self, tuner_index, tuner_config): + """ + Implement this method to build the tuner by config + return: tuner + """ + # 1. Add experiment config in tuner_config + tuner_config["experiment"] = { + "name": "estimator_experiment_{}".format(tuner_index), + "id": tuner_index, + "dir": self.pipeline_ex_config.estimator_ex_dir, + "observer_type": "file_storage", + } + tuner_config["qlib_client"] = self.qlib_client_config + # 2. Add data config in tuner_config + tuner_config["data"] = self.data_config + # 3. Add backtest config in tuner_config + tuner_config["backtest"] = self.backtest_config + # 4. Update trainer in tuner_config + tuner_config["trainer"].update({"args": self.time_config}) + + # 5. Import Tuner class + tuner_module = get_module_by_module_path(self.pipeline_ex_config.tuner_module_path) + tuner_class = getattr(tuner_module, self.pipeline_ex_config.tuner_class) + # 6. Return the specific tuner + return tuner_class(tuner_config, self.optim_config) + + def save_tuner_exp_info(self): + + TimeInspector.set_time_mark() + save_path = os.path.join(self.pipeline_ex_config.tuner_ex_dir, Pipeline.GLOBAL_BEST_PARAMS_NAME) + with open(save_path, "w") as fp: + json.dump(self.global_best_params, fp) + TimeInspector.log_cost_time("Finished save global best tuner parameters.") + + self.logger.info("Best Tuner id: {}.".format(self.best_tuner_index)) + self.logger.info("Global best parameters: {}.".format(self.global_best_params)) + self.logger.info("You can check the best parameters at {}.".format(save_path)) diff --git a/qlib/contrib/tuner/space.py b/qlib/contrib/tuner/space.py new file mode 100644 index 0000000000..76f101671b --- /dev/null +++ b/qlib/contrib/tuner/space.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from hyperopt import hp + + +TopkAmountStrategySpace = { + "topk": hp.choice("topk", [30, 35, 40]), + "buffer_margin": hp.choice("buffer_margin", [200, 250, 300]), +} + +QLibDataLabelSpace = { + "labels": hp.choice( + "labels", + [["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["Ref($close, -5)/$close - 1"]], + ) +} diff --git a/qlib/contrib/tuner/tuner.py b/qlib/contrib/tuner/tuner.py new file mode 100644 index 0000000000..8da40bc695 --- /dev/null +++ b/qlib/contrib/tuner/tuner.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import yaml +import json +import copy +import pickle +import logging +import importlib +import subprocess +import pandas as pd +import numpy as np + +from abc import abstractmethod + +from ...log import get_module_logger, TimeInspector +from hyperopt import fmin, tpe +from hyperopt import STATUS_OK, STATUS_FAIL + + +class Tuner(object): + def __init__(self, tuner_config, optim_config): + + self.logger = get_module_logger("Tuner", sh_level=logging.INFO) + + self.tuner_config = tuner_config + self.optim_config = optim_config + + self.max_evals = self.tuner_config.get("max_evals", 10) + self.ex_dir = os.path.join( + self.tuner_config["experiment"]["dir"], + self.tuner_config["experiment"]["name"], + ) + + self.best_params = None + self.best_res = None + + self.space = self.setup_space() + + def tune(self): + + TimeInspector.set_time_mark() + fmin( + fn=self.objective, + space=self.space, + algo=tpe.suggest, + max_evals=self.max_evals, + ) + self.logger.info("Local best params: {} ".format(self.best_params)) + TimeInspector.log_cost_time( + "Finished searching best parameters in Tuner {}.".format(self.tuner_config["experiment"]["id"]) + ) + + self.save_local_best_params() + + @abstractmethod + def objective(self, params): + """ + Implement this method to give an optimization factor using parameters in space. + :return: {'loss': a factor for optimization, float type, + 'status': the status of this evaluation step, STATUS_OK or STATUS_FAIL}. + """ + pass + + @abstractmethod + def setup_space(self): + """ + Implement this method to setup the searching space of tuner. + :return: searching space, dict type. + """ + pass + + @abstractmethod + def save_local_best_params(self): + """ + Implement this method to save the best parameters of this tuner. + """ + pass + + +class QLibTuner(Tuner): + + ESTIMATOR_CONFIG_NAME = "estimator_config.yaml" + EXP_INFO_NAME = "exp_info.json" + EXP_RESULT_DIR = "sacred/{}" + EXP_RESULT_NAME = "analysis.pkl" + LOCAL_BEST_PARAMS_NAME = "local_best_params.json" + + def objective(self, params): + + # 1. Setup an config for a spcific estimator process + estimator_path = self.setup_estimator_config(params) + self.logger.info("Searching params: {} ".format(params)) + + # 2. Use subprocess to do the estimator program, this process will wait until subprocess finish + sub_fails = subprocess.call("estimator -c {}".format(estimator_path), shell=True) + if sub_fails: + # If this subprocess failed, ignore this evaluation step + self.logger.info("Estimator experiment failed when using this searching parameters") + return {"loss": np.nan, "status": STATUS_FAIL} + + # 3. Fetch the result of subprocess, and check whether the result is Nan + res = self.fetch_result() + if np.isnan(res): + status = STATUS_FAIL + else: + status = STATUS_OK + + # 4. Save the best score and params + if self.best_res is None or self.best_res > res: + self.best_res = res + self.best_params = params + + # 5. Return the result as optim objective + return {"loss": res, "status": status} + + def fetch_result(self): + + # 1. Get experiment information + exp_info_path = os.path.join(self.ex_dir, QLibTuner.EXP_INFO_NAME) + with open(exp_info_path) as fp: + exp_info = json.load(fp) + estimator_ex_id = exp_info["id"] + + # 2. Return model result if needed + if self.optim_config.report_type == "model": + if self.optim_config.report_factor == "model_score": + # if estimator experiment is multi-label training, user need to process the scores by himself + # Default method is return the average score + return np.mean(exp_info["performance"]["model_score"]) + elif self.optim_config.report_factor == "model_pearsonr": + # pearsonr is a correlation coefficient, 1 is the best + return np.abs(exp_info["performance"]["model_pearsonr"] - 1) + + # 3. Get backtest results + exp_result_dir = os.path.join(self.ex_dir, QLibTuner.EXP_RESULT_DIR.format(estimator_ex_id)) + exp_result_path = os.path.join(exp_result_dir, QLibTuner.EXP_RESULT_NAME) + with open(exp_result_path, "rb") as fp: + analysis_df = pickle.load(fp) + + # 4. Get the backtest factor which user want to optimize, if user want to maximize the factor, then reverse the result + res = analysis_df.loc[self.optim_config.report_type].loc[self.optim_config.report_factor] + # res = res.values[0] if self.optim_config.optim_type == 'min' else -res.values[0] + if self.optim_config == "min": + return res.values[0] + elif self.optim_config == "max": + return -res.values[0] + else: + # self.optim_config == 'correlation' + return np.abs(res.values[0] - 1) + + def setup_estimator_config(self, params): + + estimator_config = copy.deepcopy(self.tuner_config) + estimator_config["model"].update({"args": params["model_space"]}) + estimator_config["strategy"].update({"args": params["strategy_space"]}) + if params.get("data_label_space", None) is not None: + estimator_config["data"]["args"].update(params["data_label_space"]) + + estimator_path = os.path.join( + self.tuner_config["experiment"].get("dir", "../"), + QLibTuner.ESTIMATOR_CONFIG_NAME, + ) + + with open(estimator_path, "w") as fp: + yaml.dump(estimator_config, fp) + + return estimator_path + + def setup_space(self): + # 1. Setup model space + model_space_name = self.tuner_config["model"].get("space", None) + if model_space_name is None: + raise ValueError("Please give the search space of model.") + model_space = getattr( + importlib.import_module(".space", package="qlib.contrib.tuner"), + model_space_name, + ) + + # 2. Setup strategy space + strategy_space_name = self.tuner_config["strategy"].get("space", None) + if strategy_space_name is None: + raise ValueError("Please give the search space of strategy.") + strategy_space = getattr( + importlib.import_module(".space", package="qlib.contrib.tuner"), + strategy_space_name, + ) + + # 3. Setup data label space if given + if self.tuner_config.get("data_label", None) is not None: + data_label_space_name = self.tuner_config["data_label"].get("space", None) + if data_label_space_name is not None: + data_label_space = getattr( + importlib.import_module(".space", package="qlib.contrib.tuner"), + data_label_space_name, + ) + else: + data_label_space_name = None + + # 4. Combine the searching space + space = dict() + space.update({"model_space": model_space}) + space.update({"strategy_space": strategy_space}) + if data_label_space_name is not None: + space.update({"data_label_space": data_label_space}) + + return space + + def save_local_best_params(self): + + TimeInspector.set_time_mark() + local_best_params_path = os.path.join(self.ex_dir, QLibTuner.LOCAL_BEST_PARAMS_NAME) + with open(local_best_params_path, "w") as fp: + json.dump(self.best_params, fp) + TimeInspector.log_cost_time( + "Finished saving local best tuner parameters to: {} .".format(local_best_params_path) + ) diff --git a/qlib/data/__init__.py b/qlib/data/__init__.py new file mode 100644 index 0000000000..b6eb66468c --- /dev/null +++ b/qlib/data/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +from .data import ( + D, + CalendarProvider, + InstrumentProvider, + FeatureProvider, + ExpressionProvider, + DatasetProvider, + LocalCalendarProvider, + LocalInstrumentProvider, + LocalFeatureProvider, + LocalExpressionProvider, + LocalDatasetProvider, + ClientCalendarProvider, + ClientInstrumentProvider, + ClientDatasetProvider, + BaseProvider, + LocalProvider, + ClientProvider, +) + +from .cache import ( + ExpressionCache, + DatasetCache, + ServerExpressionCache, + ServerDatasetCache, + SimpleDatasetCache, + ClientDatasetCache, + ClientCalendarCache, +) diff --git a/qlib/data/_libs/__init__.py b/qlib/data/_libs/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/qlib/data/_libs/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/qlib/data/_libs/expanding.pyx b/qlib/data/_libs/expanding.pyx new file mode 100644 index 0000000000..76b824c947 --- /dev/null +++ b/qlib/data/_libs/expanding.pyx @@ -0,0 +1,152 @@ +# cython: profile=False +# cython: boundscheck=False, wraparound=False, cdivision=True +cimport cython +cimport numpy as np +import numpy as np + +from libc.math cimport sqrt, isnan, NAN +from libcpp.vector cimport vector + + +cdef class Expanding(object): + """1-D array expanding""" + cdef vector[double] barv + cdef int na_count + def __init__(self): + self.na_count = 0 + + cdef double update(self, double val): + pass + + +cdef class Mean(Expanding): + """1-D array expanding mean""" + cdef double vsum + def __init__(self): + super(Mean, self).__init__() + self.vsum = 0 + + cdef double update(self, double val): + self.barv.push_back(val) + if isnan(val): + self.na_count += 1 + else: + self.vsum += val + return self.vsum / (self.barv.size() - self.na_count) + + +cdef class Slope(Expanding): + """1-D array expanding slope""" + cdef double x_sum + cdef double x2_sum + cdef double y_sum + cdef double xy_sum + def __init__(self): + super(Slope, self).__init__() + self.x_sum = 0 + self.x2_sum = 0 + self.y_sum = 0 + self.xy_sum = 0 + + cdef double update(self, double val): + self.barv.push_back(val) + cdef size_t size = self.barv.size() + if isnan(val): + self.na_count += 1 + else: + self.x_sum += size + self.x2_sum += size * size + self.y_sum += val + self.xy_sum += size * val + cdef int N = size - self.na_count + return (N*self.xy_sum - self.x_sum*self.y_sum) / \ + (N*self.x2_sum - self.x_sum*self.x_sum) + + +cdef class Resi(Expanding): + """1-D array expanding residuals""" + cdef double x_sum + cdef double x2_sum + cdef double y_sum + cdef double xy_sum + def __init__(self): + super(Resi, self).__init__() + self.x_sum = 0 + self.x2_sum = 0 + self.y_sum = 0 + self.xy_sum = 0 + + cdef double update(self, double val): + self.barv.push_back(val) + cdef size_t size = self.barv.size() + if isnan(val): + self.na_count += 1 + else: + self.x_sum += size + self.x2_sum += size * size + self.y_sum += val + self.xy_sum += size * val + cdef int N = size - self.na_count + slope = (N*self.xy_sum - self.x_sum*self.y_sum) / \ + (N*self.x2_sum - self.x_sum*self.x_sum) + x_mean = self.x_sum / N + y_mean = self.y_sum / N + interp = y_mean - slope*x_mean + return val - (slope*size + interp) + + +cdef class Rsquare(Expanding): + """1-D array expanding rsquare""" + cdef double x_sum + cdef double x2_sum + cdef double y_sum + cdef double y2_sum + cdef double xy_sum + def __init__(self): + super(Rsquare, self).__init__() + self.x_sum = 0 + self.x2_sum = 0 + self.y_sum = 0 + self.y2_sum = 0 + self.xy_sum = 0 + + cdef double update(self, double val): + self.barv.push_back(val) + cdef size_t size = self.barv.size() + if isnan(val): + self.na_count += 1 + else: + self.x_sum += size + self.x2_sum += size + self.y_sum += val + self.y2_sum += val * val + self.xy_sum += size * val + cdef int N = size - self.na_count + cdef double rvalue = (N*self.xy_sum - self.x_sum*self.y_sum) / \ + sqrt((N*self.x2_sum - self.x_sum*self.x_sum) * (N*self.y2_sum - self.y_sum*self.y_sum)) + return rvalue * rvalue + + +cdef np.ndarray[double, ndim=1] expanding(Expanding r, np.ndarray a): + cdef int i + cdef int N = len(a) + cdef np.ndarray[double, ndim=1] ret = np.empty(N) + for i in range(N): + ret[i] = r.update(a[i]) + return ret + +def expanding_mean(np.ndarray a): + cdef Mean r = Mean() + return expanding(r, a) + +def expanding_slope(np.ndarray a): + cdef Slope r = Slope() + return expanding(r, a) + +def expanding_rsquare(np.ndarray a): + cdef Rsquare r = Rsquare() + return expanding(r, a) + +def expanding_resi(np.ndarray a): + cdef Resi r = Resi() + return expanding(r, a) diff --git a/qlib/data/_libs/rolling.pyx b/qlib/data/_libs/rolling.pyx new file mode 100644 index 0000000000..37d27ffa4d --- /dev/null +++ b/qlib/data/_libs/rolling.pyx @@ -0,0 +1,207 @@ +# cython: profile=False +# cython: boundscheck=False, wraparound=False, cdivision=True +cimport cython +cimport numpy as np +import numpy as np + +from libc.math cimport sqrt, isnan, NAN +from libcpp.deque cimport deque + + +cdef class Rolling(object): + """1-D array rolling""" + cdef int window + cdef deque[double] barv + cdef int na_count + def __init__(self, int window): + self.window = window + self.na_count = window + cdef int i + for i in range(window): + self.barv.push_back(NAN) + + cdef double update(self, double val): + pass + + +cdef class Mean(Rolling): + """1-D array rolling mean""" + cdef double vsum + def __init__(self, int window): + super(Mean, self).__init__(window) + self.vsum = 0 + + cdef double update(self, double val): + self.barv.push_back(val) + if not isnan(self.barv.front()): + self.vsum -= self.barv.front() + else: + self.na_count -= 1 + self.barv.pop_front() + if isnan(val): + self.na_count += 1 + # return NAN + else: + self.vsum += val + return self.vsum / (self.window - self.na_count) + + +cdef class Slope(Rolling): + """1-D array rolling slope""" + cdef double i_sum # can be used as i2_sum + cdef double x_sum + cdef double x2_sum + cdef double y_sum + cdef double xy_sum + def __init__(self, int window): + super(Slope, self).__init__(window) + self.i_sum = 0 + self.x_sum = 0 + self.x2_sum = 0 + self.y_sum = 0 + self.xy_sum = 0 + + cdef double update(self, double val): + self.barv.push_back(val) + self.xy_sum = self.xy_sum - self.y_sum + self.x2_sum = self.x2_sum + self.i_sum - 2*self.x_sum + self.x_sum = self.x_sum - self.i_sum + cdef double _val + _val = self.barv.front() + if not isnan(_val): + self.i_sum -= 1 + self.y_sum -= _val + else: + self.na_count -= 1 + self.barv.pop_front() + if isnan(val): + self.na_count += 1 + # return NAN + else: + self.i_sum += 1 + self.x_sum += self.window + self.x2_sum += self.window * self.window + self.y_sum += val + self.xy_sum += self.window * val + cdef int N = self.window - self.na_count + return (N*self.xy_sum - self.x_sum*self.y_sum) / \ + (N*self.x2_sum - self.x_sum*self.x_sum) + + +cdef class Resi(Rolling): + """1-D array rolling residuals""" + cdef double i_sum # can be used as i2_sum + cdef double x_sum + cdef double x2_sum + cdef double y_sum + cdef double xy_sum + def __init__(self, int window): + super(Resi, self).__init__(window) + self.i_sum = 0 + self.x_sum = 0 + self.x2_sum = 0 + self.y_sum = 0 + self.xy_sum = 0 + + cdef double update(self, double val): + self.barv.push_back(val) + self.xy_sum = self.xy_sum - self.y_sum + self.x2_sum = self.x2_sum + self.i_sum - 2*self.x_sum + self.x_sum = self.x_sum - self.i_sum + cdef double _val + _val = self.barv.front() + if not isnan(_val): + self.i_sum -= 1 + self.y_sum -= _val + else: + self.na_count -= 1 + self.barv.pop_front() + if isnan(val): + self.na_count += 1 + # return NAN + else: + self.i_sum += 1 + self.x_sum += self.window + self.x2_sum += self.window * self.window + self.y_sum += val + self.xy_sum += self.window * val + cdef int N = self.window - self.na_count + slope = (N*self.xy_sum - self.x_sum*self.y_sum) / \ + (N*self.x2_sum - self.x_sum*self.x_sum) + x_mean = self.x_sum / N + y_mean = self.y_sum / N + interp = y_mean - slope*x_mean + return val - (slope*self.window + interp) + + +cdef class Rsquare(Rolling): + """1-D array rolling rsquare""" + cdef double i_sum + cdef double x_sum + cdef double x2_sum + cdef double y_sum + cdef double y2_sum + cdef double xy_sum + def __init__(self, int window): + super(Rsquare, self).__init__(window) + self.i_sum = 0 + self.x_sum = 0 + self.x2_sum = 0 + self.y_sum = 0 + self.y2_sum = 0 + self.xy_sum = 0 + + cdef double update(self, double val): + self.barv.push_back(val) + self.xy_sum = self.xy_sum - self.y_sum + self.x2_sum = self.x2_sum + self.i_sum - 2*self.x_sum + self.x_sum = self.x_sum - self.i_sum + cdef double _val + _val = self.barv.front() + if not isnan(_val): + self.i_sum -= 1 + self.y_sum -= _val + self.y2_sum -= _val * _val + else: + self.na_count -= 1 + self.barv.pop_front() + if isnan(val): + self.na_count += 1 + # return NAN + else: + self.i_sum += 1 + self.x_sum += self.window + self.x2_sum += self.window * self.window + self.y_sum += val + self.y2_sum += val * val + self.xy_sum += self.window * val + cdef int N = self.window - self.na_count + cdef double rvalue + rvalue = (N*self.xy_sum - self.x_sum*self.y_sum) / \ + sqrt((N*self.x2_sum - self.x_sum*self.x_sum) * (N*self.y2_sum - self.y_sum*self.y_sum)) + return rvalue * rvalue + + +cdef np.ndarray[double, ndim=1] rolling(Rolling r, np.ndarray a): + cdef int i + cdef int N = len(a) + cdef np.ndarray[double, ndim=1] ret = np.empty(N) + for i in range(N): + ret[i] = r.update(a[i]) + return ret + +def rolling_mean(np.ndarray a, int window): + cdef Mean r = Mean(window) + return rolling(r, a) + +def rolling_slope(np.ndarray a, int window): + cdef Slope r = Slope(window) + return rolling(r, a) + +def rolling_rsquare(np.ndarray a, int window): + cdef Rsquare r = Rsquare(window) + return rolling(r, a) + +def rolling_resi(np.ndarray a, int window): + cdef Resi r = Resi(window) + return rolling(r, a) diff --git a/qlib/data/base.py b/qlib/data/base.py new file mode 100644 index 0000000000..c357700c0c --- /dev/null +++ b/qlib/data/base.py @@ -0,0 +1,229 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import abc +import six +import pandas as pd + + +@six.add_metaclass(abc.ABCMeta) +class Expression(object): + """Expression base class""" + + def __str__(self): + return type(self).__name__ + + def __repr__(self): + return str(self) + + def __gt__(self, other): + from .ops import Gt + + return Gt(self, other) + + def __ge__(self, other): + from .ops import Ge + + return Ge(self, other) + + def __lt__(self, other): + from .ops import Lt + + return Lt(self, other) + + def __le__(self, other): + from .ops import Le + + return Le(self, other) + + def __eq__(self, other): + from .ops import Eq + + return Eq(self, other) + + def __ne__(self, other): + from .ops import Ne + + return Ne(self, other) + + def __add__(self, other): + from .ops import Add + + return Add(self, other) + + def __radd__(self, other): + from .ops import Add + + return Add(other, self) + + def __sub__(self, other): + from .ops import Sub + + return Sub(self, other) + + def __rsub__(self, other): + from .ops import Sub + + return Sub(other, self) + + def __mul__(self, other): + from .ops import Mul + + return Mul(self, other) + + def __rmul__(self, other): + from .ops import Mul + + return Mul(self, other) + + def __div__(self, other): + from .ops import Div + + return Div(self, other) + + def __rdiv__(self, other): + from .ops import Div + + return Div(other, self) + + def __truediv__(self, other): + from .ops import Div + + return Div(self, other) + + def __rtruediv__(self, other): + from .ops import Div + + return Div(other, self) + + def __pow__(self, other): + from .ops import Power + + return Power(self, other) + + def __and__(self, other): + from .ops import And + + return And(self, other) + + def __rand__(self, other): + from .ops import And + + return And(other, self) + + def __or__(self, other): + from .ops import Or + + return Or(self, other) + + def __ror__(self, other): + from .ops import Or + + return Or(other, self) + + def load(self, instrument, start_index, end_index, freq): + """load feature + + Parameters + ---------- + instrument : str + instrument code + start_index : str + feature start index [in calendar] + end_index : str + feature end index [in calendar] + freq : str + feature frequency + + Returns + ---------- + pd.Series + feature series: The index of the series is the calendar index + """ + from .cache import H + + # cache + args = str(self), instrument, start_index, end_index, freq + if args in H["f"]: + return H["f"][args] + if start_index is None or end_index is None or start_index > end_index: + raise ValueError("Invalid index range: {} {}".format(start_index, end_index)) + series = self._load_internal(instrument, start_index, end_index, freq) + series.name = str(self) + H["f"][args] = series + return series + + @abc.abstractmethod + def _load_internal(self, instrument, start_index, end_index, freq): + pass + + @abc.abstractmethod + def get_longest_back_rolling(self): + """Get the longest length of historical data the feature has accessed + + This is designed for getting the needed range of the data to calculate + the features in specific range at first. However, situations like + Ref(Ref($close, -1), 1) can not be handled rightly. + + So this will only used for detecting the length of historical data needed. + """ + # TODO: forward operator like Ref($close, -1) is not supported yet. + raise NotImplementedError("This function must be implemented in your newly defined feature") + + @abc.abstractmethod + def get_extended_window_size(self): + """get_extend_window_size + + For to calculate this Operator in range[start_index, end_index] + We have to get the *leaf feature* in + range[start_index - lft_etd, end_index + rght_etd]. + + Returns + ---------- + (int, int) + lft_etd, rght_etd + """ + raise NotImplementedError("This function must be implemented in your newly defined feature") + + +class Feature(Expression): + """Static Expression + + This kind of feature will load data from provider + """ + + def __init__(self, name=None): + if name: + self._name = name.lower() + else: + self._name = type(self).__name__.lower() + + def __str__(self): + return "$" + self._name + + def _load_internal(self, instrument, start_index, end_index, freq): + # load + from .data import FeatureD + + return FeatureD.feature(instrument, str(self), start_index, end_index, freq) + + def get_longest_back_rolling(self): + return 0 + + def get_extended_window_size(self): + return 0, 0 + + +@six.add_metaclass(abc.ABCMeta) +class ExpressionOps(Expression): + """Operator Expression + + This kind of feature will use operator for feature + construction on the fly. + """ + + pass diff --git a/qlib/data/cache.py b/qlib/data/cache.py new file mode 100644 index 0000000000..a91b5d71ef --- /dev/null +++ b/qlib/data/cache.py @@ -0,0 +1,1143 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import os +import sys +import stat +import time +import pickle +import traceback +import redis_lock +import contextlib +from pathlib import Path +import numpy as np +import pandas as pd +from collections import OrderedDict + +from ..config import C +from ..utils import ( + hash_args, + get_redis_connection, + read_bin, + parse_field, + remove_fields_space, + normalize_cache_fields, + normalize_cache_instruments, +) + +from ..log import get_module_logger +from .base import Feature + +from .ops import * + + +class MemCacheUnit(OrderedDict): + """Memory Cache Unit.""" + + # TODO: use min_heap to replace ordereddict for better performance + + def __init__(self, *args, **kwargs): + self.size_limit = kwargs.pop("size_limit", None) + # limit_type: check size_limit type, length(call fun: len) or size(call fun: sys.getsizeof) + self.limit_type = kwargs.pop("limit_type", "length") + super(MemCacheUnit, self).__init__(*args, **kwargs) + self._check_size_limit() + + def __setitem__(self, key, value): + super(MemCacheUnit, self).__setitem__(key, value) + self._check_size_limit() + + def __getitem__(self, key): + value = super(MemCacheUnit, self).__getitem__(key) + super(MemCacheUnit, self).__delitem__(key) + super(MemCacheUnit, self).__setitem__(key, value) + return value + + def _check_size_limit(self): + if self.size_limit is not None: + get_cur_size = lambda x: len(x) if self.limit_type == "length" else sum(map(sys.getsizeof, x.values())) + while get_cur_size(self) > self.size_limit: + self.popitem(last=False) + + +class MemCache(object): + """Memory cache.""" + + def __init__(self, mem_cache_size_limit=None, limit_type="length"): + """ + + Parameters + ---------- + mem_cache_size_limit: cache max size + limit_type: length or sizeof; length(call fun: len), size(call fun: sys.getsizeof) + """ + if limit_type not in ["length", "sizeof"]: + raise ValueError(f"limit_type must be length or sizeof, your limit_type is {limit_type}") + + self.__calendar_mem_cache = MemCacheUnit( + size_limit=C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit, + limit_type=limit_type, + ) + self.__instrument_mem_cache = MemCacheUnit( + size_limit=C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit, + limit_type=limit_type, + ) + self.__feature_mem_cache = MemCacheUnit( + size_limit=C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit, + limit_type=limit_type, + ) + + def __getitem__(self, key): + if key == "c": + return self.__calendar_mem_cache + elif key == "i": + return self.__instrument_mem_cache + elif key == "f": + return self.__feature_mem_cache + else: + raise KeyError("Unknown memcache unit") + + def clear(self): + self.__calendar_mem_cache.clear() + self.__instrument_mem_cache.clear() + self.__feature_mem_cache.clear() + + +class MemCacheExpire: + CACHE_EXPIRE = C.mem_cache_expire + + @staticmethod + def set_cache(mem_cache, key, value): + """set cache + + :param mem_cache: MemCache attribute('c'/'i'/'f') + :param key: cache key + :param value: cache value + """ + mem_cache[key] = value, time.time() + + @staticmethod + def get_cache(mem_cache, key): + """get mem cache + + :param mem_cache: MemCache attribute('c'/'i'/'f') + :param key: cache key + :return: cache value; if cache not exist, return None + """ + value = None + expire = False + if key in mem_cache: + value, latest_time = mem_cache[key] + expire = (time.time() - latest_time) > MemCacheExpire.CACHE_EXPIRE + return value, expire + + +class CacheUtils(object): + LOCK_ID = "QLIB" + + @staticmethod + def organize_meta_file(): + pass + + @staticmethod + def reset_lock(): + r = get_redis_connection() + redis_lock.reset_all(r) + + @staticmethod + def visit(cache_path): + # FIXME: Because read_lock was canceled when reading the cache, multiple processes may have read and write exceptions here + try: + with open(cache_path + ".meta", "rb") as f: + d = pickle.load(f) + with open(cache_path + ".meta", "wb") as f: + try: + d["meta"]["last_visit"] = str(time.time()) + d["meta"]["visits"] = d["meta"]["visits"] + 1 + except KeyError: + raise KeyError("Unknown meta keyword") + pickle.dump(d, f) + except Exception as e: + get_module_logger("CacheUtils").warning(f"visit {cache_path} cache error: {e}") + + @staticmethod + @contextlib.contextmanager + def reader_lock(redis_t, lock_name): + lock_name = f"{C.provider_uri}:{lock_name}" + current_cache_rlock = redis_lock.Lock(redis_t, "%s-rlock" % lock_name) + current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name) + # make sure only one reader is entering + current_cache_rlock.acquire(timeout=60) + try: + current_cache_readers = redis_t.get("%s-reader" % lock_name) + if current_cache_readers is None or int(current_cache_readers) == 0: + current_cache_wlock.acquire() + redis_t.incr("%s-reader" % lock_name) + finally: + current_cache_rlock.release() + try: + yield + finally: + # make sure only one reader is leaving + current_cache_rlock.acquire(timeout=60) + try: + redis_t.decr("%s-reader" % lock_name) + if int(redis_t.get("%s-reader" % lock_name)) == 0: + redis_t.delete("%s-reader" % lock_name) + current_cache_wlock.reset() + finally: + current_cache_rlock.release() + + @staticmethod + @contextlib.contextmanager + def writer_lock(redis_t, lock_name): + lock_name = f"{C.provider_uri}:{lock_name}" + current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name, id=CacheUtils.LOCK_ID) + current_cache_wlock.acquire() + try: + yield + finally: + current_cache_wlock.release() + + +class BaseProviderCache(object): + """Provider cache base class""" + + def __init__(self, provider): + self.provider = provider + self.logger = get_module_logger(self.__class__.__name__) + + def __getattr__(self, attr): + return getattr(self.provider, attr) + + +class ExpressionCache(BaseProviderCache): + """Expression cache mechanism base class. + + This class is used to wrap expression provider with self-defined expression cache mechanism. + + .. note:: Override the `_uri` and `_expression` method to create your own expression cache mechanism. + """ + + def expression(self, instrument, field, start_time, end_time, freq): + """Get expression data. + + .. note:: Same interface as `expression` method in expression provider + """ + try: + return self._expression(instrument, field, start_time, end_time, freq) + except NotImplementedError: + return self.provider.expression(instrument, field, start_time, end_time, freq) + + def _uri(self, instrument, field, start_time, end_time, freq): + """Get expression cache file uri. + + Override this method to define how to get expression cache file uri corresponding to users' own cache mechanism. + """ + raise NotImplementedError("Implement this function to match your own cache mechanism") + + def _expression(self, instrument, field, start_time, end_time, freq): + """Get expression data using cache. + + Override this method to define how to get expression data corresponding to users' own cache mechanism. + """ + raise NotImplementedError("Implement this method if you want to use expression cache") + + def update(self, cache_uri): + """Update expression cache to latest calendar. + + Overide this method to define how to update expression cache corresponding to users' own cache mechanism. + + Parameters + ---------- + cache_uri : str + the complete uri of expression cache file (include dir path) + + Returns + ------- + int + 0(successful update)/ 1(no need to update)/ 2(update failure) + """ + raise NotImplementedError("Implement this method if you want to make expression cache up to date") + + +class DatasetCache(BaseProviderCache): + """Dataset cache mechanism base class. + + This class is used to wrap dataset provider with self-defined dataset cache mechanism. + + .. note:: Override the `_uri` and `_dataset` method to create your own dataset cache mechanism. + """ + + HDF_KEY = "df" + + def dataset( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=1, + ): + """Get feature dataset. + + .. note:: Same interface as `dataset` method in dataset provider + + .. note:: The server use redis_lock to make sure + read-write conflicts will not be triggered + but client readers are not considered. + """ + if disk_cache == 0: + # skip cache + return self.provider.dataset(instruments, fields, start_time, end_time, freq) + else: + # use and replace cache + try: + return self._dataset(instruments, fields, start_time, end_time, freq, disk_cache) + except NotImplementedError: + return self.provider.dataset(instruments, fields, start_time, end_time, freq) + + def _uri(self, instruments, fields, start_time, end_time, freq, **kwargs): + """Get dataset cache file uri. + + Override this method to define how to get dataset cache file uri corresponding to users' own cache mechanism. + """ + raise NotImplementedError("Implement this function to match your own cache mechanism") + + def _dataset( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=1, + ): + """Get feature dataset using cache. + + Override this method to define how to get feature dataset corresponding to users' own cache mechanism. + """ + raise NotImplementedError("Implement this method if you want to use dataset feature cache") + + def _dataset_uri( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=1, + ): + """Get a uri of feature dataset using cache. + specially: + disk_cache=1 means using data set cache and return the uri of cache file. + disk_cache=0 means client knows the path of expression cache, + server checks if the cache exists(if not, generate it), and client loads data by itself. + Override this method to define how to get feature dataset uri corresponding to users' own cache mechanism. + """ + raise NotImplementedError( + "Implement this method if you want to use dataset feature cache as a cache file for client" + ) + + def update(self, cache_uri): + """Update dataset cache to latest calendar. + + Overide this method to define how to update dataset cache corresponding to users' own cache mechanism. + + Parameters + ---------- + cache_uri : str + the complete uri of dataset cache file (include dir path) + + Returns + ------- + int + 0(successful update)/ 1(no need to update)/ 2(update failure) + """ + raise NotImplementedError("Implement this method if you want to make expression cache up to date") + + @staticmethod + def cache_to_origin_data(data, fields): + """cache data to origin data + + :param data: pd.DataFrame, cache data + :param fields: feature fields + :return: pd.DataFrame + """ + not_space_fields = remove_fields_space(fields) + data = data.loc[:, not_space_fields] + # set features fields + data.columns = list(fields) + return data + + @staticmethod + def normalize_uri_args(instruments, fields, freq): + """normalize uri args""" + instruments = normalize_cache_instruments(instruments) + fields = normalize_cache_fields(fields) + freq = freq.lower() + + return instruments, fields, freq + + +class ServerExpressionCache(ExpressionCache): + """Prepared cache mechanism for server.""" + + def __init__(self, provider, **kwargs): + super(ServerExpressionCache, self).__init__(provider) + self.r = get_redis_connection() + # remote==True means client is using this module, writing behaviour will not be allowed. + self.remote = kwargs.get("remote", False) + if self.remote: + self.expr_cache_path = os.path.join(C.mount_path, C.features_cache_dir_name) + else: + self.expr_cache_path = os.path.join(C.provider_uri, C.features_cache_dir_name) + os.makedirs(self.expr_cache_path, exist_ok=True) + + def _uri(self, instrument, field, start_time, end_time, freq): + field = remove_fields_space(field) + instrument = str(instrument).lower() + return hash_args(instrument, field, freq) + + @staticmethod + def check_cache_exists(cache_path): + for p in [cache_path, cache_path + ".meta"]: + if not Path(p).exists(): + return False + return True + + def _expression(self, instrument, field, start_time=None, end_time=None, freq="day"): + _cache_uri = self._uri( + instrument=instrument, + field=field, + start_time=None, + end_time=None, + freq=freq, + ) + _instrument_dir = os.path.join(self.expr_cache_path, instrument.lower()) + cache_path = os.path.join(_instrument_dir, _cache_uri) + # get calendar + from .data import Cal + + _calendar = Cal.calendar(freq=freq) + + _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False) + + if self.check_cache_exists(cache_path): + """ + In most cases, we do not need reader_lock. + Because updating data is a small probability event compare to reading data. + + """ + # FIXME: Removing the reader lock may result in conflicts. + # with CacheUtils.reader_lock(self.r, 'expression-%s' % _cache_uri): + + # modify expression cache meta file + try: + # FIXME: Multiple readers may result in error visit number + if not self.remote: + CacheUtils.visit(cache_path) + series = read_bin(cache_path, start_index, end_index) + return series + except Exception as e: + series = None + self.logger.error("reading %s file error : %s" % (cache_path, traceback.format_exc())) + return series + else: + # normalize field + field = remove_fields_space(field) + # cache unavailable, generate the cache + if not os.path.exists(_instrument_dir): + os.makedirs(_instrument_dir, exist_ok=True) + if not isinstance(eval(parse_field(field)), Feature): + # When the expression is not a raw feature + # generate expression cache if the feature is not a Feature + # instance + series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq) + if not series.empty: + # This expresion is empty, we don't generate any cache for it. + with CacheUtils.writer_lock(self.r, "expression-%s" % _cache_uri): + self.gen_expression_cache( + expression_data=series, + cache_path=cache_path, + instrument=instrument, + field=field, + freq=freq, + last_update=str(_calendar[-1]), + ) + return series.loc[start_index:end_index] + else: + return series + else: + # If the expression is a raw feature(such as $close, $open) + return self.provider.expression(instrument, field, start_time, end_time, freq) + + @staticmethod + def clear_cache(cache_path): + meta_path = cache_path + ".meta" + for p in [cache_path, meta_path]: + p = Path(p) + if p.exists(): + p.unlink() + + def gen_expression_cache(self, expression_data, cache_path, instrument, field, freq, last_update): + """use bin file to save like feature-data.""" + # Make sure the cache runs right when the directory is deleted + # while running + meta = { + "info": { + "instrument": instrument, + "field": field, + "freq": freq, + "last_update": last_update, + }, + "meta": {"last_visit": time.time(), "visits": 1}, + } + self.logger.debug(f"generating expression cache: {meta}") + os.makedirs(self.expr_cache_path, exist_ok=True) + self.clear_cache(cache_path) + meta_path = cache_path + ".meta" + + with open(meta_path, "wb") as f: + pickle.dump(meta, f) + os.chmod(meta_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) + df = expression_data.to_frame() + + r = np.hstack([df.index[0], expression_data]).astype(" 0: + start, stop = ( + index_data["start"].iloc[0].item(), + index_data["end"].iloc[-1].item(), + ) + else: + start = stop = 0 + + with pd.HDFStore(cache_path, mode="r") as store: + if "/{}".format(im.KEY) in store.keys(): + df = store.select(key=im.KEY, start=start, stop=stop) + df.reset_index(inplace=True) + df.set_index(["instrument", "datetime"], inplace=True) + df.sort_index(inplace=True) + # read cache and need to replace not-space fields to field + df = cls.cache_to_origin_data(df, fields) + + else: + df = pd.DataFrame(columns=fields) + return df + + def _dataset( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=0, + ): + + if disk_cache == 0: + # In this case, data_set cache is configured but will not be used. + return self.provider.dataset(instruments, fields, start_time, end_time, freq) + + _cache_uri = self._uri( + instruments=instruments, + fields=fields, + start_time=None, + end_time=None, + freq=freq, + disk_cache=disk_cache, + ) + + cache_path = os.path.join(self.dtst_cache_path, _cache_uri) + + features = pd.DataFrame() + gen_flag = False + + if self.check_cache_exists(cache_path): + if disk_cache == 1: + # use cache + with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri): + CacheUtils.visit(cache_path) + features = self.read_data_from_cache(cache_path, start_time, end_time, fields) + elif disk_cache == 2: + gen_flag = True + else: + gen_flag = True + + if gen_flag: + # cache unavailable, generate the cache + with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri): + features = self.gen_dataset_cache( + cache_path=cache_path, + instruments=instruments, + fields=fields, + freq=freq, + ) + if not features.empty: + features.reset_index(inplace=True) + features.set_index(["datetime", "instrument"], inplace=True) + features.sort_index(inplace=True) + features = features.loc[start_time:end_time] + return features + + def _dataset_uri( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=0, + ): + if disk_cache == 0: + # In this case, server only checks the expression cache. + # The client will load the cache data by itself. + from .data import LocalDatasetProvider + + LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq) + return "" + + _cache_uri = self._uri( + instruments=instruments, + fields=fields, + start_time=None, + end_time=None, + freq=freq, + disk_cache=disk_cache, + ) + cache_path = os.path.join(self.dtst_cache_path, _cache_uri) + + if self.check_cache_exists(cache_path): + self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly") + with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri): + CacheUtils.visit(cache_path) + return _cache_uri + else: + # cache unavailable, generate the cache + with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri): + self.gen_dataset_cache( + cache_path=cache_path, + instruments=instruments, + fields=fields, + freq=freq, + ) + return _cache_uri + + class IndexManager: + """ + The lock is not considered in the class. Please consider the lock outside the code. + This class is the proxy of the disk data. + """ + + KEY = "df" + + def __init__(self, cache_path): + self.index_path = cache_path + ".index" + self._data = None + self.logger = get_module_logger(self.__class__.__name__) + + def get_index(self, start_time=None, end_time=None): + # TODO: fast read index from the disk. + if self._data is None: + self.sync_from_disk() + return self._data.loc[start_time:end_time].copy() + + def sync_to_disk(self): + if self._data is None: + raise ValueError("No data to sync to disk.") + self._data.sort_index(inplace=True) + self._data.to_hdf(self.index_path, key=self.KEY, mode="w", format="table") + # The index should be readable for all users + os.chmod(self.index_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) + + def sync_from_disk(self): + # The file will not be closed directly if we read_hdf from the disk directly + with pd.HDFStore(self.index_path, mode="r") as store: + if "/{}".format(self.KEY) in store.keys(): + self._data = pd.read_hdf(store, key=self.KEY) + else: + self._data = pd.DataFrame() + + def update(self, data, sync=True): + self._data = data.astype(np.int32).copy() + if sync: + self.sync_to_disk() + + def append_index(self, data, to_disk=True): + data = data.astype(np.int32).copy() + data.sort_index(inplace=True) + self._data = pd.concat([self._data, data]) + if to_disk: + with pd.HDFStore(self.index_path) as store: + store.append(self.KEY, data, append=True) + + @staticmethod + def build_index_from_data(data, start_index=0): + if data.empty: + return pd.DataFrame() + line_data = data.iloc[:, 0].fillna(0).groupby("datetime").count() + line_data.sort_index(inplace=True) + index_end = line_data.cumsum() + index_start = index_end.shift(1).fillna(0) + + index_data = pd.DataFrame() + index_data["start"] = index_start + index_data["end"] = index_end + index_data += start_index + return index_data + + @staticmethod + def clear_cache(cache_path): + meta_path = cache_path + ".meta" + for p in [cache_path, meta_path, cache_path + ".index", cache_path + ".data"]: + p = Path(p) + if p.exists(): + p.unlink() + + def gen_dataset_cache(self, cache_path, instruments, fields, freq): + """gen_dataset_cache + + NOTE:This function does not consider the cache read write lock. Please + Aquire the lock outside this function + + The format the cache contains 3 parts(followed by typical filename). + - index : cache/d41366901e25de3ec47297f12e2ba11d.index + - The content of the file may be in following format(pandas.Series) + start end + 1999-11-10 00:00:00 0 1 + 1999-11-11 00:00:00 1 2 + 1999-11-12 00:00:00 2 3 + ... + NOTE: The start is closed. The end is open!!!!! + - Each line contains two element + - It indicates the `end_index` of the data for `timestamp` + + - meta data: cache/d41366901e25de3ec47297f12e2ba11d.meta + - data : cache/d41366901e25de3ec47297f12e2ba11d + - This is a hdf file sorted by datetime + + :param cache_path: The path to store the cache + :param instruments: The instruments to store the cache + :param fields: The fields to store the cache + :param freq: The freq to store the cache + + :return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function + """ + # get calendar + from .data import Cal + + _calendar = Cal.calendar(freq=freq) + self.logger.debug(f"Generating dataset cache {cache_path}") + # Make sure the cache runs right when the directory is deleted + # while running + os.makedirs(self.dtst_cache_path, exist_ok=True) + self.clear_cache(cache_path) + + features = self.provider.dataset(instruments, fields, _calendar[0], _calendar[-1], freq) + + # sort index by datetime + if not features.empty: + features.reset_index(inplace=True) + features.set_index(["datetime", "instrument"], inplace=True) + features.sort_index(inplace=True) + + # write cache data + with pd.HDFStore(cache_path + ".data") as store: + cache_to_orig_map = dict(zip(remove_fields_space(features.columns), features.columns)) + orig_to_cache_map = dict(zip(features.columns, remove_fields_space(features.columns))) + cache_features = features[list(cache_to_orig_map.values())].rename(columns=orig_to_cache_map) + # cache columns + cache_columns = sorted(cache_features.columns) + cache_features = cache_features.loc[:, cache_columns] + cache_features = cache_features.loc[:, ~cache_features.columns.duplicated()] + store.append(DatasetCache.HDF_KEY, cache_features, append=False) + # write meta file + meta = { + "info": { + "instruments": instruments, + "fields": cache_columns, + "freq": freq, + "last_update": str(_calendar[-1]), # The last_update to store the cache + }, + "meta": {"last_visit": time.time(), "visits": 1}, + } + with open(cache_path + ".meta", "wb") as f: + pickle.dump(meta, f) + os.chmod(cache_path + ".meta", stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) + # write index file + im = ServerDatasetCache.IndexManager(cache_path) + index_data = im.build_index_from_data(features) + im.update(index_data) + + # rename the file after the cache has been generated + # this doesn't work well on windows, but our server won't use windows + # temporarily + os.replace(cache_path + ".data", cache_path) + # the fields of the cached features are converted to the original fields + return features + + def update(self, cache_uri): + cp_cache_uri = os.path.join(self.dtst_cache_path, cache_uri) + + if not self.check_cache_exists(cp_cache_uri): + self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed") + self.clear_cache(cp_cache_uri) + return 2 + + im = ServerDatasetCache.IndexManager(cp_cache_uri) + with CacheUtils.writer_lock(self.r, "dataset-%s" % cache_uri): + with open(cp_cache_uri + ".meta", "rb") as f: + d = pickle.load(f) + instruments = d["info"]["instruments"] + fields = d["info"]["fields"] + freq = d["info"]["freq"] + last_update_time = d["info"]["last_update"] + index_data = im.get_index() + + self.logger.debug("Updating dataset: {}".format(d)) + from .data import Inst + + if Inst.get_inst_type(instruments) == Inst.DICT: + self.logger.info(f"The file {cache_uri} has dict cache. Skip updating") + return 1 + + # get newest calendar + from .data import Cal + + whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq) + # The calendar since last updated + new_calendar = Cal.calendar(start_time=last_update_time, end_time=None, freq=freq) + + # get append data + if len(new_calendar) <= 1: + # Including last updated calendar, we only get 1 item. + # No future updating is needed. + return 1 + else: + # get the data needed after the historical data are removed. + # The start index of new data + current_index = len(whole_calendar) - len(new_calendar) + 1 + + # To avoid recursive import + from .data import ExpressionD + + # The existing data length + lft_etd = rght_etd = 0 + for field in fields: + expr = ExpressionD.get_expression_instance(field) + l, r = expr.get_extended_window_size() + lft_etd = max(lft_etd, l) + rght_etd = max(rght_etd, r) + # remove the period that should be updated. + if index_data.empty: + # We don't have any data for such dataset. Nothing to remove + rm_n_period = rm_lines = 0 + else: + rm_n_period = min(rght_etd, index_data.shape[0]) + rm_lines = ( + (index_data["end"] - index_data["start"]) + .loc[whole_calendar[current_index - rm_n_period] :] + .sum() + .item() + ) + + data = self.provider.dataset( + instruments, + fields, + whole_calendar[current_index - rm_n_period], + new_calendar[-1], + freq, + ) + + if not data.empty: + data.reset_index(inplace=True) + data.set_index(["datetime", "instrument"], inplace=True) + data.sort_index(inplace=True) + else: + return 0 # No data to update cache + + store = pd.HDFStore(cp_cache_uri) + # FIXME: + # Because the feature cache are stored as .bin file. + # So the series read from features are all float32. + # However, the first dataset cache is calulated based on the + # raw data. So the data type may be float64. + # Different data type will result in failure of appending data + if "/{}".format(DatasetCache.HDF_KEY) in store.keys(): + schema = store.select(DatasetCache.HDF_KEY, start=0, stop=0) + for col, dtype in schema.dtypes.items(): + data[col] = data[col].astype(dtype) + if rm_lines > 0: + store.remove(key=im.KEY, start=-rm_lines) + store.append(DatasetCache.HDF_KEY, data) + store.close() + + # update index file + new_index_data = im.build_index_from_data( + data.loc(axis=0)[whole_calendar[current_index] :, :], + start_index=0 if index_data.empty else index_data["end"].iloc[-1], + ) + im.append_index(new_index_data) + + # update meta file + d["info"]["last_update"] = str(new_calendar[-1]) + with open(cp_cache_uri + ".meta", "wb") as f: + pickle.dump(d, f) + return 0 + + +class SimpleDatasetCache(DatasetCache): + """Simple dataset cache that can be used locally or on client.""" + + def __init__(self, provider): + super(SimpleDatasetCache, self).__init__(provider) + try: + self.local_cache_path = C["local_cache_path"] + except KeyError as e: + self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism") + + def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs): + instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq) + local_cache_path = str(Path(self.local_cache_path).expanduser().resolve()) + return hash_args( + instruments, + fields, + start_time, + end_time, + freq, + disk_cache, + local_cache_path, + ) + + def _dataset( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=1, + ): + if disk_cache == 0: + # In this case, data_set cache is configured but will not be used. + return self.provider.dataset(instruments, fields, start_time, end_time, freq) + os.makedirs(os.path.expanduser(self.local_cache_path), exist_ok=True) + cache_file = os.path.join( + self.local_cache_path, + self._uri(instruments, fields, start_time, end_time, freq, disk_cache=disk_cache), + ) + gen_flag = False + + if os.path.exists(cache_file): + if disk_cache == 1: + # use cache + df = pd.read_pickle(cache_file) + return self.cache_to_origin_data(df, fields) + elif disk_cache == 2: + # replace cache + gen_flag = True + else: + gen_flag = True + + if gen_flag: + data = self.provider.dataset(instruments, normalize_cache_fields(fields), start_time, end_time, freq) + data.to_pickle(cache_file) + return self.cache_to_origin_data(data, fields) + + +class ClientDatasetCache(DatasetCache): + """Prepared cache mechanism for server.""" + + def __init__(self, provider): + super(ClientDatasetCache, self).__init__(provider) + + def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs): + return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache) + + def dataset( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=0, + ): + + if "local" in C.dataset_provider.lower(): + # use LocalDatasetProvider + return self.provider.dataset(instruments, fields, start_time, end_time, freq) + + if disk_cache == 0: + # do not use data_set cache, load data from remote expression cache directly + return self.provider.dataset( + instruments, + fields, + start_time, + end_time, + freq, + disk_cache, + return_uri=False, + ) + + # use ClientDatasetProvider + feature_uri = self._uri(instruments, fields, None, None, freq, disk_cache=disk_cache) + value, expire = MemCacheExpire.get_cache(H["f"], feature_uri) + mnt_feature_uri = os.path.join(C.mount_path, C.dataset_cache_dir_name, feature_uri) + if value is None or expire or not os.path.exists(mnt_feature_uri): + df, uri = self.provider.dataset( + instruments, + fields, + start_time, + end_time, + freq, + disk_cache, + return_uri=True, + ) + # cache uri + MemCacheExpire.set_cache(H["f"], uri, uri) + # cache DataFrame + # HZ['f'][uri] = df.copy() + get_module_logger("cache").debug(f"get feature from {C.dataset_provider}") + else: + mnt_feature_uri = os.path.join(C.mount_path, C.dataset_cache_dir_name, feature_uri) + df = ServerDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields) + get_module_logger("cache").debug("get feature from uri cache") + + return df + + +class CalendarCache(BaseProviderCache): + pass + + +class ClientCalendarCache(CalendarCache): + def calendar(self, start_time=None, end_time=None, freq="day", future=False): + uri = self._uri(start_time, end_time, freq, future) + result, expire = MemCacheExpire.get_cache(H["c"], uri) + if result is None or expire: + + result = self.provider.calendar(start_time, end_time, freq, future) + MemCacheExpire.set_cache(H["c"], uri, result) + + get_module_logger("data").debug(f"get calendar from {C.calendar_provider}") + else: + get_module_logger("data").debug("get calendar from local cache") + + return result + + +# MemCache sizeof +HZ = MemCache(C.mem_cache_space_limit, limit_type="sizeof") +# MemCache length +H = MemCache(limit_type="length") diff --git a/qlib/data/client.py b/qlib/data/client.py new file mode 100644 index 0000000000..2e83726d19 --- /dev/null +++ b/qlib/data/client.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import socketio + +from .. import __version__ +from ..log import get_module_logger +import pickle + + +class Client(object): + """A client class + + Provide the connection tool functions for ClientProvider. + """ + + def __init__(self, host, port): + super(Client, self).__init__() + self.sio = socketio.Client() + self.server_host = host + self.server_port = port + self.logger = get_module_logger(self.__class__.__name__) + # bind connect/disconnect callbacks + self.sio.on( + "connect", + lambda: self.logger.debug("Connect to server {}".format(self.sio.connection_url)), + ) + self.sio.on("disconnect", lambda: self.logger.debug("Disconnect from server!")) + + def connect_server(self): + """Connect to server.""" + try: + self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port)) + except socketio.exceptions.ConnectionError: + self.logger.error("Cannot connect to server - check your network or server status") + + def disconnect(self): + """Disconnect from server.""" + try: + self.sio.eio.disconnect(True) + except Exception as e: + self.logger.error("Cannot disconnect from server : %s" % e) + + def send_request(self, request_type, request_content, msg_queue, msg_proc_func=None): + """Send a certain request to server. + + Parameters + ---------- + request_type : str + type of proposed request, 'calendar'/'instrument'/'feature' + request_content : dict + records the information of the request + msg_proc_func : func + the function to process the message when receiving response, should have arg `*args` + msg_queue: Queue + The queue to pass the messsage after callback + """ + head_info = {"version": __version__} + + def request_callback(*args): + """callback_wrapper + + :param *args: args[0] is the response content + """ + # args[0] is the response content + self.logger.debug("receive data and enter queue") + msg = dict(args[0]) + if msg["detailed_info"] is not None: + if msg["status"] != 0: + self.logger.error(msg["detailed_info"]) + else: + self.logger.info(msg["detailed_info"]) + if msg["status"] != 0: + ex = ValueError(f"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}") + msg_queue.put(ex) + else: + if msg_proc_func is not None: + try: + ret = msg_proc_func(msg["result"]) + except Exception as e: + self.logger.exception("Error when processing message.") + ret = e + else: + ret = msg["result"] + msg_queue.put(ret) + self.disconnect() + self.logger.debug("disconnected") + + self.logger.debug("try connecting") + self.connect_server() + self.logger.debug("connected") + # The pickle is for passing some parameters with special type(such as + # pd.Timestamp) + request_content = {"head": head_info, "body": pickle.dumps(request_content)} + self.sio.on(request_type + "_response", request_callback) + self.logger.debug("try sending") + self.sio.emit(request_type + "_request", request_content) + self.sio.wait() diff --git a/qlib/data/data.py b/qlib/data/data.py new file mode 100644 index 0000000000..5420f8efa2 --- /dev/null +++ b/qlib/data/data.py @@ -0,0 +1,1111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import os +import abc +import six +import time +import queue +import bisect +import logging +import importlib +import traceback +import numpy as np +import pandas as pd +from multiprocessing import Pool + +from .cache import H +from ..config import C +from .ops import * +from ..log import get_module_logger +from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields +from .base import Feature +from .cache import ServerDatasetCache, ServerExpressionCache + + +@six.add_metaclass(abc.ABCMeta) +class CalendarProvider(object): + """Calendar provider base class + + Provide calendar data. + """ + + @abc.abstractmethod + def calendar(self, start_time=None, end_time=None, freq="day", future=False): + """Get calendar of certain market in given time range. + + Parameters + ---------- + start_time : str + start of the time range + end_time : str + end of the time range + freq : str + time frequency, available: year/quarter/month/week/day + future : bool + whether including future trading day + + Returns + ---------- + list + calendar list + """ + raise NotImplementedError("Subclass of CalendarProvider must implement `calendar` method") + + def locate_index(self, start_time, end_time, freq, future): + """Locate the start time index and end time index in a calendar under certain frequency. + + Parameters + ---------- + start_time : str + start of the time range + end_time : str + end of the time range + freq : str + time frequency, available: year/quarter/month/week/day + future : bool + whether including future trading day + + Returns + ------- + pd.Timestamp + the real start time + pd.Timestamp + the real end time + int + the index of start time + int + the index of end time + """ + start_time = pd.Timestamp(start_time) + end_time = pd.Timestamp(end_time) + calendar, calendar_index = self._get_calendar(freq=freq, future=future) + if start_time not in calendar_index: + try: + start_time = calendar[bisect.bisect_left(calendar, start_time)] + except IndexError: + raise IndexError( + "`start_time` uses a future date, if you want to get future trading days, you can use: `future=True`" + ) + start_index = calendar_index[start_time] + if end_time not in calendar_index: + end_time = calendar[bisect.bisect_right(calendar, end_time) - 1] + end_index = calendar_index[end_time] + return start_time, end_time, start_index, end_index + + def _get_calendar(self, freq, future): + """Load calendar using memcache. + + Parameters + ---------- + freq : str + frequency of read calendar file + future : bool + whether including future trading day + + Returns + ------- + list + list of timestamps + dict + dict composed by timestamp as key and index as value for fast search + """ + flag = f"{freq}_future_{future}" + if flag in H["c"]: + _calendar, _calendar_index = H["c"][flag] + else: + _calendar = np.array(self._load_calendar(freq, future)) + _calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search + H["c"][flag] = _calendar, _calendar_index + return _calendar, _calendar_index + + def _uri(self, start_time, end_time, freq, future=False): + """Get the uri of calendar generation task.""" + return hash_args(start_time, end_time, freq, future) + + +@six.add_metaclass(abc.ABCMeta) +class InstrumentProvider(object): + """Instrument provider base class + + Provide instrument data. + """ + + @staticmethod + def instruments(market="all", filter_pipe=None): + """Get the general config dictionary for a base market adding several dynamic filters. + + Parameters + ---------- + market : str + market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500 + filter_pipe : list + the list of dynamic filters + + Returns + ---------- + dict + dict of stockpool config + {`market`=>base market name, `filter_pipe`=>list of filters} + + example : + {'market': 'csi500', + 'filter_pipe': [{'filter_type': 'ExpressionDFilter', + 'rule_expression': '$open<40', + 'filter_start_time': None, + 'filter_end_time': None, + 'keep': False}, + {'filter_type': 'NameDFilter', + 'name_rule_re': 'SH[0-9]{4}55', + 'filter_start_time': None, + 'filter_end_time': None}]} + """ + if filter_pipe is None: + filter_pipe = [] + config = {"market": market, "filter_pipe": []} + # the order of the filters will affect the result, so we need to keep + # the order + for filter_t in filter_pipe: + config["filter_pipe"].append(filter_t.to_config()) + return config + + @abc.abstractmethod + def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + """List the instruments based on a certain stockpool config. + + Parameters + ---------- + instruments : dict + stockpool config + start_time : str + start of the time range + end_time : str + end of the time range + as_list : bool + return instruments as list or dict + + Returns + ------- + dict or list + instruments list or dictionary with time spans + """ + raise NotImplementedError("Subclass of InstrumentProvider must implement `list_instruments` method") + + def _uri(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + return hash_args(instruments, start_time, end_time, freq, as_list) + + # instruments type + LIST = "LIST" + DICT = "DICT" + CONF = "CONF" + + @classmethod + def get_inst_type(cls, inst): + if "market" in inst: + return cls.CONF + if isinstance(inst, dict): + return cls.DICT + if isinstance(inst, (list, tuple, pd.Index, np.ndarray)): + return cls.LIST + raise ValueError(f"Unknown instrument type {inst}") + + +@six.add_metaclass(abc.ABCMeta) +class FeatureProvider(object): + """Feature provider class + + Provide feature data. + """ + + @abc.abstractmethod + def feature(self, instrument, field, start_time, end_time, freq): + """Get feature data. + + Parameters + ---------- + instrument : str + a certain instrument + field : str + a certain field of feature + start_time : str + start of the time range + end_time : str + end of the time range + freq : str + time frequency, available: year/quarter/month/week/day + + Returns + ------- + pd.Series + data of a certain feature + """ + raise NotImplementedError("Subclass of FeatureProvider must implement `feature` method") + + +@six.add_metaclass(abc.ABCMeta) +class ExpressionProvider(object): + """Expression provider class + + Provide Expression data. + """ + + def __init__(self): + self.expression_instance_cache = {} + + def get_expression_instance(self, field): + try: + if field in self.expression_instance_cache: + expression = self.expression_instance_cache[field] + else: + expression = eval(parse_field(field)) + self.expression_instance_cache[field] = expression + except NameError as e: + get_module_logger("data").exception( + "ERROR: field [%s] contains invalid operator/variable [%s]" % (str(field), str(e).split()[1]) + ) + raise + except SyntaxError: + get_module_logger("data").exception("ERROR: field [%s] contains invalid syntax" % str(field)) + raise + return expression + + @abc.abstractmethod + def expression(self, instrument, field, start_time=None, end_time=None, freq="day"): + """Get Expression data. + + Parameters + ---------- + instrument : str + a certain instrument + field : str + a certain field of feature + start_time : str + start of the time range + end_time : str + end of the time range + freq : str + time frequency, available: year/quarter/month/week/day + + Returns + ------- + pd.Series + data of a certain expression + """ + raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method") + + +@six.add_metaclass(abc.ABCMeta) +class DatasetProvider(object): + """Dataset provider class + + Provide Dataset data. + """ + + @abc.abstractmethod + def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"): + """Get dataset data. + + Parameters + ---------- + instruments : list or dict + list/dict of instruments or dict of stockpool config + fields : list + list of feature instances + start_time : str + start of the time range + end_time : str + end of the time range + freq : str + time frequency + + Returns + ---------- + pd.DataFrame + a pandas dataframe with index + """ + raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method") + + def _uri( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=1, + **kwargs, + ): + """Get task uri, used when generating rabbitmq task in qlib_server + + Parameters + ---------- + instruments : list or dict + list/dict of instruments or dict of stockpool config + fields : list + list of feature instances + start_time : str + start of the time range + end_time : str + end of the time range + freq : str + time frequency + disk_cache : int + whether to skip(0)/use(1)/replace(2) disk_cache + + """ + return ServerDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache) + + @staticmethod + def get_instruments_d(instruments, freq): + """ + Parse different types of input instruments to output instruments_d + Wrong format of input instruments will lead to exception. + + """ + if isinstance(instruments, dict): + if "market" in instruments: + # dict of stockpool config + instruments_d = Inst.list_instruments(instruments=instruments, freq=freq, as_list=False) + else: + # dict of instruments and timestamp + instruments_d = instruments + elif isinstance(instruments, (list, tuple, pd.Index, np.ndarray)): + # list or tuple of a group of instruments + instruments_d = list(instruments) + else: + raise ValueError("Unsupported input type for param `instrument`") + return instruments_d + + @staticmethod + def get_column_names(fields): + """ + Get column names from input fields + + """ + if len(fields) == 0: + raise ValueError("fields cannot be empty") + fields = fields.copy() + column_names = [str(f) for f in fields] + return column_names + + @staticmethod + def parse_fields(fields): + # parse and check the input fields + return [ExpressionD.get_expression_instance(f) for f in fields] + + @staticmethod + def dataset_processor(instruments_d, column_names, start_time, end_time, freq): + """ + Load and process the data, return the data set. + - default using multi-kernel method. + + """ + normalize_column_names = normalize_cache_fields(column_names) + data = dict() + # One process for one task, so that the memory will be freed quicker. + if C.maxtasksperchild is None: + p = Pool(processes=C.kernels) + else: + p = Pool(processes=C.kernels, maxtasksperchild=C.maxtasksperchild) + + if isinstance(instruments_d, dict): + for inst, spans in instruments_d.items(): + data[inst] = p.apply_async( + DatasetProvider.expression_calculator, + args=( + inst, + start_time, + end_time, + freq, + normalize_column_names, + spans, + C, + ), + ) + else: + for inst in instruments_d: + data[inst] = p.apply_async( + DatasetProvider.expression_calculator, + args=( + inst, + start_time, + end_time, + freq, + normalize_column_names, + None, + C, + ), + ) + + p.close() + p.join() + + new_data = dict() + for inst in sorted(data.keys()): + if len(data[inst].get()) > 0: + # NOTE: Python version >= 3.6; in versions after python3.6, dict will always guarantee the insertion order + new_data[inst] = data[inst].get() + + if len(new_data) > 0: + data = pd.concat(new_data, names=["instrument"], sort=False) + data = ServerDatasetCache.cache_to_origin_data(data, column_names) + else: + data = pd.DataFrame(columns=column_names) + + return data + + @staticmethod + def expression_calculator(inst, start_time, end_time, freq, column_names, spans=None, C=None): + """ + Calculate the expressions for one instrument, return a df result. + If the expression has been calculated before, load from cache. + + return value: A data frame with index 'datetime' and other data columns. + + """ + # NOTE: This place is compatible with windows, windows multi-process is spawn + if getattr(ExpressionD, "_provider", None) is None: + register_all_wrappers() + + obj = dict() + for field in column_names: + # The client does not have expression provider, the data will be loaded from cache using static method. + obj[field] = ExpressionD.expression(inst, field, start_time, end_time, freq) + + data = pd.DataFrame(obj) + _calendar = Cal.calendar(freq=freq) + data.index = _calendar[data.index.values.astype(np.int)] + data.index.names = ["datetime"] + + if spans is None: + return data + else: + mask = np.zeros(len(data), dtype=np.bool) + for begin, end in spans: + mask |= (data.index >= begin) & (data.index <= end) + return data[mask] + + +class LocalCalendarProvider(CalendarProvider): + """Local calendar data provider class + + Provide calendar data from local data source. + """ + + def __init__(self, **kwargs): + self.remote = kwargs.get("remote", False) + + @property + def _uri_cal(self): + """Calendar file uri.""" + if self.remote: + return os.path.join(C.mount_path, "calendars", "{}.txt") + else: + return os.path.join(C.provider_uri, "calendars", "{}.txt") + + def _load_calendar(self, freq, future): + """Load original calendar timestamp from file. + + Parameters + ---------- + freq : str + frequency of read calendar file + + Returns + ---------- + list + list of timestamps + """ + if future: + fname = self._uri_cal.format(freq + "_future") + # if future calendar not exists, return current calendar + if not os.path.exists(fname): + get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!") + fname = self._uri_cal.format(freq) + else: + fname = self._uri_cal.format(freq) + if not os.path.exists(fname): + raise ValueError("calendar not exists for freq " + freq) + with open(fname) as f: + return [pd.Timestamp(x.strip()) for x in f] + + def calendar(self, start_time=None, end_time=None, freq="day", future=False): + _calendar, _calendar_index = self._get_calendar(freq, future) + if start_time == "None": + start_time = None + if end_time == "None": + end_time = None + # strip + if start_time: + start_time = pd.Timestamp(start_time) + if start_time > _calendar[-1]: + return np.array([]) + else: + start_time = _calendar[0] + if end_time: + end_time = pd.Timestamp(end_time) + if end_time < _calendar[0]: + return np.array([]) + else: + end_time = _calendar[-1] + _, _, si, ei = self.locate_index(start_time, end_time, freq, future) + return _calendar[si : ei + 1] + + +class LocalInstrumentProvider(InstrumentProvider): + """Local instrument data provider class + + Provide instrument data from local data source. + """ + + def __init__(self): + pass + + @property + def _uri_inst(self): + """Instrument file uri.""" + return os.path.join(C.provider_uri, "instruments", "{}.txt") + + def _load_instruments(self, market): + fname = self._uri_inst.format(market) + print(fname) + if not os.path.exists(fname): + raise ValueError("instruments not exists for market " + market) + _instruments = dict() + with open(fname) as f: + for line in f: + inst_time = line.strip().split() + inst = inst_time[0] + if len(inst_time) == 3: + # `day` + begin = inst_time[1] + end = inst_time[2] + elif len(inst_time) == 5: + # `1min` + begin = inst_time[1] + " " + inst_time[2] + end = inst_time[3] + " " + inst_time[4] + _instruments.setdefault(inst, []).append((pd.Timestamp(begin), pd.Timestamp(end))) + return _instruments + + def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + market = instruments["market"] + if market in H["i"]: + _instruments = H["i"][market] + else: + _instruments = self._load_instruments(market) + H["i"][market] = _instruments + # strip + # use calendar boundary + cal = Cal.calendar(freq=freq) + start_time = pd.Timestamp(start_time or cal[0]) + end_time = pd.Timestamp(end_time or cal[-1]) + _instruments_filtered = { + inst: list( + filter( + lambda x: x[0] <= x[1], + [(max(start_time, x[0]), min(end_time, x[1])) for x in spans], + ) + ) + for inst, spans in _instruments.items() + } + _instruments_filtered = {key: value for key, value in _instruments_filtered.items() if value} + # filter + filter_pipe = instruments["filter_pipe"] + for filter_config in filter_pipe: + from . import filter as F + + filter_t = getattr(F, filter_config["filter_type"]).from_config(filter_config) + _instruments_filtered = filter_t(_instruments_filtered, start_time, end_time, freq) + # as list + if as_list: + return list(_instruments_filtered) + return _instruments_filtered + + +class LocalFeatureProvider(FeatureProvider): + """Local feature data provider class + + Provide feature data from local data source. + """ + + def __init__(self, **kwargs): + self.remote = kwargs.get("remote", False) + + @property + def _uri_data(self): + """Static feature file uri.""" + if self.remote: + return os.path.join(C.mount_path, "features", "{}", "{}.{}.bin") + else: + return os.path.join(C.provider_uri, "features", "{}", "{}.{}.bin") + + def feature(self, instrument, field, start_index, end_index, freq): + # validate + field = str(field).lower()[1:] + uri_data = self._uri_data.format(instrument.lower(), field, freq) + if not os.path.exists(uri_data): + get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field)) + return pd.Series() + # raise ValueError('uri_data not found: ' + uri_data) + # load + series = read_bin(uri_data, start_index, end_index) + return series + + +class LocalExpressionProvider(ExpressionProvider): + """Local expression data provider class + + Provide expression data from local data source. + """ + + def __init__(self): + super().__init__() + + def expression(self, instrument, field, start_time=None, end_time=None, freq="day"): + expression = self.get_expression_instance(field) + start_time = pd.Timestamp(start_time) + end_time = pd.Timestamp(end_time) + _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False) + lft_etd, rght_etd = expression.get_extended_window_size() + series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq) + # Ensure that each column type is consistent + # FIXME: The stock data is currently float. If there is other types of data, this part needs to be re-implemented. + try: + series = series.astype(float) + except ValueError: + pass + if not series.empty: + series = series.loc[start_index:end_index] + return series + + +class LocalDatasetProvider(DatasetProvider): + """Local dataset data provider class + + Provide dataset data from local data source. + """ + + def __init__(self): + pass + + def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"): + instruments_d = self.get_instruments_d(instruments, freq) + column_names = self.get_column_names(fields) + cal = Cal.calendar(start_time, end_time, freq) + if len(cal) == 0: + return pd.DataFrame(columns=column_names) + start_time = cal[0] + end_time = cal[-1] + + data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq) + + return data + + @staticmethod + def multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq="day"): + """ + This method is used to prepare the expression cache for the client. + Then the client will load the data from expression cache by itself. + + """ + instruments_d = DatasetProvider.get_instruments_d(instruments, freq) + column_names = DatasetProvider.get_column_names(fields) + cal = Cal.calendar(start_time, end_time, freq) + if len(cal) == 0: + return + start_time = cal[0] + end_time = cal[-1] + + if C.maxtasksperchild is None: + p = Pool(processes=C.kernels) + else: + p = Pool(processes=C.kernels, maxtasksperchild=C.maxtasksperchild) + + for inst in instruments_d: + p.apply_async( + LocalDatasetProvider.cache_walker, + args=( + inst, + start_time, + end_time, + freq, + column_names, + ), + ) + + p.close() + p.join() + + @staticmethod + def cache_walker(inst, start_time, end_time, freq, column_names): + """ + If the expressions of one instrument haven't been calculated before, + calculate it and write it into expression cache. + + """ + for field in column_names: + ExpressionD.expression(inst, field, start_time, end_time, freq) + + +class ClientCalendarProvider(CalendarProvider): + """Client calendar data provider class + + Provide calendar data by requesting data from server as a client. + """ + + def __init__(self): + self.conn = None + self.queue = queue.Queue() + + def set_conn(self, conn): + self.conn = conn + + def calendar(self, start_time=None, end_time=None, freq="day", future=False): + self.conn.send_request( + request_type="calendar", + request_content={ + "start_time": str(start_time), + "end_time": str(end_time), + "freq": freq, + "future": future, + }, + msg_queue=self.queue, + msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content], + ) + result = self.queue.get(timeout=C["timeout"]) + return result + + +class ClientInstrumentProvider(InstrumentProvider): + """Client instrument data provider class + + Provide instrument data by requesting data from server as a client. + """ + + def __init__(self): + self.conn = None + self.queue = queue.Queue() + + def set_conn(self, conn): + self.conn = conn + + def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + def inst_msg_proc_func(response_content): + if isinstance(response_content, dict): + instrument = { + i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t] for i, t in response_content.items() + } + else: + instrument = response_content + return instrument + + self.conn.send_request( + request_type="instrument", + request_content={ + "instruments": instruments, + "start_time": str(start_time), + "end_time": str(end_time), + "freq": freq, + "as_list": as_list, + }, + msg_queue=self.queue, + msg_proc_func=inst_msg_proc_func, + ) + result = self.queue.get(timeout=C["timeout"]) + if isinstance(result, Exception): + raise result + get_module_logger("data").debug("get result") + return result + + +class ClientDatasetProvider(DatasetProvider): + """Client dataset data provider class + + Provide dataset data by requesting data from server as a client. + """ + + def __init__(self): + self.conn = None + + def set_conn(self, conn): + self.conn = conn + self.queue = queue.Queue() + + def dataset( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=0, + return_uri=False, + ): + if Inst.get_inst_type(instruments) == Inst.DICT: + get_module_logger("data").warning( + "Getting features from a dict of instruments is not recommended because the features will not be " + "cached! " + "The dict of instruments will be cleaned every day." + ) + + if disk_cache == 0: + """ + Call the server to generate the expression cache. + Then load the data from the expression cache directly. + - default using multi-kernel method. + + """ + self.conn.send_request( + request_type="feature", + request_content={ + "instruments": instruments, + "fields": fields, + "start_time": start_time, + "end_time": end_time, + "freq": freq, + "disk_cache": 0, + }, + msg_queue=self.queue, + ) + feature_uri = self.queue.get(timeout=C["timeout"]) + if isinstance(feature_uri, Exception): + raise feature_uri + else: + instruments_d = self.get_instruments_d(instruments, freq) + column_names = self.get_column_names(fields) + cal = Cal.calendar(start_time, end_time, freq) + if len(cal) == 0: + return pd.DataFrame(columns=column_names) + start_time = cal[0] + end_time = cal[-1] + + data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq) + if return_uri: + return data, feature_uri + else: + return data + else: + + """ + Call the server to generate the data-set cache, get the uri of the cache file. + Then load the data from the file on NFS directly. + - using single-process implementation. + + """ + self.conn.send_request( + request_type="feature", + request_content={ + "instruments": instruments, + "fields": fields, + "start_time": start_time, + "end_time": end_time, + "freq": freq, + "disk_cache": 1, + }, + msg_queue=self.queue, + ) + # - Done in callback + feature_uri = self.queue.get(timeout=C["timeout"]) + if isinstance(feature_uri, Exception): + raise feature_uri + get_module_logger("data").debug("get result") + try: + # pre-mound nfs, used for demo + mnt_feature_uri = os.path.join(C.mount_path, C.dataset_cache_dir_name, feature_uri) + df = ServerDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields) + get_module_logger("data").debug("finish slicing data") + if return_uri: + return df, feature_uri + return df + except AttributeError: + raise IOError("Unable to fetch instruments from remote server!") + + +class BaseProvider: + """Local provider class + + To keep compatible with old qlib provider. + """ + + def calendar(self, start_time=None, end_time=None, freq="day", future=False): + return Cal.calendar(start_time, end_time, freq, future=future) + + def instruments(self, market="all", filter_pipe=None, start_time=None, end_time=None): + if start_time is not None or end_time is not None: + get_module_logger("Provider").warning( + "The instruments corresponds to a stock pool. " + "Parameters `start_time` and `end_time` does not take effect now." + ) + return InstrumentProvider.instruments(market, filter_pipe) + + def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + return Inst.list_instruments(instruments, start_time, end_time, freq, as_list) + + def features( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=None, + ): + """ + disk_cache : int + whether to skip(0)/use(1)/replace(2) disk_cache + + This function will try to use cache method which has a keyword `disk_cache`, + and will use provider method if a type error is raised because the DatasetD instance + is a provider class. + """ + disk_cache = C.default_disk_cache if disk_cache is None else disk_cache + if C.disable_disk_cache: + disk_cache = False + try: + return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache) + except TypeError: + return DatasetD.dataset(instruments, fields, start_time, end_time, freq) + + +class LocalProvider(BaseProvider): + def _uri(self, type, **kwargs): + """_uri + The server hope to get the uri of the request. The uri will be decided + by the dataprovider. For ex, different cache layer has different uri. + + :param type: The type of resource for the uri + :param **kwargs: + """ + if type == "calendar": + return Cal._uri(**kwargs) + elif type == "instrument": + return Inst._uri(**kwargs) + elif type == "feature": + return DatasetD._uri(**kwargs) + + def features_uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1): + """features_uri + + Return the uri of the generated cache of features/dataset + + :param disk_cache: + :param instruments: + :param fields: + :param start_time: + :param end_time: + :param freq: + """ + return DatasetD._dataset_uri(instruments, fields, start_time, end_time, freq, disk_cache) + + +class ClientProvider(BaseProvider): + """Client Provider + + Requesting data from server as a client. Can propose requests: + - Calendar : Directly respond a list of calendars + - Instruments (without filter): Directly respond a list/dict of instruments + - Instruments (with filters): Respond a list/dict of instruments + - Features : Respond a cache uri + The general workflow is described as follows: + When the user use client provider to propose a request, the client provider will connect the server and send the request. The client will start to wait for the response. The response will be made instantly indicating whether the cache is available. The waiting procedure will terminate only when the client get the reponse saying `feature_available` is true. + `BUG` : Everytime we make request for certain data we need to connect to the server, wait for the response and disconnect from it. We can't make a sequence of requests within one connection. You can refer to https://python-socketio.readthedocs.io/en/latest/client.html for documentation of python-socketIO client. + """ + + def __init__(self): + from .client import Client + + self.client = Client(C.flask_server, C.flask_port) + self.logger = get_module_logger(self.__class__.__name__) + if isinstance(Cal, ClientCalendarProvider): + Cal.set_conn(self.client) + Inst.set_conn(self.client) + if hasattr(DatasetD, "provider"): + DatasetD.provider.set_conn(self.client) + else: + DatasetD.set_conn(self.client) + + +class Wrapper(object): + """Data Provider Wrapper""" + + def __init__(self): + self._provider = None + + def register(self, provider): + self._provider = provider + + def __getattr__(self, key): + if self._provider is None: + raise AttributeError("Please run qlib.init() first using qlib") + return getattr(self._provider, key) + + +def get_cls_from_name(cls_name): + return getattr(importlib.import_module(".data", package="qlib"), cls_name) + + +def get_provider_obj(config, **params): + if isinstance(config, dict): + params.update(config["kwargs"]) + config = config["class"] + return get_cls_from_name(config)(**params) + + +def register_wrapper(wrapper, cls_or_obj): + """register_wrapper + + :param wrapper: A wrapper of all kinds of providers + :param cls_or_obj: A class or class name or object instance in data/data.py + """ + if isinstance(cls_or_obj, str): + cls_or_obj = get_cls_from_name(cls_or_obj) + obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj + wrapper.register(obj) + + +Cal = Wrapper() +Inst = Wrapper() +FeatureD = Wrapper() +ExpressionD = Wrapper() +DatasetD = Wrapper() +D = Wrapper() + + +def register_all_wrappers(): + """register_all_wrappers""" + logger = get_module_logger("data") + + _calendar_provider = get_provider_obj(C.calendar_provider) + if getattr(C, "calendar_cache", None) is not None: + _calendar_provider = get_provider_obj(C.calendar_cache, provider=_calendar_provider) + register_wrapper(Cal, _calendar_provider) + logger.debug(f"registering Cal {C.calendar_provider}-{C.calenar_cache}") + + register_wrapper(Inst, C.instrument_provider) + logger.debug(f"registering Inst {C.instrument_provider}") + + if getattr(C, "feature_provider", None) is not None: + feature_provider = get_provider_obj(C.feature_provider) + register_wrapper(FeatureD, feature_provider) + logger.debug(f"registering FeatureD {C.feature_provider}") + + if getattr(C, "expression_provider", None) is not None: + # This provider is unnecessary in client provider + _eprovider = get_provider_obj(C.expression_provider) + if getattr(C, "expression_cache", None) is not None: + _eprovider = get_provider_obj(C.expression_cache, provider=_eprovider) + register_wrapper(ExpressionD, _eprovider) + logger.debug(f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}") + + _dprovider = get_provider_obj(C.dataset_provider) + if getattr(C, "dataset_cache", None) is not None: + _dprovider = get_provider_obj(C.dataset_cache, provider=_dprovider) + register_wrapper(DatasetD, _dprovider) + logger.debug(f"registering DataseteD {C.dataset_provider}-{C.dataset_cache}") + + register_wrapper(D, C.provider) + logger.debug(f"registering D {C.provider}") diff --git a/qlib/data/filter.py b/qlib/data/filter.py new file mode 100644 index 0000000000..1552aeee7c --- /dev/null +++ b/qlib/data/filter.py @@ -0,0 +1,375 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import print_function +from abc import abstractmethod + +import re +import pandas as pd +import numpy as np +import six +import abc + +from .data import Cal, DatasetD + + +@six.add_metaclass(abc.ABCMeta) +class BaseDFilter(object): + """Dynamic Instruments Filter Abstract class + + Users can override this class to construct their own filter + + Override __init__ to input filter regulations + + Override filter_main to use the regulations to filter instruments + """ + + def __init__(self): + pass + + @staticmethod + def from_config(config): + """Construct an instance from config dict. + + Parameters + ---------- + config : dict + dict of config parameters + """ + raise NotImplementedError("Subclass of BaseDFilter must reimplement `from_config` method") + + @abstractmethod + def to_config(self): + """Construct an instance from config dict. + + Returns + ---------- + dict + return the dict of config parameters + """ + raise NotImplementedError("Subclass of BaseDFilter must reimplement `to_config` method") + + +@six.add_metaclass(abc.ABCMeta) +class SeriesDFilter(BaseDFilter): + """Dynamic Instruments Filter Abstract class to filter a series of certain features + + Filters should provide parameters: + + - filter start time + - filter end time + - filter rule + + Override __init__ to assign a certain rule to filter the series. + + Override _getFilterSeries to use the rule to filter the series and get a dict of {inst => series}, or override filter_main for more advanced series filter rule + """ + + def __init__(self, fstart_time=None, fend_time=None): + """Init function for filter base class. + Filter a set of instruments based on a certain rule within a certain period assigned by fstart_time and fend_time. + + Parameters + ---------- + fstart_time: str + the time for the filter rule to start filter the instruments + fend_time: str + the time for the filter rule to stop filter the instruments + """ + super(SeriesDFilter, self).__init__() + self.filter_start_time = pd.Timestamp(fstart_time) if fstart_time else None + self.filter_end_time = pd.Timestamp(fend_time) if fend_time else None + + def _getTimeBound(self, instruments): + """Get time bound for all instruments. + + Parameters + ---------- + instruments: dict + the dict of instruments in the form {instrument_name => list of timestamp tuple} + + Returns + ---------- + pd.Timestamp, pd.Timestamp + the lower time bound and upper time bound of all the instruments + """ + trange = Cal.calendar(freq=self.filter_freq) + ubound, lbound = trange[0], trange[-1] + for _, timestamp in instruments.items(): + if timestamp: + lbound = timestamp[0][0] if timestamp[0][0] < lbound else lbound + ubound = timestamp[-1][-1] if timestamp[-1][-1] > ubound else ubound + return lbound, ubound + + def _toSeries(self, time_range, target_timestamp): + """Convert the target timestamp to a pandas series of bool value within a time range. + Make the time inside the target_timestamp range TRUE, others FALSE. + + Parameters + ---------- + time_range : D.calendar + the time range of the instruments + target_timestamp : list + the list of tuple (timestamp, timestamp) + + Returns + ---------- + pd.Series + the series of bool value for an instrument + """ + # Construct a whole dict of {date => bool} + timestamp_series = {timestamp: False for timestamp in time_range} + # Convert to pd.Series + timestamp_series = pd.Series(timestamp_series) + # Fill the date within target_timestamp with TRUE + for start, end in target_timestamp: + timestamp_series[Cal.calendar(start_time=start, end_time=end, freq=self.filter_freq)] = True + return timestamp_series + + def _filterSeries(self, timestamp_series, filter_series): + """Filter the timestamp series with filter series by using element-wise AND operation of the two series + + Parameters + ---------- + timestamp_series : pd.Series + the series of bool value indicating existing time + filter_series : pd.Series + the series of bool value indicating filter feature + + Returns + ---------- + pd.Series + the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp + """ + fstart, fend = list(filter_series.keys())[0], list(filter_series.keys())[-1] + timestamp_series[fstart:fend] = timestamp_series[fstart:fend] & filter_series + return timestamp_series + + def _toTimestamp(self, timestamp_series): + """Convert the timestamp series to a list of tuple (timestamp, timestamp) indicating a continuous range of TRUE + + Parameters + ---------- + timestamp_series: pd.Series + the series of bool value after being filtered + + Returns + ---------- + list + the list of tuple (timestamp, timestamp) + """ + # sort the timestamp_series according to the timestamps + timestamp_series.sort_index() + timestamp = [] + _lbool = None + _ltime = None + for _ts, _bool in timestamp_series.items(): + # there is likely to be NAN when the filter series don't have the + # bool value, so we just change the NAN into False + if _bool == np.nan: + _bool = False + if _lbool is None: + _cur_start = _ts + _lbool = _bool + _ltime = _ts + continue + if (_lbool, _bool) == (True, False): + if _cur_start: + timestamp.append((_cur_start, _ltime)) + elif (_lbool, _bool) == (False, True): + _cur_start = _ts + _lbool = _bool + _ltime = _ts + if _lbool: + timestamp.append((_cur_start, _ltime)) + return timestamp + + def __call__(self, instruments, start_time=None, end_time=None, freq="day"): + """Call this filter to get filtered instruments list""" + self.filter_freq = freq + return self.filter_main(instruments, start_time, end_time) + + @abstractmethod + def _getFilterSeries(self, instruments, fstart, fend): + """Get filter series based on the rules assigned during the initialization and the input time range. + + Parameters + ---------- + instruments : dict + the dict of instruments to be filtered + fstart : pd.Timestamp + start time of filter + fend : pd.Timestamp + end time of filter + + .. note:: fstart/fend indicates the intersection of instruments start/end time and filter start/end time + + Returns + ---------- + pd.Dataframe + a series of {pd.Timestamp => bool} + """ + raise NotImplementedError("Subclass of SeriesDFilter must reimplement `getFilterSeries` method") + + def filter_main(self, instruments, start_time=None, end_time=None): + """Implement this method to filter the instruments. + + Parameters + ---------- + instruments: dict + input instruments to be filtered + start_time: str + start of the time range + end_time: str + end of the time range + + Returns + ---------- + dict + filtered instruments, same structure as input instruments + """ + lbound, ubound = self._getTimeBound(instruments) + start_time = pd.Timestamp(start_time or lbound) + end_time = pd.Timestamp(end_time or ubound) + _instruments_filtered = {} + _all_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=self.filter_freq) + _filter_calendar = Cal.calendar( + start_time=self.filter_start_time and max(self.filter_start_time, _all_calendar[0]) or _all_calendar[0], + end_time=self.filter_end_time and min(self.filter_end_time, _all_calendar[-1]) or _all_calendar[-1], + freq=self.filter_freq, + ) + _all_filter_series = self._getFilterSeries(instruments, _filter_calendar[0], _filter_calendar[-1]) + for inst, timestamp in instruments.items(): + # Construct a whole map of date + _timestamp_series = self._toSeries(_all_calendar, timestamp) + # Get filter series + if inst in _all_filter_series: + _filter_series = _all_filter_series[inst] + else: + if self.keep: + _filter_series = pd.Series({timestamp: True for timestamp in _filter_calendar}) + else: + _filter_series = pd.Series({timestamp: False for timestamp in _filter_calendar}) + # Calculate bool value within the range of filter + _timestamp_series = self._filterSeries(_timestamp_series, _filter_series) + # Reform the map to (start_timestamp, end_timestamp) format + _timestamp = self._toTimestamp(_timestamp_series) + # Remove empty timestamp + if _timestamp: + _instruments_filtered[inst] = _timestamp + return _instruments_filtered + + +class NameDFilter(SeriesDFilter): + """Name dynamic instrument filter + + Filter the instruments based on a regulated name format. + + A name rule regular expression is required. + """ + + def __init__(self, name_rule_re, fstart_time=None, fend_time=None): + """Init function for name filter class + + params: + ------ + name_rule_re: str + regular expression for the name rule + """ + super(NameDFilter, self).__init__(fstart_time, fend_time) + self.name_rule_re = name_rule_re + + def _getFilterSeries(self, instruments, fstart, fend): + all_filter_series = {} + filter_calendar = Cal.calendar(start_time=fstart, end_time=fend, freq=self.filter_freq) + for inst, timestamp in instruments.items(): + if re.match(self.name_rule_re, inst): + _filter_series = pd.Series({timestamp: True for timestamp in filter_calendar}) + else: + _filter_series = pd.Series({timestamp: False for timestamp in filter_calendar}) + all_filter_series[inst] = _filter_series + return all_filter_series + + @staticmethod + def from_config(config): + return NameDFilter( + name_rule_re=config["name_rule_re"], + fstart_time=config["filter_start_time"], + fend_time=config["filter_end_time"], + ) + + def to_config(self): + return { + "filter_type": "NameDFilter", + "name_rule_re": self.name_rule_re, + "filter_start_time": str(self.filter_start_time) if self.filter_start_time else self.filter_start_time, + "filter_end_time": str(self.filter_end_time) if self.filter_end_time else self.filter_end_time, + } + + +class ExpressionDFilter(SeriesDFilter): + """Expression dynamic instrument filter + + Filter the instruments based on a certain expression. + + An expression rule indicating a certain feature field is required. + + Examples + ---------- + - *basic features filter* : rule_expression = '$close/$open>5' + - *cross-sectional features filter* : rule_expression = '$rank($close)<10' + - *time-sequence features filter* : rule_expression = '$Ref($close, 3)>100' + """ + + def __init__(self, rule_expression, fstart_time=None, fend_time=None, keep=False): + """Init function for expression filter class + + params: + ------ + fstart_time: str + filter the feature starting from this time + fend_time: str + filter the feature ending by this time + rule_expression: str + an input expression for the rule + keep: bool + whether to keep the instruments of which features don't exist in the filter time span + """ + super(ExpressionDFilter, self).__init__(fstart_time, fend_time) + self.rule_expression = rule_expression + self.keep = keep + + def _getFilterSeries(self, instruments, fstart, fend): + # do not use dataset cache + try: + _features = DatasetD.dataset( + instruments, + [self.rule_expression], + fstart, + fend, + freq=self.filter_freq, + disk_cache=0, + ) + except TypeError: + # use LocalDatasetProvider + _features = DatasetD.dataset(instruments, [self.rule_expression], fstart, fend, freq=self.filter_freq) + rule_expression_field_name = list(_features.keys())[0] + all_filter_series = _features[rule_expression_field_name] + return all_filter_series + + def from_config(config): + return ExpressionDFilter( + rule_expression=config["rule_expression"], + fstart_time=config["filter_start_time"], + fend_time=config["filter_end_time"], + keep=config["keep"], + ) + + def to_config(self): + return { + "filter_type": "ExpressionDFilter", + "rule_expression": self.rule_expression, + "filter_start_time": str(self.filter_start_time) if self.filter_start_time else self.filter_start_time, + "filter_end_time": str(self.filter_end_time) if self.filter_end_time else self.filter_end_time, + "keep": self.keep, + } diff --git a/qlib/data/ops.py b/qlib/data/ops.py new file mode 100644 index 0000000000..104296a0ea --- /dev/null +++ b/qlib/data/ops.py @@ -0,0 +1,1405 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import numpy as np +import pandas as pd + +from .base import Expression, ExpressionOps +from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi +from ._libs.expanding import expanding_slope, expanding_rsquare, expanding_resi +from ..log import get_module_logger + +__all__ = ( + "Ref", + "Max", + "Min", + "Sum", + "Mean", + "Std", + "Var", + "Skew", + "Kurt", + "Med", + "Mad", + "Slope", + "Rsquare", + "Resi", + "Rank", + "Quantile", + "Count", + "EMA", + "WMA", + "Corr", + "Cov", + "Delta", + "Abs", + "Sign", + "Log", + "Power", + "Add", + "Sub", + "Mul", + "Div", + "Greater", + "Less", + "And", + "Or", + "Not", + "Gt", + "Ge", + "Lt", + "Le", + "Eq", + "Ne", + "Mask", + "IdxMax", + "IdxMin", + "If", +) + +np.seterr(invalid="ignore") + +#################### Element-Wise Operator #################### + + +class ElemOperator(ExpressionOps): + """Element-wise Operator + + Parameters + ---------- + feature : Expression + feature instance + func : str + feature operation method + + Returns + ---------- + Expression + feature operation output + """ + + def __init__(self, feature, func): + self.feature = feature + self.func = func + + def __str__(self): + return "{}({})".format(type(self).__name__, self.feature) + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return getattr(np, self.func)(series) + + def get_longest_back_rolling(self): + return self.feature.get_longest_back_rolling() + + def get_extended_window_size(self): + return self.feature.get_extended_window_size() + + +class Abs(ElemOperator): + """Feature Absolute Value + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + Expression + a feature instance with absolute output + """ + + def __init__(self, feature): + super(Abs, self).__init__(feature, "abs") + + +class Sign(ElemOperator): + """Feature Sign + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + Expression + a feature instance with sign + """ + + def __init__(self, feature): + super(Sign, self).__init__(feature, "sign") + + +class Log(ElemOperator): + """Feature Log + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + Expression + a feature instance with log + """ + + def __init__(self, feature): + super(Log, self).__init__(feature, "log") + + +class Power(ElemOperator): + """Feature Power + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + Expression + a feature instance with power + """ + + def __init__(self, feature, exponent): + super(Power, self).__init__(feature, "power") + self.exponent = exponent + + def __str__(self): + return "{}({},{})".format(type(self).__name__, self.feature, self.exponent) + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return getattr(np, self.func)(series, self.exponent) + + +class Mask(ElemOperator): + """Feature Mask + + Parameters + ---------- + feature : Expression + feature instance + instrument : str + instrument mask + + Returns + ---------- + Expression + a feature instance with masked instrument + """ + + def __init__(self, feature, instrument): + super(Mask, self).__init__(feature, "mask") + self.instrument = instrument + + def __str__(self): + return "{}({},{})".format(type(self).__name__, self.feature, self.instrument.lower()) + + def _load_internal(self, instrument, start_index, end_index, freq): + return self.feature.load(self.instrument, start_index, end_index, freq) + + +class Not(ElemOperator): + """Not Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + feature elementwise not output + """ + + def __init__(self, feature): + super(Not, self).__init__(feature, "bitwise_not") + + +#################### Pair-Wise Operator #################### +class PairOperator(ExpressionOps): + """Pair-wise operator + + Parameters + ---------- + feature_left : Expression + feature instance or numeric value + feature_right : Expression + feature instance or numeric value + func : str + operator function + + Returns + ---------- + Feature: + two features' operation output + """ + + def __init__(self, feature_left, feature_right, func): + self.feature_left = feature_left + self.feature_right = feature_right + self.func = func + + def __str__(self): + return "{}({},{})".format(type(self).__name__, self.feature_left, self.feature_right) + + def _load_internal(self, instrument, start_index, end_index, freq): + assert any( + [isinstance(self.feature_left, Expression), self.feature_right, Expression] + ), "at least one of two inputs is Expression instance" + if isinstance(self.feature_left, Expression): + series_left = self.feature_left.load(instrument, start_index, end_index, freq) + else: + series_left = self.feature_left # numeric value + if isinstance(self.feature_right, Expression): + series_right = self.feature_right.load(instrument, start_index, end_index, freq) + else: + series_right = self.feature_right + return getattr(np, self.func)(series_left, series_right) + + def get_longest_back_rolling(self): + if isinstance(self.feature_left, Expression): + left_br = self.feature_left.get_longest_back_rolling() + else: + left_br = 0 + + if isinstance(self.feature_right, Expression): + right_br = self.feature_right.get_longest_back_rolling() + else: + right_br = 0 + return max(left_br, right_br) + + def get_extended_window_size(self): + if isinstance(self.feature_left, Expression): + ll, lr = self.feature_left.get_extended_window_size() + else: + ll, lr = 0, 0 + + if isinstance(self.feature_right, Expression): + rl, rr = self.feature_right.get_extended_window_size() + else: + rl, rr = 0, 0 + return max(ll, rl), max(lr, rr) + + +class Add(PairOperator): + """Add Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + two features' sum + """ + + def __init__(self, feature_left, feature_right): + super(Add, self).__init__(feature_left, feature_right, "add") + + +class Sub(PairOperator): + """Subtract Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + two features' subtraction + """ + + def __init__(self, feature_left, feature_right): + super(Sub, self).__init__(feature_left, feature_right, "subtract") + + +class Mul(PairOperator): + """Multiply Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + two features' product + """ + + def __init__(self, feature_left, feature_right): + super(Mul, self).__init__(feature_left, feature_right, "multiply") + + +class Div(PairOperator): + """Division Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + two features' division + """ + + def __init__(self, feature_left, feature_right): + super(Div, self).__init__(feature_left, feature_right, "divide") + + +class Greater(PairOperator): + """Greater Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + greater elements taken from the input two features + """ + + def __init__(self, feature_left, feature_right): + super(Greater, self).__init__(feature_left, feature_right, "maximum") + + +class Less(PairOperator): + """Less Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + smaller elements taken from the input two features + """ + + def __init__(self, feature_left, feature_right): + super(Less, self).__init__(feature_left, feature_right, "minimum") + + +class Gt(PairOperator): + """Greater Than Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + bool series indicate `left > right` + """ + + def __init__(self, feature_left, feature_right): + super(Gt, self).__init__(feature_left, feature_right, "greater") + + +class Ge(PairOperator): + """Greater Equal Than Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + bool series indicate `left >= right` + """ + + def __init__(self, feature_left, feature_right): + super(Ge, self).__init__(feature_left, feature_right, "greater_equal") + + +class Lt(PairOperator): + """Less Than Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + bool series indicate `left < right` + """ + + def __init__(self, feature_left, feature_right): + super(Lt, self).__init__(feature_left, feature_right, "less") + + +class Le(PairOperator): + """Less Equal Than Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + bool series indicate `left <= right` + """ + + def __init__(self, feature_left, feature_right): + super(Le, self).__init__(feature_left, feature_right, "less_equal") + + +class Eq(PairOperator): + """Equal Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + bool series indicate `left == right` + """ + + def __init__(self, feature_left, feature_right): + super(Eq, self).__init__(feature_left, feature_right, "equal") + + +class Ne(PairOperator): + """Not Equal Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + bool series indicate `left != right` + """ + + def __init__(self, feature_left, feature_right): + super(Ne, self).__init__(feature_left, feature_right, "not_equal") + + +class And(PairOperator): + """And Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + two features' row by row & output + """ + + def __init__(self, feature_left, feature_right): + super(And, self).__init__(feature_left, feature_right, "bitwise_and") + + +class Or(PairOperator): + """Or Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + + Returns + ---------- + Feature: + two features' row by row | outputs + """ + + def __init__(self, feature_left, feature_right): + super(Or, self).__init__(feature_left, feature_right, "bitwise_or") + + +#################### Triple-wise Operator #################### +class If(ExpressionOps): + """If Operator + + Parameters + ---------- + condition : Expression + feature instance with bool values as condition + feature_left : Expression + feature instance + feature_right : Expression + feature instance + """ + + def __init__(self, condition, feature_left, feature_right): + self.condition = condition + self.feature_left = feature_left + self.feature_right = feature_right + + def __str__(self): + return "If({},{},{})".format(self.condition, self.feature_left, self.feature_right) + + def _load_internal(self, instrument, start_index, end_index, freq): + series_cond = self.condition.load(instrument, start_index, end_index, freq) + if isinstance(self.feature_left, Expression): + series_left = self.feature_left.load(instrument, start_index, end_index, freq) + else: + series_left = self.feature_left + if isinstance(self.feature_right, Expression): + series_right = self.feature_right.load(instrument, start_index, end_index, freq) + else: + series_right = self.feature_right + series = pd.Series(np.where(series_cond, series_left, series_right), index=series_cond.index) + return series + + def get_longest_back_rolling(self): + if isinstance(self.feature_left, Expression): + left_br = self.feature_left.get_longest_back_rolling() + else: + left_br = 0 + + if isinstance(self.feature_right, Expression): + right_br = self.feature_right.get_longest_back_rolling() + else: + right_br = 0 + + if isinstance(self.condition, Expression): + c_br = self.condition.get_longest_back_rolling() + else: + c_br = 0 + return max(left_br, right_br, c_br) + + def get_extended_window_size(self): + if isinstance(self.feature_left, Expression): + ll, lr = self.feature_left.get_extended_window_size() + else: + ll, lr = 0, 0 + + if isinstance(self.feature_right, Expression): + rl, rr = self.feature_right.get_extended_window_size() + else: + rl, rr = 0, 0 + + if isinstance(self.condition, Expression): + cl, cr = self.condition.get_extended_window_size() + else: + cl, cr = 0, 0 + return max(ll, rl, cl), max(lr, rr, cr) + + +#################### Rolling #################### +# NOTE: methods like `rolling.mean` are optimized with cython, +# and are super faster than `rolling.apply(np.mean)` + + +class Rolling(ExpressionOps): + """Rolling Operator + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + func : str + rolling method + + Returns + ---------- + Expression + rolling outputs + """ + + def __init__(self, feature, N, func): + self.feature = feature + self.N = N + self.func = func + + def __str__(self): + return "{}({},{})".format(type(self).__name__, self.feature, self.N) + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + # NOTE: remove all null check, + # now it's user's responsibility to decide whether use features in null days + # isnull = series.isnull() # NOTE: isnull = NaN, inf is not null + if self.N == 0: + series = getattr(series.expanding(min_periods=1), self.func)() + else: + series = getattr(series.rolling(self.N, min_periods=1), self.func)() + # series.iloc[:self.N-1] = np.nan + # series[isnull] = np.nan + return series + + def get_longest_back_rolling(self): + if self.N == 0: + return np.inf + return self.feature.get_longest_back_rolling() + self.N - 1 + + def get_extended_window_size(self): + if self.N == 0: + # FIXME: How to make this accurate and efficiently? Or should we + # remove such support for N == 0? + get_module_logger(self.__class__.__name__).warning("The Rolling(ATTR, 0) will not be accurately calculated") + return self.feature.get_extended_window_size() + else: + lft_etd, rght_etd = self.feature.get_extended_window_size() + lft_etd = max(lft_etd + self.N - 1, lft_etd) + return lft_etd, rght_etd + + +class Ref(Rolling): + """Feature Reference + + Parameters + ---------- + feature : Expression + feature instance + N : int + N = 0, retrieve the first data; N > 0, retrieve data of N periods ago; N < 0, future data + + Returns + ---------- + Expression + a feature instance with target reference + """ + + def __init__(self, feature, N): + super(Ref, self).__init__(feature, N, "ref") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + # N = 0, return first day + if series.empty: + return series # Pandas bug, see: https://github.com/pandas-dev/pandas/issues/21049 + elif self.N == 0: + series = pd.Series(series.iloc[0], index=series.index) + else: + series = series.shift(self.N) # copy + return series + + def get_longest_back_rolling(self): + if self.N == 0: + return np.inf + return self.feature.get_longest_back_rolling() + self.N + + def get_extended_window_size(self): + if self.N == 0: + get_module_logger(self.__class__.__name__).warning("The Ref(ATTR, 0) will not be accurately calculated") + return self.feature.get_extended_window_size() + else: + lft_etd, rght_etd = self.feature.get_extended_window_size() + lft_etd = max(lft_etd + self.N, lft_etd) + rght_etd = max(rght_etd - self.N, rght_etd) + return lft_etd, rght_etd + + +class Mean(Rolling): + """Rolling Mean (MA) + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling average + """ + + def __init__(self, feature, N): + super(Mean, self).__init__(feature, N, "mean") + + +class Sum(Rolling): + """Rolling Sum + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling sum + """ + + def __init__(self, feature, N): + super(Sum, self).__init__(feature, N, "sum") + + +class Std(Rolling): + """Rolling Std + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling std + """ + + def __init__(self, feature, N): + super(Std, self).__init__(feature, N, "std") + + +class Var(Rolling): + """Rolling Variance + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling variance + """ + + def __init__(self, feature, N): + super(Var, self).__init__(feature, N, "var") + + +class Skew(Rolling): + """Rolling Skewness + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling skewness + """ + + def __init__(self, feature, N): + super(Skew, self).__init__(feature, N, "skew") + + +class Kurt(Rolling): + """Rolling Kurtosis + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling kurtosis + """ + + def __init__(self, feature, N): + super(Kurt, self).__init__(feature, N, "kurt") + + +class Max(Rolling): + """Rolling Max + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling max + """ + + def __init__(self, feature, N): + super(Max, self).__init__(feature, N, "max") + + +class IdxMax(Rolling): + """Rolling Max Index + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling max index + """ + + def __init__(self, feature, N): + super(IdxMax, self).__init__(feature, N, "idxmax") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + if self.N == 0: + series = series.expanding(min_periods=1).apply(lambda x: x.argmax() + 1, raw=True) + else: + series = series.rolling(self.N, min_periods=1).apply( + lambda x: x.argmax() + 1, + raw=True, + ) + return series + + +class Min(Rolling): + """Rolling Min + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling min + """ + + def __init__(self, feature, N): + super(Min, self).__init__(feature, N, "min") + + +class IdxMin(Rolling): + """Rolling Min Index + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling min index + """ + + def __init__(self, feature, N): + super(IdxMin, self).__init__(feature, N, "idxmin") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + if self.N == 0: + series = series.expanding(min_periods=1).apply(lambda x: x.argmin() + 1, raw=True) + else: + series = series.rolling(self.N, min_periods=1).apply( + lambda x: x.argmin() + 1, + raw=True, + ) + return series + + +class Quantile(Rolling): + """Rolling Quantile + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling quantile + """ + + def __init__(self, feature, N, qscore): + super(Quantile, self).__init__(feature, N, "quantile") + self.qscore = qscore + + def __str__(self): + return "{}({},{},{})".format(type(self).__name__, self.feature, self.N, self.qscore) + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + if self.N == 0: + series = series.expanding(min_periods=1).quantile(self.qscore) + else: + series = series.rolling(self.N, min_periods=1).quantile(self.qscore) + return series + + +class Med(Rolling): + """Rolling Median + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling median + """ + + def __init__(self, feature, N): + super(Med, self).__init__(feature, N, "median") + + +class Mad(Rolling): + """Rolling Mean Absolute Deviation + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling mean absolute deviation + """ + + def __init__(self, feature, N): + super(Mad, self).__init__(feature, N, "mad") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + # TODO: implement in Cython + + def mad(x): + x1 = x[~np.isnan(x)] + return np.mean(np.abs(x1 - x1.mean())) + + if self.N == 0: + series = series.expanding(min_periods=1).apply(mad, raw=True) + else: + series = series.rolling(self.N, min_periods=1).apply(mad, raw=True) + return series + + +class Rank(Rolling): + """Rolling Rank (Percentile) + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling rank + """ + + def __init__(self, feature, N): + super(Rank, self).__init__(feature, N, "rank") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + # TODO: implement in Cython + + def rank(x): + if np.isnan(x[-1]): + return np.nan + x1 = x[~np.isnan(x)] + if x1.shape[0] == 0: + return np.nan + return (x1.argsort()[-1] + 1) / len(x1) + + if self.N == 0: + series = series.expanding(min_periods=1).apply(rank, raw=True) + else: + series = series.rolling(self.N, min_periods=1).apply(rank, raw=True) + return series + + +class Count(Rolling): + """Rolling Count + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling count of number of non-NaN elements + """ + + def __init__(self, feature, N): + super(Count, self).__init__(feature, N, "count") + + +class Delta(Rolling): + """Rolling Delta + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with end minus start in rolling window + """ + + def __init__(self, feature, N): + super(Delta, self).__init__(feature, N, "delta") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + if self.N == 0: + series = series - series.iloc[0] + else: + series = series - series.shift(self.N) + return series + + +# TODO: +# support pair-wise rolling like `Slope(A, B, N)` +class Slope(Rolling): + """Rolling Slope + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with regression slope of given window + """ + + def __init__(self, feature, N): + super(Slope, self).__init__(feature, N, "slope") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + if self.N == 0: + series = pd.Series(expanding_slope(series.values), index=series.index) + else: + series = pd.Series(rolling_slope(series.values, self.N), index=series.index) + return series + + +class Rsquare(Rolling): + """Rolling R-value Square + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with regression r-value square of given window + """ + + def __init__(self, feature, N): + super(Rsquare, self).__init__(feature, N, "rsquare") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + if self.N == 0: + series = pd.Series(expanding_rsquare(series.values), index=series.index) + else: + series = pd.Series(rolling_rsquare(series.values, self.N), index=series.index) + return series + + +class Resi(Rolling): + """Rolling Regression Residuals + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with regression residuals of given window + """ + + def __init__(self, feature, N): + super(Resi, self).__init__(feature, N, "resi") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + if self.N == 0: + series = pd.Series(expanding_resi(series.values), index=series.index) + else: + series = pd.Series(rolling_resi(series.values, self.N), index=series.index) + return series + + +class WMA(Rolling): + """Rolling WMA + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with weighted moving average output + """ + + def __init__(self, feature, N): + super(WMA, self).__init__(feature, N, "wma") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + # TODO: implement in Cython + + def weighted_mean(x): + w = np.arange(len(x)) + w /= w.sum() + return np.nanmean(w * x) + + if self.N == 0: + series = series.expanding(min_periods=1).apply(weighted_mean, raw=True) + else: + series = series.rolling(self.N, min_periods=1).apply(weighted_mean, raw=True) + return series + + +class EMA(Rolling): + """Rolling Exponential Mean (EMA) + + Parameters + ---------- + feature : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with regression r-value square of given window + """ + + def __init__(self, feature, N): + super(EMA, self).__init__(feature, N, "ema") + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + + def exp_weighted_mean(x): + a = 1 - 2 / (1 + len(x)) + w = a ** np.arange(len(x))[::-1] + w /= w.sum() + return np.nansum(w * x) + + if self.N == 0: + series = series.expanding(min_periods=1).apply(exp_weighted_mean, raw=True) + else: + series = series.ewm(span=self.N, min_periods=1).mean() + return series + + +#################### Pair-Wise Rolling #################### +class PairRolling(ExpressionOps): + """Pair Rolling Operator + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling output of two input features + """ + + def __init__(self, feature_left, feature_right, N, func): + self.feature_left = feature_left + self.feature_right = feature_right + self.N = N + self.func = func + + def __str__(self): + return "{}({},{},{})".format(type(self).__name__, self.feature_left, self.feature_right, self.N) + + def _load_internal(self, instrument, start_index, end_index, freq): + series_left = self.feature_left.load(instrument, start_index, end_index, freq) + series_right = self.feature_right.load(instrument, start_index, end_index, freq) + if self.N == 0: + series = getattr(series_left.expanding(min_periods=1), self.func)(series_right) + else: + series = getattr(series_left.rolling(self.N, min_periods=1), self.func)(series_right) + return series + + def get_longest_back_rolling(self): + if self.N == 0: + return np.inf + return ( + max( + self.feature_left.get_longest_back_rolling(), + self.feature_right.get_longest_back_rolling(), + ) + + self.N + - 1 + ) + + def get_extended_window_size(self): + if self.N == 0: + get_module_logger(self.__class__.__name__).warning( + "The PairRolling(ATTR, 0) will not be accurately calculated" + ) + return self.feature.get_extended_window_size() + else: + ll, lr = self.feature_left.get_extended_window_size() + rl, rr = self.feature_right.get_extended_window_size() + return max(ll, rl) + self.N - 1, max(lr, rr) + + +class Corr(PairRolling): + """Rolling Correlation + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling correlation of two input features + """ + + def __init__(self, feature_left, feature_right, N): + super(Corr, self).__init__(feature_left, feature_right, N, "corr") + + +class Cov(PairRolling): + """Rolling Covariance + + Parameters + ---------- + feature_left : Expression + feature instance + feature_right : Expression + feature instance + N : int + rolling window size + + Returns + ---------- + Expression + a feature instance with rolling max of two input features + """ + + def __init__(self, feature_left, feature_right, N): + super(Cov, self).__init__(feature_left, feature_right, N, "cov") diff --git a/qlib/log.py b/qlib/log.py new file mode 100644 index 0000000000..bc87fc5796 --- /dev/null +++ b/qlib/log.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import os +import re +import logging +from time import time +import logging.handlers +from logging import config as logging_config + +from .config import C + + +def get_module_logger(module_name, level=None): + """ + Get a logger for a specific module. + + :param module_name: str + Logic module name. + :param level: int + :param sh_level: int + Stream handler log level. + :param log_format: str + :return: Logger + Logger object. + """ + if level is None: + level = C.logging_level + + module_name = "qlib.{}".format(module_name) + # Get logger. + module_logger = logging.getLogger(module_name) + module_logger.setLevel(level) + return module_logger + + +class TimeInspector(object): + + timer_logger = get_module_logger("timer", level=logging.WARNING) + + time_marks = [] + + @classmethod + def set_time_mark(cls): + """ + Set a time mark with current time, and this time mark will push into a stack. + :return: float + A timestamp for current time. + """ + _time = time() + cls.time_marks.append(_time) + return _time + + @classmethod + def pop_time_mark(cls): + """ + Pop last time mark from stack. + """ + return cls.time_marks.pop() + + @classmethod + def get_cost_time(cls): + """ + Get last time mark from stack, calculate time diff with current time. + :return: float + Time diff calculated by last time mark with current time. + """ + cost_time = time() - cls.time_marks.pop() + return cost_time + + @classmethod + def log_cost_time(cls, info="Done"): + """ + Get last time mark from stack, calculate time diff with current time, and log time diff and info. + :param info: str + Info that will be log into stdout. + """ + cost_time = time() - cls.time_marks.pop() + cls.timer_logger.info("Time cost: {0:.5f} | {1}".format(cost_time, info)) + + +def set_log_with_config(log_config: dict): + """set log with config + + :param log_config: + :return: + """ + logging_config.dictConfig(log_config) + + +class LogFilter(logging.Filter): + def __init__(self, param=None): + self.param = param + + @staticmethod + def match_msg(filter_str, msg): + match = False + try: + if re.match(filter_str, msg): + match = True + except Exception: + pass + return match + + def filter(self, record): + allow = True + if isinstance(self.param, str): + allow = not self.match_msg(self.param, record.msg) + elif isinstance(self.param, list): + allow = not any([self.match_msg(p, record.msg) for p in self.param]) + return allow diff --git a/qlib/utils.py b/qlib/utils.py new file mode 100644 index 0000000000..225c03aba0 --- /dev/null +++ b/qlib/utils.py @@ -0,0 +1,547 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import os +import re +import copy +import json +import yaml +import redis +import bisect +import shutil +import difflib +import hashlib +import datetime +import requests +import tempfile +import importlib +import contextlib +import numpy as np +import pandas as pd +from pathlib import Path + +from .config import C +from .log import get_module_logger + +log = get_module_logger("utils") + + +#################### Server #################### +def get_redis_connection(): + """get redis connection instance.""" + return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db) + + +#################### Data #################### +def read_bin(file_path, start_index, end_index): + with open(file_path, "rb") as f: + # read start_index + ref_start_index = int(np.frombuffer(f.read(4), dtype=" end_index: + return pd.Series() + # calculate offset + f.seek(4 * (si - ref_start_index) + 4) + # read nbytes + count = end_index - si + 1 + data = np.frombuffer(f.read(4 * count), dtype="= data[mid][level]: + left = mid + 1 + else: + right = mid + return left + + +#################### HTTP #################### +def requests_with_retry(url, retry=5, **kwargs): + while retry > 0: + retry -= 1 + try: + res = requests.get(url, timeout=1, **kwargs) + assert res.status_code in {200, 206} + return res + except AssertionError: + continue + except Exception as e: + log.warning("exception encountered {}".format(e)) + continue + raise Exception("ERROR: requests failed!") + + +#################### Parse #################### +def parse_config(config): + # Check whether need parse, all object except str do not need to be parsed + if not isinstance(config, str): + return config + # Check whether config is file + if os.path.exists(config): + with open(config, "r") as f: + return yaml.load(f) + # Check whether the str can be parsed + try: + return yaml.load(config) + except BaseException: + raise ValueError("cannot parse config!") + + +#################### Other #################### +def drop_nan_by_y_index(x, y, weight=None): + # x, y, weight: DataFrame + # Find index of rows which do not contain Nan in all columns from y. + mask = ~y.isna().any(axis=1) + # Get related rows from x, y, weight. + x = x[mask] + y = y[mask] + if weight is not None: + weight = weight[mask] + return x, y, weight + + +def hash_args(*args): + # json.dumps will keep the dict keys always sorted. + string = json.dumps(args, sort_keys=True, default=str) # frozenset + return hashlib.md5(string.encode()).hexdigest() + + +def parse_field(field): + # Following patterns will be matched: + # - $close -> Feature("close") + # - $close5 -> Feature("close5") + # - $open+$close -> Feature("open")+Feature("close") + if not isinstance(field, str): + field = str(field) + return re.sub(r"\$(\w+)", r'Feature("\1")', field) + + +def get_module_by_module_path(module_path): + """Load module path + + :param module_path: + :return: + """ + + if ".py" in module_path: + module_spec = importlib.util.spec_from_file_location("", module_path) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + else: + module = importlib.import_module(module_path) + + return module + + +def compare_dict_value(src_data: dict, dst_data: dict): + """Compare dict value + + :param src_data: + :param dst_data: + :return: + """ + + class DateEncoder(json.JSONEncoder): + # FIXME: This class can only be accurate to the day. If it is a minute, + # there may be a bug + def default(self, o): + if isinstance(o, (datetime.datetime, datetime.date)): + return o.strftime("%Y-%m-%d %H:%M:%S") + return json.JSONEncoder.default(self, o) + + src_data = json.dumps(src_data, indent=4, sort_keys=True, cls=DateEncoder) + dst_data = json.dumps(dst_data, indent=4, sort_keys=True, cls=DateEncoder) + diff = difflib.ndiff(src_data, dst_data) + changes = [line for line in diff if line.startswith("+ ") or line.startswith("- ")] + return changes + + +def create_save_path(save_path=None): + """Create save path + + :param save_path: + :return: + """ + if save_path: + if not os.path.exists(save_path): + os.makedirs(save_path) + else: + temp_dir = os.path.expanduser("~/tmp") + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + _, save_path = tempfile.mkstemp(dir=temp_dir) + return save_path + + +@contextlib.contextmanager +def save_multiple_parts_file(filename, format="gztar"): + """Save multiple parts file + + Implementation process: + 1. get the absolute path to 'filename' + 2. create a 'filename' directory + 3. user does something with file_path('filename/') + 4. remove 'filename' directory + 5. make_archive 'filename' directory, and rename 'archive file' to filename + + :param filename: result model path + :param format: archive format: one of "zip", "tar", "gztar", "bztar", or "xztar" + :return: real model path + + Usage:: + + >>> # The following code will create an archive file('~/tmp/test_file') containing 'test_doc_i'(i is 0-10) files. + >>> with save_multiple_parts_file('~/tmp/test_file') as filename_dir: + ... for i in range(10): + ... temp_path = os.path.join(filename_dir, 'test_doc_{}'.format(str(i))) + ... with open(temp_path) as fp: + ... fp.write(str(i)) + ... + + """ + + if filename.startswith("~"): + filename = os.path.expanduser(filename) + + file_path = os.path.abspath(filename) + + # Create model dir + if os.path.exists(file_path): + raise FileExistsError("ERROR: file exists: {}, cannot be create the directory.".format(file_path)) + + os.makedirs(file_path) + + # return model dir + yield file_path + + # filename dir to filename.tar.gz file + tar_file = shutil.make_archive(file_path, format=format, root_dir=file_path) + + # Remove filename dir + if os.path.exists(file_path): + shutil.rmtree(file_path) + + # filename.tar.gz rename to filename + os.rename(tar_file, file_path) + + +@contextlib.contextmanager +def unpack_archive_with_buffer(buffer, format="gztar"): + """Unpack archive with archive buffer + After the call is finished, the archive file and directory will be deleted. + + Implementation process: + 1. create 'tempfile' in '~/tmp/' and directory + 2. 'buffer' write to 'tempfile' + 3. unpack archive file('tempfile') + 4. user does something with file_path('tempfile/') + 5. remove 'tempfile' and 'tempfile directory' + + :param buffer: bytes + :param format: archive format: one of "zip", "tar", "gztar", "bztar", or "xztar" + :return: unpack archive directory path + + Usage:: + + >>> # The following code is to print all the file names in 'test_unpack.tar.gz' + >>> with open('test_unpack.tar.gz') as fp: + ... buffer = fp.read() + ... + >>> with unpack_archive_with_buffer(buffer) as temp_dir: + ... for f_n in os.listdir(temp_dir): + ... print(f_n) + ... + + """ + temp_dir = os.path.expanduser("~/tmp") + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + with tempfile.NamedTemporaryFile("wb", delete=False, dir=temp_dir) as fp: + fp.write(buffer) + file_path = fp.name + + try: + tar_file = file_path + ".tar.gz" + os.rename(file_path, tar_file) + # Create dir + os.makedirs(file_path) + shutil.unpack_archive(tar_file, format=format, extract_dir=file_path) + + # Return temp dir + yield file_path + + except Exception as e: + log.error(str(e)) + finally: + # Remove temp tar file + if os.path.exists(tar_file): + os.unlink(tar_file) + + # Remove temp model dir + if os.path.exists(file_path): + shutil.rmtree(file_path) + + +@contextlib.contextmanager +def get_tmp_file_with_buffer(buffer): + temp_dir = os.path.expanduser("~/tmp") + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + with tempfile.NamedTemporaryFile("wb", delete=True, dir=temp_dir) as fp: + fp.write(buffer) + file_path = fp.name + yield file_path + + +def remove_repeat_field(fields): + """remove repeat field + + :param fields: list; features fields + :return: list + """ + fields = copy.deepcopy(fields) + _fields = set(fields) + return sorted(_fields, key=fields.index) + + +def remove_fields_space(fields: [list, str, tuple]): + """remove fields space + + :param fields: features fields + :return: list or str + """ + if isinstance(fields, str): + return fields.replace(" ", "") + return [i.replace(" ", "") for i in fields if isinstance(i, str)] + + +def normalize_cache_fields(fields: [list, tuple]): + """normalize cache fields + + :param fields: features fields + :return: list + """ + return sorted(remove_repeat_field(remove_fields_space(fields))) + + +def normalize_cache_instruments(instruments): + """normalize cache instruments + + :return: list or dict + """ + if isinstance(instruments, (list, tuple, pd.Index, np.ndarray)): + instruments = sorted(list(instruments)) + else: + # dict type stockpool + if "market" in instruments: + pass + else: + instruments = {k: sorted(v) for k, v in instruments.items()} + return instruments + + +def is_tradable_date(cur_date): + """judgy whether date is a tradable date + ---------- + date : pandas.Timestamp + current date + """ + from .data import D + + return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date()) + + +def get_date_range(trading_date, shift, future=False): + """get trading date range by shift + + :param trading_date: + :param shift: int + :param future: bool + :return: + """ + from .data import D + + calendar = D.calendar(future=future) + if pd.to_datetime(trading_date) not in list(calendar): + raise ValueError("{} is not trading day!".format(str(trading_date))) + day_index = bisect.bisect_left(calendar, trading_date) + if 0 <= (day_index + shift) < len(calendar): + if shift > 0: + return calendar[day_index + 1 : day_index + 1 + shift] + else: + return calendar[day_index + shift : day_index] + else: + return calendar + + +def get_date_by_shift(trading_date, shift, future=False): + """get trading date with shift bias wil cur_date + e.g. : shift == 1, return next trading date + shift == -1, return previous trading date + ---------- + trading_date : pandas.Timestamp + current date + shift : int + """ + return get_date_range(trading_date, shift, future)[0 if shift < 0 else -1] if shift != 0 else trading_date + + +def get_next_trading_date(trading_date, future=False): + """get next trading date + ---------- + cur_date : pandas.Timestamp + current date + """ + return get_date_by_shift(trading_date, 1, future=future) + + +def get_pre_trading_date(trading_date, future=False): + """get previous trading date + ---------- + date : pandas.Timestamp + current date + """ + return get_date_by_shift(trading_date, -1, future=future) + + +def transform_end_date(end_date=None, freq="day"): + """get previous trading date + If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned. + Otherwise, returns the end_date + ---------- + end_date: str + end trading date + date : pandas.Timestamp + current date + """ + from .data import D + + last_date = D.calendar(freq=freq)[-1] + if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)): + log.warning( + "\nInfo: the end_date in the configuration file is {}, " + "so the default last date {} is used.".format(end_date, last_date) + ) + end_date = last_date + return end_date + + +def get_date_in_file_name(file_name): + """Get the date(YYYY-MM-DD) written in file name + Parameter + file_name : str + :return + date : str + 'YYYY-MM-DD' + """ + pattern = "[0-9]{4}-[0-9]{2}-[0-9]{2}" + date = re.search(pattern, str(file_name)).group() + return date + + +def split_pred(pred, number=None, split_date=None): + """split the score file into two part + Parameter + --------- + pred : pd.DataFrame (index:) + A score file of stocks + number: the number of dates for pred_left + split_date: the last date of the pred_left + Return + ------- + pred_left : pd.DataFrame (index:) + The first part of original score file + pred_right : pd.DataFrame (index:) + The second part of original score file + """ + if number is None and split_date is None: + raise ValueError("`number` and `split date` cannot both be None") + dates = sorted(pred.index.get_level_values("datetime").unique()) + dates = list(map(pd.Timestamp, dates)) + if split_date is None: + date_left_end = dates[number - 1] + date_right_begin = dates[number] + date_left_start = None + else: + split_date = pd.Timestamp(split_date) + date_left_end = split_date + date_right_begin = split_date + pd.Timedelta(days=1) + if number is None: + date_left_start = None + else: + end_idx = bisect.bisect_right(dates, split_date) + date_left_start = dates[end_idx - number] + pred_temp = pred.sort_index() + pred_left = pred_temp.loc(axis=0)[:, date_left_start:date_left_end] + pred_right = pred_temp.loc(axis=0)[:, date_right_begin:] + return pred_left, pred_right + + +def can_use_cache(): + res = True + r = get_redis_connection() + try: + r.client() + except redis.exceptions.ConnectionError: + res = False + finally: + r.close() + return res + + +def exists_qlib_data(qlib_dir): + qlib_dir = Path(qlib_dir).expanduser() + if not qlib_dir.exists(): + return False + + calendars_dir = qlib_dir.joinpath("calendars") + instruments_dir = qlib_dir.joinpath("instruments") + features_dir = qlib_dir.joinpath("features") + # check dir + for _dir in [calendars_dir, instruments_dir, features_dir]: + if not (_dir.exists() and list(_dir.iterdir())): + return False + # check calendar bin + for _calendar in calendars_dir.iterdir(): + if not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")): + return False + + # check instruments + code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir())) + _instrument = instruments_dir.joinpath("all.txt") + miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names) + if miss_code and any(map(lambda x: "sht" not in x, miss_code)): + return False + + return True diff --git a/qlib/version.py b/qlib/version.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py new file mode 100644 index 0000000000..d972f6318e --- /dev/null +++ b/scripts/dump_bin.py @@ -0,0 +1,250 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import shutil +from pathlib import Path +from functools import partial +from concurrent.futures import ThreadPoolExecutor + +import fire +import numpy as np +import pandas as pd +from tqdm import tqdm +from loguru import logger + + +class DumpData(object): + FILE_SUFFIX = ".csv" + + def __init__( + self, + csv_path: str, + qlib_dir: str, + backup_dir: str = None, + freq: str = "day", + works: int = None, + date_field_name: str = "date", + ): + """ + + Parameters + ---------- + csv_path: str + stock data path or directory + qlib_dir: str + qlib(dump) data director + backup_dir: str, default None + if backup_dir is not None, backup qlib_dir to backup_dir + freq: str, default "day" + transaction frequency + works: int, default None + number of threads + date_field_name: str, default "date" + the name of the date field in the csv + """ + csv_path = Path(csv_path).expanduser() + self.csv_files = sorted(csv_path.glob(f"*{self.FILE_SUFFIX}") if csv_path.is_dir() else [csv_path]) + self.qlib_dir = Path(qlib_dir).expanduser() + self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser() + if backup_dir is not None: + self._backup_qlib_dir(Path(backup_dir).expanduser()) + + self.freq = freq + self.calendar_format = "%Y-%m-%d" if self.freq == "day" else "%Y-%m-%d %H:%M:%S" + + self.works = works + self.date_field_name = date_field_name + + self._calendars_dir = self.qlib_dir.joinpath("calendars") + self._features_dir = self.qlib_dir.joinpath("features") + self._instruments_dir = self.qlib_dir.joinpath("instruments") + + self._calendars_list = [] + self._include_fields = () + self._exclude_fields = () + + def _backup_qlib_dir(self, target_dir: Path): + shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve())) + + def _get_date_for_df(self, file_path: Path, *, is_begin_end: bool = False): + df = pd.read_csv(str(file_path.resolve())) + if df.empty or self.date_field_name not in df.columns.tolist(): + return [] + if is_begin_end: + return [df[self.date_field_name].min(), df[self.date_field_name].max()] + return df[self.date_field_name].tolist() + + def _get_source_data(self, file_path: Path): + df = pd.read_csv(str(file_path.resolve())) + df[self.date_field_name] = df[self.date_field_name].astype(np.datetime64) + return df + + def _file_to_bin(self, file_path: Path = None): + code = file_path.name[: -len(self.FILE_SUFFIX)].strip().lower() + features_dir = self._features_dir.joinpath(code) + features_dir.mkdir(parents=True, exist_ok=True) + calendars_df = pd.DataFrame(data=self._calendars_list, columns=[self.date_field_name]) + calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64) + # read csv file + df = self._get_source_data(file_path) + cal_df = calendars_df[ + (calendars_df[self.date_field_name] >= df[self.date_field_name].min()) + & (calendars_df[self.date_field_name] <= df[self.date_field_name].max()) + ] + cal_df.set_index(self.date_field_name, inplace=True) + df.set_index(self.date_field_name, inplace=True) + r_df = df.reindex(cal_df.index) + date_index = self._calendars_list.index(r_df.index.min()) + for field in ( + self._include_fields + if self._include_fields + else set(r_df.columns) - set(self._exclude_fields) + if self._exclude_fields + else r_df.columns + ): + + bin_path = features_dir.joinpath(f"{field}.{self.freq}.bin") + if field not in r_df.columns: + continue + r = np.hstack([date_index, r_df[field]]).astype(" --qlib_dir + + Examples + --------- + + # dump all stock + python dump_bin.py dump_features --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name + # dump one stock + python dump_bin.py dump_features --csv_path ~/tmp/stock_data/sh600000.csv --qlib_dir ~/tmp/qlib_data --calendar_path ~/tmp/qlib_data/calendar/all.txt --exclude_fields date,code,timestamp,code_name + """ + logger.info("start dump features......") + if calendar_path is not None: + # read calendar from calendar file + self._calendars_list = self._read_calendar(Path(calendar_path)) + + if not self._calendars_list: + self.dump_calendars() + + self._include_fields = tuple(map(str.strip, include_fields)) if include_fields else self._include_fields + self._exclude_fields = tuple(map(str.strip, exclude_fields)) if exclude_fields else self._exclude_fields + with tqdm(total=len(self.csv_files)) as p_bar: + with ThreadPoolExecutor(max_workers=self.works) as executor: + for _ in executor.map(self._file_to_bin, self.csv_files): + p_bar.update() + + logger.info("end of features dump.\n") + + def dump_calendars(self): + """dump calendars + + Notes + --------- + python dump_bin.py dump_calendars --csv_path --qlib_dir + + Examples + --------- + python dump_bin.py dump_calendars --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data + """ + logger.info("start dump calendars......") + calendar_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve()) + all_datetime = set() + with tqdm(total=len(self.csv_files)) as p_bar: + with ThreadPoolExecutor(max_workers=self.works) as executor: + for temp_datetime in executor.map(self._get_date_for_df, self.csv_files): + all_datetime = all_datetime | set(temp_datetime) + p_bar.update() + + self._calendars_list = sorted(map(pd.Timestamp, all_datetime)) + self._calendars_dir.mkdir(parents=True, exist_ok=True) + result_calendar_list = list(map(lambda x: x.strftime(self.calendar_format), self._calendars_list)) + np.savetxt(calendar_path, result_calendar_list, fmt="%s", encoding="utf-8") + logger.info("end of calendars dump.\n") + + def dump_instruments(self): + """dump instruments + + Notes + --------- + python dump_bin.py dump_instruments --csv_path --qlib_dir + + Examples + --------- + python dump_bin.py dump_instruments --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data + """ + logger.info("start dump instruments......") + symbol_list = list(map(lambda x: x.name[: -len(self.FILE_SUFFIX)], self.csv_files)) + _result_list = [] + _fun = partial(self._get_date_for_df, is_begin_end=True) + with tqdm(total=len(symbol_list)) as p_bar: + with ThreadPoolExecutor(max_workers=self.works) as execute: + for symbol, res in zip(symbol_list, execute.map(_fun, self.csv_files)): + if res: + begin_time = res[0] + end_time = res[-1] + _result_list.append(f"{symbol.upper()}\t{begin_time}\t{end_time}") + p_bar.update() + + self._instruments_dir.mkdir(parents=True, exist_ok=True) + to_path = str(self._instruments_dir.joinpath("all.txt").resolve()) + np.savetxt(to_path, _result_list, fmt="%s", encoding="utf-8") + logger.info("end of instruments dump.\n") + + def dump(self, include_fields: str = None, exclude_fields: tuple = None): + """dump data + + Parameters + ---------- + include_fields: str + dump fields + + exclude_fields: str + fields not dumped + + Examples + --------- + python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --include_fields open,close,high,low,volume,factor + python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name + """ + if isinstance(exclude_fields, str): + exclude_fields = exclude_fields.split(",") + if isinstance(include_fields, str): + include_fields = include_fields.split(",") + self.dump_calendars() + self.dump_features(include_fields=include_fields, exclude_fields=exclude_fields) + self.dump_instruments() + + +if __name__ == "__main__": + fire.Fire(DumpData) diff --git a/scripts/get_data.py b/scripts/get_data.py new file mode 100644 index 0000000000..9e3a99556b --- /dev/null +++ b/scripts/get_data.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import fire +import zipfile +import requests +from tqdm import tqdm +from pathlib import Path +from loguru import logger + + +class GetData: + REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads" + + def _download_data(self, file_name: str, target_dir: [Path, str]): + target_dir = Path(target_dir).expanduser() + target_dir.mkdir(exist_ok=True, parents=True) + + url = f"{self.REMOTE_URL}/{file_name}" + target_path = target_dir.parent.joinpath(file_name) + + resp = requests.get(url, stream=True) + if resp.status_code != 200: + raise requests.exceptions.HTTPError() + + chuck_size = 1024 + logger.info(f"{file_name} downloading......") + with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar: + with target_path.open("wb") as fp: + for chuck in resp.iter_content(chunk_size=chuck_size): + fp.write(chuck) + p_bar.update(chuck_size) + + self._unzip(target_path, target_dir) + + @staticmethod + def _unzip(file_path: Path, target_dir: Path): + logger.info(f"{file_path} unzipping......") + with zipfile.ZipFile(str(file_path.resolve()), "r") as zp: + for _file in tqdm(zp.namelist()): + zp.extract(_file, str(target_dir.resolve())) + + def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data"): + """download cn qlib data from remote + + Parameters + ---------- + target_dir: str + data save directory + + Examples + --------- + python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data + ------- + + """ + file_name = "qlib_data_cn.zip" + self._download_data(file_name, target_dir) + + def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"): + """download cn csv data from remote + + Parameters + ---------- + target_dir: str + data save directory + + Examples + --------- + python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data + ------- + + """ + file_name = "csv_data_cn.zip" + self._download_data(file_name, target_dir) + + +if __name__ == "__main__": + fire.Fire(GetData) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..c07b952b23 --- /dev/null +++ b/setup.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# -*- coding: utf-8 -*- +import io +import os +import numpy + +from setuptools import find_packages, setup, Extension + +# Package meta-data. +NAME = "qlib" +DESCRIPTION = "A Quantitative-research Library" +REQUIRES_PYTHON = ">=3.5.0" +VERSION = "0.4.6.dev" + +# Detect Cython +try: + import Cython + + ver = Cython.__version__ + _CYTHON_INSTALLED = ver >= "0.28" +except ImportError: + _CYTHON_INSTALLED = False + +if not _CYTHON_INSTALLED: + print("Required Cython version >= 0.28 is not detected!") + print('Please run "pip install --upgrade cython" first.') + exit(-1) + +# What packages are required for this module to be executed? +# `estimator` may depend on other packages. In order to reduce dependencies, it is not written here. +REQUIRED = [ + "numpy>=1.12.0", + "pandas>=0.25.1", + "scipy>=1.0.0", + "requests>=2.18.0", + "sacred>=0.7.4", + "pymongo==3.7.2", + "python-socketio==3.1.2", + "redis>=3.0.1", + "python-redis-lock>=3.3.1", + "schedule>=0.6.0", + "cvxpy==1.0.21", + "hyperopt==0.1.1", + "fire>=0.2.1", + "statsmodels", + "xlrd>=1.0.0", + "plotly==3.5.0", + "matplotlib==3.1.3", + "tables>=3.6.1", + "pyyaml>=5.3.1", + "tqdm", + "loguru", + "lightgbm", + "tornado", +] + +# Numpy include +NUMPY_INCLUDE = numpy.get_include() + +here = os.path.abspath(os.path.dirname(__file__)) + +with io.open(os.path.join(here, "README.rst"), encoding="utf-8") as f: + long_description = "\n" + f.read() + +# Cython Extensions +extensions = [ + Extension( + "qlib.data._libs.rolling", + ["qlib/data/_libs/rolling.pyx"], + language="c++", + include_dirs=[NUMPY_INCLUDE], + ), + Extension( + "qlib.data._libs.expanding", + ["qlib/data/_libs/expanding.pyx"], + language="c++", + include_dirs=[NUMPY_INCLUDE], + ), +] + +# Where the magic happens: +setup( + name=NAME, + version=VERSION, + description=DESCRIPTION, + long_description=long_description, + python_requires=REQUIRES_PYTHON, + packages=find_packages(exclude=("tests",)), + # if your package is a single module, use this instead of 'packages': + # py_modules=['qlib'], + entry_points={ + # 'console_scripts': ['mycli=mymodule:cli'], + "console_scripts": [ + "estimator=qlib.contrib.estimator.launcher:run", + "tuner=qlib.contrib.tuner.launcher:run", + ], + }, + ext_modules=extensions, + install_requires=REQUIRED, + include_package_data=True, + classifiers=[ + # Trove classifiers + # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers + # 'License :: OSI Approved :: MIT License', + "Development Status :: 3 - Alpha", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + ], +) diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py new file mode 100644 index 0000000000..d779929fe8 --- /dev/null +++ b/tests/test_all_pipeline.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +import unittest +from pathlib import Path + +import numpy as np +import pandas as pd +from scipy.stats import pearsonr + +import qlib +from qlib.config import REG_CN +from qlib.utils import drop_nan_by_y_index +from qlib.contrib.model.gbdt import LGBModel +from qlib.contrib.estimator.handler import QLibDataHandlerV1 +from qlib.contrib.strategy.strategy import TopkAmountStrategy +from qlib.contrib.evaluate import ( + backtest as normal_backtest, + long_short_backtest, + risk_analysis, +) +from qlib.utils import exists_qlib_data + + +DATA_HANDLER_CONFIG = { + "dropna_label": True, + "start_date": "2007-01-01", + "end_date": "2020-08-01", + "market": "CSI500", +} + +MODEL_CONFIG = { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, +} + +TRAINER_CONFIG = { + "train_start_date": "2007-01-01", + "train_end_date": "2014-12-31", + "validate_start_date": "2015-01-01", + "validate_end_date": "2016-12-31", + "test_start_date": "2017-01-01", + "test_end_date": "2020-08-01", +} + +STRATEGY_CONFIG = { + "topk": 50, + "buffer_margin": 230, +} + +BACKTEST_CONFIG = { + "verbose": False, + "limit_threshold": 0.095, + "account": 100000000, + "benchmark": "SH000905", + "deal_price": "vwap", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, +} + + +# train +def train(): + """train model + + Returns + ------- + pred_score: pandas.DataFrame + predict scores + performance: dict + model performance + """ + # get data + x_train, y_train, x_validate, y_validate, x_test, y_test = QLibDataHandlerV1(**DATA_HANDLER_CONFIG).get_split_data( + **TRAINER_CONFIG + ) + + # train + model = LGBModel(**MODEL_CONFIG) + model.fit(x_train, y_train, x_validate, y_validate) + _pred = model.predict(x_test) + _pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns) + pred_score = pd.DataFrame(index=_pred.index) + pred_score["score"] = _pred.iloc(axis=1)[0] + + # get performance + model_score = model.score(x_test, y_test) + # Remove rows from x, y and w, which contain Nan in any columns in y_test. + x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test) + pred_test = model.predict(x_test) + model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0] + + return pred_score, {"model_score": model_score, "model_pearsonr": model_pearsonr} + + +def backtest(pred): + """backtest + + Parameters + ---------- + pred: pandas.DataFrame + predict scores + + Returns + ------- + report_normal: pandas.DataFrame + + positions_normal: dict + + long_short_reports: dict + + """ + strategy = TopkAmountStrategy(**STRATEGY_CONFIG) + _report_normal, _positions_normal = normal_backtest(pred, strategy=strategy, **BACKTEST_CONFIG) + _long_short_reports = long_short_backtest(pred, topk=50) + return _report_normal, _positions_normal, _long_short_reports + + +def analyze(report_normal, long_short_reports): + _analysis = dict() + _analysis["pred_long"] = risk_analysis(long_short_reports["long"]) + _analysis["pred_short"] = risk_analysis(long_short_reports["short"]) + _analysis["pred_long_short"] = risk_analysis(long_short_reports["long_short"]) + _analysis["sub_bench"] = risk_analysis(report_normal["return"] - report_normal["bench"]) + _analysis["sub_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"]) + analysis_df = pd.concat(_analysis) # type: pd.DataFrame + print(analysis_df) + return analysis_df + + +class TestAllFlow(unittest.TestCase): + PRED_SCORE = None + REPORT_NORMAL = None + POSITIONS = None + LONG_SHORT_REPORTS = None + + @classmethod + def setUpClass(cls) -> None: + # use default data + mount_path = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(mount_path): + print(f"Qlib data is not found in {mount_path}") + sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) + from get_data import GetData + + GetData().qlib_data_cn(mount_path) + qlib.init(mount_path=mount_path, region=REG_CN) + + def test_0_train(self): + TestAllFlow.PRED_SCORE, model_pearsonr = train() + self.assertGreaterEqual(model_pearsonr["model_pearsonr"], 0, "train failed") + + def test_1_backtest(self): + TestAllFlow.REPORT_NORMAL, TestAllFlow.POSITIONS, TestAllFlow.LONG_SHORT_REPORTS = backtest( + TestAllFlow.PRED_SCORE + ) + analyze_df = analyze(TestAllFlow.REPORT_NORMAL, TestAllFlow.LONG_SHORT_REPORTS) + self.assertGreaterEqual( + analyze_df.loc(axis=0)["sub_cost", "annual"].values[0], + 0.10, + "backtest failed", + ) + + +def suite(): + _suite = unittest.TestSuite() + _suite.addTest(TestAllFlow("test_0_train")) + _suite.addTest(TestAllFlow("test_1_backtest")) + return _suite + + +if __name__ == "__main__": + runner = unittest.TextTestRunner() + runner.run(suite()) diff --git a/tests/test_dump_data.py b/tests/test_dump_data.py new file mode 100644 index 0000000000..b871fcc9a9 --- /dev/null +++ b/tests/test_dump_data.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import sys +import shutil +import unittest +from pathlib import Path + +import qlib +import numpy as np +import pandas as pd +from qlib.data import D + +sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) +from get_data import GetData +from dump_bin import DumpData + + +DATA_DIR = Path(__file__).parent.joinpath("test_data") +DATA_DIR.mkdir(exist_ok=True, parents=True) +SOURCE_DIR = DATA_DIR.joinpath("source") +QLIB_DIR = DATA_DIR.joinpath("qlib") +QLIB_DIR.mkdir(exist_ok=True, parents=True) + + +class TestDumpData(unittest.TestCase): + FIELDS = "open,close,high,low,volume,vwap,factor,change,money".split(",") + QLIB_FIELDS = list(map(lambda x: f"${x}", FIELDS)) + DUMP_DATA = DumpData(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR) + SOTCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.iterdir())) + + # simpe data + SIMPLE_DATA = None + + @classmethod + def setUpClass(cls) -> None: + GetData().csv_data_cn(SOURCE_DIR) + mount_path = provider_uri = str(QLIB_DIR.resolve()) + qlib.init( + mount_path=mount_path, + provider_uri=provider_uri, + expression_cache=None, + dataset_cache=None, + ) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(str(DATA_DIR.resolve())) + + def test_0_dump_calendars(self): + self.DUMP_DATA.dump_calendars() + ori_calendars = set( + map( + pd.Timestamp, + pd.read_csv(QLIB_DIR.joinpath("calendars", "day.txt"), header=None).loc[:, 0].values, + ) + ) + res_calendars = set(D.calendar()) + assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, "dump calendars failed" + + def test_1_dump_instruments(self): + self.DUMP_DATA.dump_instruments() + ori_ins = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.iterdir())) + res_ins = set(D.list_instruments(D.instruments("all"), as_list=True)) + assert len(ori_ins - res_ins) == len(ori_ins - res_ins) == 0, "dump instruments failed" + + def test_2_dump_features(self): + self.DUMP_DATA.dump_features(include_fields=self.FIELDS) + df = D.features(self.SOTCK_NAMES, self.QLIB_FIELDS) + TestDumpData.SIMPLE_DATA = df.loc(axis=0)[self.SOTCK_NAMES[0], :] + self.assertFalse(df.dropna().empty, "features data failed") + self.assertListEqual(list(df.columns), self.QLIB_FIELDS, "features columns failed") + + def test_3_dump_features_simple(self): + stock = self.SOTCK_NAMES[0] + dump_data = DumpData(csv_path=SOURCE_DIR.joinpath(f"{stock.upper()}.csv"), qlib_dir=QLIB_DIR) + dump_data.dump_features(include_fields=self.FIELDS, calendar_path=QLIB_DIR.joinpath("calendars", "day.txt")) + + df = D.features([stock], self.QLIB_FIELDS) + + self.assertEqual(len(df), len(TestDumpData.SIMPLE_DATA), "dump features simple failed") + self.assertTrue(np.isclose(df.dropna(), self.SIMPLE_DATA.dropna()).all(), "dump features simple failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_get_data.py b/tests/test_get_data.py new file mode 100644 index 0000000000..b7fdc274a4 --- /dev/null +++ b/tests/test_get_data.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +import shutil +import unittest +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) +from get_data import GetData + +import qlib +from qlib.data import D + +DATA_DIR = Path(__file__).parent.joinpath("test_data") +SOURCE_DIR = DATA_DIR.joinpath("source") +SOURCE_DIR.mkdir(exist_ok=True, parents=True) +QLIB_DIR = DATA_DIR.joinpath("qlib") +QLIB_DIR.mkdir(exist_ok=True, parents=True) + + +class TestGetData(unittest.TestCase): + FIELDS = "$open,$close,$high,$low,$volume,$vwap,$factor,$change,$money".split(",") + + @classmethod + def setUpClass(cls) -> None: + mount_path = provider_uri = str(QLIB_DIR.resolve()) + qlib.init( + mount_path=mount_path, + provider_uri=provider_uri, + expression_cache=None, + dataset_cache=None, + ) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(str(DATA_DIR.resolve())) + + def test_0_qlib_data(self): + + GetData().qlib_data_cn(QLIB_DIR) + df = D.features(D.instruments("sse50"), self.FIELDS) + self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed") + self.assertFalse(df.dropna().empty, "get qlib data failed") + + def test_1_csv_data(self): + GetData().csv_data_cn(SOURCE_DIR) + stock_name = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.iterdir())) + self.assertEqual(len(stock_name), 300, "get csv data failed") + + +if __name__ == "__main__": + unittest.main()