Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine RL todos #1332

Merged
merged 12 commits into from Nov 10, 2022
12 changes: 5 additions & 7 deletions qlib/backtest/__init__.py
Expand Up @@ -10,7 +10,6 @@
import pandas as pd

from .account import Account
from .report import Indicator, PortfolioMetrics

if TYPE_CHECKING:
from ..strategy.base import BaseStrategy
Expand All @@ -20,7 +19,7 @@
from ..config import C
from ..log import get_module_logger
from ..utils import init_instance_by_config
from .backtest import backtest_loop, collect_data_loop
from .backtest import INDICATOR_METRIC, PORT_METRIC, backtest_loop, collect_data_loop
from .decision import Order
from .exchange import Exchange
from .utils import CommonInfrastructure
Expand Down Expand Up @@ -223,7 +222,7 @@ def backtest(
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
) -> Tuple[PortfolioMetrics, Indicator]:
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
executor in the nested decision execution

Expand Down Expand Up @@ -256,9 +255,9 @@ def backtest(

Returns
-------
portfolio_metrics_dict: Dict[PortfolioMetrics]
portfolio_dict: PORT_METRIC
it records the trading portfolio_metrics information
indicator_dict: Dict[Indicator]
indicator_dict: INDICATOR_METRIC
it computes the trading indicator
It is organized in a dict format

Expand All @@ -273,8 +272,7 @@ def backtest(
exchange_kwargs,
pos_type=pos_type,
)
portfolio_metrics, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
return portfolio_metrics, indicator
return backtest_loop(start_time, end_time, trade_strategy, trade_executor)


def collect_data(
Expand Down
46 changes: 27 additions & 19 deletions qlib/backtest/backtest.py
Expand Up @@ -3,12 +3,12 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
from typing import Dict, TYPE_CHECKING, Generator, Optional, Tuple, Union, cast

import pandas as pd

from qlib.backtest.decision import BaseTradeDecision
from qlib.backtest.report import Indicator, PortfolioMetrics
from qlib.backtest.report import Indicator

if TYPE_CHECKING:
from qlib.strategy.base import BaseStrategy
Expand All @@ -19,30 +19,35 @@
from ..utils.time import Freq


PORT_METRIC = Dict[str, Tuple[pd.DataFrame, dict]]
INDICATOR_METRIC = Dict[str, Tuple[pd.DataFrame, Indicator]]


def backtest_loop(
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
) -> Tuple[PortfolioMetrics, Indicator]:
lihuoran marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
"""backtest function for the interaction of the outermost strategy and executor in the nested decision execution

please refer to the docs of `collect_data_loop`

Returns
-------
portfolio_metrics: PortfolioMetrics
portfolio_dict: PORT_METRIC
it records the trading portfolio_metrics information
indicator: Indicator
indicator_dict: INDICATOR_METRIC
it computes the trading indicator
"""
return_value: dict = {}
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
pass

portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
indicator = cast(Indicator, return_value.get("indicator"))
return portfolio_metrics, indicator
portfolio_dict = cast(PORT_METRIC, return_value.get("portfolio_dict"))
indicator_dict = cast(INDICATOR_METRIC, return_value.get("indicator_dict"))

return portfolio_dict, indicator_dict


def collect_data_loop(
Expand Down Expand Up @@ -89,14 +94,17 @@ def collect_data_loop(

if return_value is not None:
all_executors = trade_executor.get_all_executors()
all_portfolio_metrics = {
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
for _executor in all_executors
if _executor.trade_account.is_port_metr_enabled()
}
all_indicators = {}
for _executor in all_executors:
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
return_value.update({"portfolio_metrics": all_portfolio_metrics, "indicator": all_indicators})

portfolio_dict: PORT_METRIC = {}
indicator_dict: INDICATOR_METRIC = {}

for executor in all_executors:
key = "{}{}".format(*Freq.parse(executor.time_per_step))
if executor.trade_account.is_port_metr_enabled():
portfolio_dict[key] = executor.trade_account.get_portfolio_metrics()

indicator_df = executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
indicator_obj = executor.trade_account.get_trade_indicator()
indicator_dict[key] = (indicator_df, indicator_obj)
lihuoran marked this conversation as resolved.
Show resolved Hide resolved

return_value.update({"portfolio_dict": portfolio_dict, "indicator_dict": indicator_dict})
63 changes: 51 additions & 12 deletions qlib/backtest/exchange.py
Expand Up @@ -26,6 +26,15 @@


class Exchange:
# `quote_df` is a pd.DataFrame class that contains basic information for backtesting
# After some processing, the data will later be maintained by `quote_cls` object for faster data retriving.
# Some conventions for `quote_df`
# - $close is for calculating the total value at end of each day.
# - if $close is None, the stock on that day is reguarded as suspended.
# - $factor is for rounding to the trading unit;
# - if any $factor is missing when $close exists, trading unit rounding will be disabled
quote_df: pd.DataFrame

def __init__(
self,
freq: str = "day",
Expand Down Expand Up @@ -159,6 +168,7 @@ def __init__(
self.codes = codes
# Necessary fields
# $close is for calculating the total value at end of each day.
# - if $close is None, the stock on that day is reguarded as suspended.
# $factor is for rounding to the trading unit
# $change is for calculating the limit of the stock

Expand Down Expand Up @@ -199,7 +209,7 @@ def get_quote_from_qlib(self) -> None:
self.end_time,
freq=self.freq,
disk_cache=True,
).dropna(subset=["$close"])
)
self.quote_df.columns = self.all_fields

# check buy_price data and sell_price data
Expand All @@ -209,7 +219,7 @@ def get_quote_from_qlib(self) -> None:
self.logger.warning("{} field data contains nan.".format(pstr))

# update trade_w_adj_price
if self.quote_df["$factor"].isna().any():
if (self.quote_df["$factor"].isna() & ~self.quote_df["$close"].isna()).any():
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
# Use adjusted price
self.trade_w_adj_price = True
Expand Down Expand Up @@ -245,9 +255,9 @@ def get_quote_from_qlib(self) -> None:
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0)

LT_TP_EXP = "(exp)" # Tuple[str, str]
LT_FLT = "float" # float
LT_NONE = "none" # none
LT_TP_EXP = "(exp)" # Tuple[str, str]: the limitation is calculated by a Qlib expression.
LT_FLT = "float" # float: the trading limitation is based on `abs($change) < limit_threshold`
LT_NONE = "none" # none: there is no trading limitation

def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
"""get limit type"""
Expand All @@ -261,20 +271,25 @@ def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
raise NotImplementedError(f"This type of `limit_threshold` is not supported")

def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None:
# $close is may contains NaN, the nan indicates that the stock is not tradable at that timestamp
suspended = self.quote_df["$close"].isna()
# check limit_threshold
limit_type = self._get_limit_type(limit_threshold)
if limit_type == self.LT_NONE:
self.quote_df["limit_buy"] = False
self.quote_df["limit_sell"] = False
self.quote_df["limit_buy"] = suspended
self.quote_df["limit_sell"] = suspended
elif limit_type == self.LT_TP_EXP:
# set limit
limit_threshold = cast(tuple, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
# astype bool is necessary, because quote_df is an expression and could be float
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]].astype("bool") | suspended
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]].astype("bool") | suspended
elif limit_type == self.LT_FLT:
limit_threshold = cast(float, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) | suspended
self.quote_df["limit_sell"] = (
self.quote_df["$change"].le(-limit_threshold) | suspended
) # pylint: disable=E1130

