Skip to content

Commit

Permalink
Refine Qlib RL data format (#1480)
Browse files Browse the repository at this point in the history
* wip

* wip

* wip

* Fix naming errors

* Backtest test passed

* Why training stuck?

* Minor

* Refine train configs

* Use dummy in training

* Remove pickle_dataframe

* CI

* CI

* Add more strict condition to filter orders

* Pass test

* Add TODO in example

---------

Co-authored-by: Young <afe.young@gmail.com>
  • Loading branch information
lihuoran and you-n-g committed Apr 26, 2023
1 parent 46264df commit 7f1e8c5
Show file tree
Hide file tree
Showing 17 changed files with 236 additions and 249 deletions.
6 changes: 3 additions & 3 deletions examples/rl_order_execution/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region

To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):

[//]: # (TODO: Instead of dumping dataframe with different format &#40;like `_gen_dataset` and `_gen_day_dataset` in `qlib/contrib/data/highfreq_provider.py`&#41;, we encourage to implement different subclass of `Dataset` and `DataHandler`. This will keep the workflow cleaner and interfaces more consistent, and move all the complexity to the subclass.)

```
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
python scripts/collect_pickle_dataframe.py
python scripts/gen_training_orders.py
python scripts/merge_orders.py
```
Expand All @@ -27,8 +28,7 @@ When finished, the structure under `data/` should be:
data
├── bin
├── orders
├── pickle
└── pickle_dataframe
└── pickle
```

## Training
Expand Down
17 changes: 5 additions & 12 deletions examples/rl_order_execution/exp_configs/backtest_opds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@ start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
Expand Down Expand Up @@ -45,10 +36,12 @@ strategies:
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:
Expand Down
17 changes: 5 additions & 12 deletions examples/rl_order_execution/exp_configs/backtest_ppo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@ start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
Expand Down Expand Up @@ -45,10 +36,12 @@ strategies:
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:
Expand Down
9 changes: 0 additions & 9 deletions examples/rl_order_execution/exp_configs/backtest_twap.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@ start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
Expand Down
17 changes: 11 additions & 6 deletions examples/rl_order_execution/exp_configs/train_opds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ simulator:
time_per_step: 30
vol_limit: null
env:
concurrency: 48
parallel_mode: shmem
concurrency: 32
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
Expand All @@ -18,10 +18,13 @@ state_interpreter:
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
backtest: false
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
reward:
class: PAPenaltyReward
Expand All @@ -32,7 +35,9 @@ reward:
data:
source:
order_dir: ./data/orders
data_dir: ./data/pickle_dataframe/backtest
feature_root_dir: ./data/pickle/
feature_columns_today: ["$close0", "$volume0"]
feature_columns_yesterday: []
total_time: 240
default_start_time_index: 0
default_end_time_index: 235
Expand Down
17 changes: 11 additions & 6 deletions examples/rl_order_execution/exp_configs/train_ppo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ simulator:
time_per_step: 30
vol_limit: null
env:
concurrency: 48
parallel_mode: shmem
concurrency: 32
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
Expand All @@ -18,10 +18,13 @@ state_interpreter:
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
backtest: false
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
reward:
class: PPOReward
Expand All @@ -33,7 +36,9 @@ reward:
data:
source:
order_dir: ./data/orders
data_dir: ./data/pickle_dataframe/backtest
feature_root_dir: ./data/pickle/
feature_columns_today: ["$close0", "$volume0"]
feature_columns_yesterday: []
total_time: 240
default_start_time_index: 0
default_end_time_index: 235
Expand Down
26 changes: 0 additions & 26 deletions examples/rl_order_execution/scripts/collect_pickle_dataframe.py

This file was deleted.

27 changes: 19 additions & 8 deletions examples/rl_order_execution/scripts/gen_training_orders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@
import os
import numpy as np
import pandas as pd
from tqdm import tqdm

from pathlib import Path

DATA_PATH = Path(os.path.join("data", "pickle_dataframe", "backtest"))
DATA_PATH = Path(os.path.join("data", "pickle", "backtest"))
OUTPUT_PATH = Path(os.path.join("data", "orders"))


def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
df = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
dataset = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
df = dataset.handler.fetch(level=None).reset_index()
if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5:
return False

df["date"] = df["datetime"].dt.date.astype("datetime64")
df = df.set_index(["instrument", "datetime", "date"])
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
div = df["$volume0"].rolling((end_idx - start_idx) * 60).mean().shift(1).groupby(level="date").transform("first")

order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
Expand All @@ -32,11 +37,17 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
os.makedirs(path, exist_ok=True)
if len(order) > 0:
order.to_pickle(path / f"{stock}.pkl.target")
return True


np.random.seed(1234)
file_list = sorted(os.listdir(DATA_PATH))
stocks = [f.replace(".pkl", "") for f in file_list]
stocks = sorted(np.random.choice(stocks, size=100, replace=False))
for stock in tqdm(stocks):
generate_order(stock, 0, 240 // 5 - 1)
np.random.shuffle(stocks)

cnt = 0
for stock in stocks:
if generate_order(stock, 0, 240 // 5 - 1):
cnt += 1
if cnt == 100:
break
14 changes: 2 additions & 12 deletions qlib/rl/contrib/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,7 @@ def single_with_simulator(
-------
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
else:
day = orders.iloc[0].datetime
init_qlib(backtest_config["qlib"], part=day)
init_qlib(backtest_config["qlib"])

stocks = orders.instrument.unique().tolist()

Expand Down Expand Up @@ -253,12 +248,7 @@ def single_with_collect_data_loop(
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""

if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
else:
day = orders.iloc[0].datetime
init_qlib(backtest_config["qlib"], part=day)
init_qlib(backtest_config["qlib"])

trade_start_time = orders["datetime"].min()
trade_end_time = orders["datetime"].max()
Expand Down
33 changes: 18 additions & 15 deletions qlib/rl/contrib/train_onpolicy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import argparse
import os
import random
Expand All @@ -9,13 +11,12 @@

import numpy as np
import pandas as pd
import qlib
import torch
import yaml
from qlib.backtest import Order
from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
from qlib.rl.data.native import load_handler_intraday_processed_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.reward import Reward
Expand Down Expand Up @@ -49,19 +50,17 @@ def _read_orders(order_dir: Path) -> pd.DataFrame:
class LazyLoadDataset(Dataset):
def __init__(
self,
data_dir: str,
order_file_path: Path,
data_dir: Path,
default_start_time_index: int,
default_end_time_index: int,
) -> None:
self._default_start_time_index = default_start_time_index
self._default_end_time_index = default_end_time_index

self._order_file_path = order_file_path
self._order_df = _read_orders(order_file_path).reset_index()

self._data_dir = data_dir
self._ticks_index: Optional[pd.DatetimeIndex] = None
self._data_dir = Path(data_dir)

def __len__(self) -> int:
return len(self._order_df)
Expand All @@ -74,12 +73,17 @@ def __getitem__(self, index: int) -> Order:
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
# TODO: of all dates.
backtest_data = load_simple_intraday_backtest_data(

data = load_handler_intraday_processed_data(
data_dir=self._data_dir,
stock_id=row["instrument"],
date=date,
feature_columns_today=[],
feature_columns_yesterday=[],
backtest=True,
index_only=True,
)
self._ticks_index = [t - date for t in backtest_data.get_time_index()]
self._ticks_index = [t - date for t in data.today.index]

order = Order(
stock_id=row["instrument"],
Expand All @@ -104,19 +108,18 @@ def train_and_test(
run_training: bool,
run_backtest: bool,
) -> None:
qlib.init()

order_root_path = Path(data_config["source"]["order_dir"])

data_granularity = simulator_config.get("data_granularity", 1)

def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
return SingleAssetOrderExecutionSimple(
order=order,
data_dir=Path(data_config["source"]["data_dir"]),
ticks_per_step=simulator_config["time_per_step"],
data_dir=data_config["source"]["feature_root_dir"],
feature_columns_today=data_config["source"]["feature_columns_today"],
feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"],
data_granularity=data_granularity,
deal_price_type=data_config["source"].get("deal_price_column", "close"),
ticks_per_step=simulator_config["time_per_step"],
vol_threshold=simulator_config["vol_limit"],
)

Expand All @@ -126,8 +129,8 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
if run_training:
train_dataset, valid_dataset = [
LazyLoadDataset(
data_dir=data_config["source"]["feature_root_dir"],
order_file_path=order_root_path / tag,
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
Expand Down Expand Up @@ -178,8 +181,8 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:

if run_backtest:
test_dataset = LazyLoadDataset(
data_dir=data_config["source"]["feature_root_dir"],
order_file_path=order_root_path / "test",
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
Expand Down
Loading

0 comments on commit 7f1e8c5

Please sign in to comment.