Skip to content

Commit

Permalink
Order execution open source (#1447)
Browse files Browse the repository at this point in the history
* Waiting for bin data

* Complete readme

* CI

* Add inst filter by time

* Update qlib/data/dataset/processor.py

* typo

* Fix time filter bug

* Add Filter and set Universe

* Complete data pipeline

* Fix Provider Logger Info Args

* Add DQN; a minor bugfix in ppo reward.

* update readme. modify assertion logic in strategy check.

* Fix Doc issues and fix black

* Fix pylint Error

---------

Co-authored-by: Young <afe.young@gmail.com>
Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 13, 2023
1 parent f98e04c commit 653c082
Show file tree
Hide file tree
Showing 24 changed files with 742 additions and 42 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -27,6 +27,8 @@ examples/estimator/estimator_example/
examples/rl/data/
examples/rl/checkpoints/
examples/rl/outputs/
examples/rl_order_execution/data/
examples/rl_order_execution/outputs/

*.egg-info/

Expand Down
4 changes: 2 additions & 2 deletions examples/benchmarks/LightGBM/multi_freq_handler.py
Expand Up @@ -29,13 +29,13 @@ def __init__(
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processor=None,
inst_processors=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = Avg15minLoader(
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processor=inst_processor
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processors=inst_processors
)
super().__init__(
instruments=instruments,
Expand Down
Expand Up @@ -18,7 +18,7 @@ data_handler_config: &data_handler_config
label: day
feature: 1min
# with label as reference
inst_processor:
inst_processors:
feature:
- class: Resample1minProcessor
module_path: features_sample.py
Expand Down
Expand Up @@ -19,7 +19,7 @@ data_handler_config: &data_handler_config
feature_15min: 1min
feature_day: day
# with label as reference
inst_processor:
inst_processors:
feature_15min:
- class: ResampleNProcessor
module_path: features_resample_N.py
Expand Down
100 changes: 100 additions & 0 deletions examples/rl_order_execution/README.md
@@ -0,0 +1,100 @@
# RL Example for Order Execution

This folder comprises an example of Reinforcement Learning (RL) workflows for order execution scenario, including both training workflows and backtest workflows.

## Data Processing

### Get Data

```
python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region hs300 --interval 5min
```

### Generate Pickle-Style Data

To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):

```
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
python scripts/collect_pickle_dataframe.py
python scripts/gen_training_orders.py
python scripts/merge_orders.py
```

When finished, the structure under `data/` should be:

```
data
├── bin
├── orders
├── pickle
└── pickle_dataframe
```

## Training

Each training task is specified by a config file. The config file for task `TASKNAME` is `exp_configs/train_TASKNAME.yml`. This example provides two training tasks:

- **PPO**: Method proposed by IJCAL 2020 paper "[An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization](https://www.ijcai.org/proceedings/2020/0627.pdf)".
- **OPDS**: Method proposed by AAAI 2021 paper "[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)".

The main differece between these two methods is their reward functions. Please see their config files for details.

Take OPDS as an example, to run the training workflow, run:

```
python -m qlib.rl.contrib.train_onpolicy --config_path exp_configs/train_opds.yml --run_backtest
```

Metrics, logs, and checkpoints will be stored under `outputs/opds` (configured by `exp_configs/train_opds.yml`).

## Backtest

Once the training workflow has completed, the trained model can be used for the backtesting workflow. Still taking OPDS as an example, once training is finished, the latest checkpoint of the model can be found at `outputs/opds/checkpoints/latest.pth`. To run backtest workflow:

1. Uncomment the `weight_file` parameter in `exp_configs/train_opds.yml` (it is commented by default). While it is possible to run the backtesting workflow without setting a checkpoint, this will lead to randomly initialized model results, thus making them meaningless.
2. Run `python -m qlib.rl.contrib.backtest --config_path exp_configs/backtest_opds.yml`.

The backtest result is stored in `outputs/checkpoints/backtest_result.csv`.

In addition to OPDS and PPO, we also provide TWAP ([Time-weighted average price](https://en.wikipedia.org/wiki/Time-weighted_average_price)) as a weak baseline. The config file for TWAP is `exp_configs/backtest_twap.yml`.

### Gap between backtest and training pipeline's testing

It is worthy to notice that the results of the backtesting process may differ from the results of the testing process used during training.
This is because different simulators are used to simulate market conditions during training and backtesting.
In training pipeline, the simplified simulator called `SingleAssetOrderExecutionSimple` is used for efficiency reasons.
`SingleAssetOrderExecutionSimple` makes no restriction to trading amounts.
No matter what the amount of the order is, it can be completely executed.
However, during backtesting, a more realistic simulator called `SingleAssetOrderExecution` is used.
It takes into account practical constraints in more real-world scenarios (for example, the trading volume must be a multiple of the smallest trading unit).
As a result, the amount of an order that is actually executed during backtesting may differ from the amount expected to be executed.

If you would like to obtain results that are exactly the same as those obtained during testing in the training pipeline, you could run training pipeline with only backtest phrase.
In order to do this:
- Modify the training config. Add the path of the checkpoint you want to use (see following for an example).
- Run `python -m qlib.rl.contrib.train_onpolicy --config_path PATH/TO/CONFIG --run_backtest --no_training`

```yaml
...
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
weight_file: PATH/TO/CHECKPOINT
module_path: qlib.rl.order_execution.policy
...
```

## Benchmarks (TBD)

To accurately evaluate the performance of models using Reinforcement Learning algorithms, it's best to run experiments multiple times and compute the average performance across all trials. However, given the time-consuming nature of model training, this is not always feasible. An alternative approach is to run each training task only once, selecting the 10 checkpoints with the highest validation performance to simulate multiple trials. In this example, we use "Price Advantage (PA)" as the metric for selecting these checkpoints. The average performance of these 10 checkpoints on the testing set is as follows:

| **Model** | **PA mean with std.** |
|-----------------------------|-----------------------|
| OPDS (with PPO policy) | 0.4785 ± 0.7815 |
| OPDS (with DQN policy) | -0.0114 ± 0.5780 |
| PPO | -1.0935 ± 0.0922 |
| TWAP | ≈ 0.0 ± 0.0 |

The table above also includes TWAP as a rule-based baseline. The ideal PA of TWAP should be 0.0, however, in this example, the order execution is divided into two steps: first, the order is split equally among each half hour, and then each five minutes within each half hour. Since trading is forbidden during the last five minutes of the day, this approach may slightly differ from traditional TWAP over the course of a full day (as there are 5 minutes missing in the last "half hour"). Therefore, the PA of TWAP can be considered as a number that is close to 0.0. To verify this, you may run a TWAP backtest and check the results.
59 changes: 59 additions & 0 deletions examples/rl_order_execution/exp_configs/backtest_opds.yml
@@ -0,0 +1,59 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: SAOEIntStrategy
kwargs:
data_granularity: 5
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
max_step: 8
values: 4
module_path: qlib.rl.order_execution.interpreter
network:
class: Recurrent
kwargs: {}
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
# weight_file: outputs/opds/checkpoints/latest.pth
module_path: qlib.rl.order_execution.policy
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/opds/
59 changes: 59 additions & 0 deletions examples/rl_order_execution/exp_configs/backtest_ppo.yml
@@ -0,0 +1,59 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: SAOEIntStrategy
kwargs:
data_granularity: 5
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
max_step: 8
values: 4
module_path: qlib.rl.order_execution.interpreter
network:
class: Recurrent
kwargs: {}
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
# weight_file: outputs/ppo/checkpoints/latest.pth
module_path: qlib.rl.order_execution.policy
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/ppo/
29 changes: 29 additions & 0 deletions examples/rl_order_execution/exp_configs/backtest_twap.yml
@@ -0,0 +1,29 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
volume_threshold: null
strategies:
1day:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
30min:
class: TWAPStrategy
kwargs: {}
module_path: qlib.contrib.strategy.rule_strategy
concurrency: 16
output_dir: outputs/twap/
61 changes: 61 additions & 0 deletions examples/rl_order_execution/exp_configs/train_opds.yml
@@ -0,0 +1,61 @@
simulator:
data_granularity: 5
time_per_step: 30
vol_limit: null
env:
concurrency: 48
parallel_mode: shmem
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
values: 4
max_step: 8
module_path: qlib.rl.order_execution.interpreter
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 5
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.order_execution.interpreter
reward:
class: PAPenaltyReward
kwargs:
penalty: 4.0
scale: 0.01
module_path: qlib.rl.order_execution.reward
data:
source:
order_dir: ./data/orders
data_dir: ./data/pickle_dataframe/backtest
total_time: 240
default_start_time_index: 0
default_end_time_index: 235
proc_data_dim: 5
num_workers: 0
queue_size: 20
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
policy:
class: PPO # PPO, DQN
kwargs:
lr: 0.0001
module_path: qlib.rl.order_execution.policy
runtime:
seed: 42
use_cuda: false
trainer:
max_epoch: 500
repeat_per_collect: 25
earlystop_patience: 50
episode_per_collect: 10000
batch_size: 1024
val_every_n_epoch: 4
checkpoint_path: ./outputs/opds
checkpoint_every_n_iters: 1

0 comments on commit 653c082

Please sign in to comment.