@staticmethod
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
Expand Down Expand Up @@ -338,8 +353,18 @@ def check_stock_limit(
- if direction is None, check if tradable for buying and selling.
- if direction == Order.BUY, check the if tradable for buying
- if direction == Order.SELL, check the sell limit for selling.

Returns
-------
True: the trading of the stock is limted (maybe hit the highest/lowest price), hence the stock is not tradable
False: the trading of the stock is not limited, hence the stock may be tradable
"""
# NOTE:
# **all** is used when checking limitation.
# For example, the stock trading is limited in a day if every miniute is limited in a day if every miniute is limited.
if direction is None:
# The trading limitation is related to the trading direction
# if the direction is not provided, then any limitation from buy or sell will result in trading limitation
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
return bool(buy_limit or sell_limit)
Expand All @@ -356,10 +381,24 @@ def check_stock_suspended(
start_time: pd.Timestamp,
end_time: pd.Timestamp,
) -> bool:
"""if stock is suspended(hence not tradable), True will be returned"""
# is suspended
if stock_id in self.quote.get_all_stock():
return self.quote.get_data(stock_id, start_time, end_time, "$close") is None
# suspended stocks are represented by None $close stock
# The $close may contains NaN,
close = self.quote.get_data(stock_id, start_time, end_time, "$close")
if close is None:
# if no close record exists
return True
elif isinstance(close, IndexData):
# **any** non-NaN $close represents trading opportunity may exists
# if all returned is nan, then the stock is suspended
return cast(bool, cast(IndexData, close).isna().all())
else:
# it is single value, make sure is is not None
return np.isnan(close)
else:
# if the stock is not in the stock list, then it is not tradable and regarded as suspended
return True

def is_stock_tradable(
Expand Down