Skip to content

Commit

Permalink
Make static prediction easier
Browse files Browse the repository at this point in the history
  • Loading branch information
you-n-g committed Oct 15, 2021
1 parent 2e49a5f commit ac08468
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 57 deletions.
4 changes: 2 additions & 2 deletions examples/nested_decision_execution/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def backtest(self):
self._train_model(model, dataset)
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
Expand Down Expand Up @@ -189,7 +189,7 @@ def collect_data(self):
backtest_config["benchmark"] = self.benchmark
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
Expand Down
2 changes: 1 addition & 1 deletion examples/workflow_by_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
},
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
Expand Down
83 changes: 83 additions & 0 deletions qlib/backtest/signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from ..model.base import BaseModel
from ..data.dataset import Dataset
from ..data.dataset.utils import convert_index_format
from ..utils.resam import resam_ts_data
import pandas as pd
import abc


class Signal(metaclass=abc.ABCMeta):
"""
Some trading strategy make decisions based on other prediction signals
The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset)
This
"""

@abc.abstractmethod
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
"""
get the signal at the end of the decision step(from `start_time` to `end_time`)
Returns
-------
Union[pd.Series, pd.DataFrame, None]:
returns None if no signal in the specific day
"""
...


class SignalWCache(Signal):
"""
Signal With pandas with based Cache
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
"""

def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
"""
Parameters
----------
signal : Union[pd.Series, pd.DataFrame]
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
instrument datetime
SH600000 2008-01-02 0.079704
2008-01-03 0.120125
2008-01-04 0.878860
2008-01-07 0.505539
2008-01-08 0.395004
"""
self.signal_cache = convert_index_format(signal, level="datetime")

def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
# the frequency of the signal may not algin with the decision frequency of strategy
# so resampling from the data is necessary
# the latest signal leverage more recent data and therefore is used in trading.
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
return signal


class ModelSignal(SignalWCache):
...

def __init__(self, model: BaseModel, dataset: Dataset):
self.model = model
self.dataset = dataset
pred_scores = self.model.predict(dataset)
if isinstance(pred_scores, pd.DataFrame):
pred_scores = pred_scores.iloc[:, 0]
super().__init__(pred_scores)

def _update_model(self):
"""
When using online data, update model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
"""
# TODO: this method is not included in the framework and could be refactor later
raise NotImplementedError("_update_model is not implemented!")
2 changes: 1 addition & 1 deletion qlib/contrib/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.


from .model_strategy import (
from .signal_strategy import (
TopkDropoutStrategy,
WeightStrategyBase,
)
Expand Down
2 changes: 1 addition & 1 deletion qlib/contrib/strategy/cost_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


from .order_generator import OrderGenWInteract
from .model_strategy import WeightStrategyBase
from .signal_strategy import WeightStrategyBase
import copy


Expand Down
2 changes: 2 additions & 0 deletions qlib/contrib/strategy/rule_strategy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
import warnings
import numpy as np
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
import copy
from qlib.backtest.signal import ModelSignal, Signal, SignalWCache
from typing import Union
from qlib.data.dataset import Dataset
from qlib.model.base import BaseModel
from qlib.backtest.position import Position
import warnings
import numpy as np
import pandas as pd

from ...utils.resam import resam_ts_data
from ...strategy.base import ModelStrategy
from ...strategy.base import BaseStrategy
from ...backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO

from .order_generator import OrderGenWInteract


class TopkDropoutStrategy(ModelStrategy):
class TopkDropoutStrategy(BaseStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
model,
dataset,
*,
topk,
n_drop,
model: BaseModel = None,
dataset: Dataset = None,
signal: Union[pd.DataFrame, pd.Series] = None,
method_sell="bottom",
method_buy="top",
risk_degree=0.95,
Expand Down Expand Up @@ -64,7 +70,7 @@ def __init__(
"""
super(TopkDropoutStrategy, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
)
self.topk = topk
self.n_drop = n_drop
Expand All @@ -73,6 +79,8 @@ def __init__(
self.risk_degree = risk_degree
self.hold_thresh = hold_thresh
self.only_tradable = only_tradable
assert signal is not None or dataset is not None and model is not None
self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal)

def get_risk_degree(self, trade_step=None):
"""get_risk_degree
Expand All @@ -87,7 +95,7 @@ def generate_trade_decision(self, execute_result=None):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
if pred_score is None:
return TradeDecisionWO([], self)
if self.only_tradable:
Expand Down Expand Up @@ -235,15 +243,17 @@ def filter_stock(l):
return TradeDecisionWO(sell_order_list + buy_order_list, self)


class WeightStrategyBase(ModelStrategy):
class WeightStrategyBase(BaseStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
model,
dataset,
*,
model: BaseModel = None,
dataset: Dataset = None,
signal: Union[pd.DataFrame, pd.Series] = None,
order_generator_cls_or_obj=OrderGenWInteract,
trade_exchange=None,
level_infra=None,
Expand All @@ -260,12 +270,14 @@ def __init__(
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
"""
super(WeightStrategyBase, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
)
if isinstance(order_generator_cls_or_obj, type):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
assert signal is not None or dataset is not None and model is not None
self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal)

def get_risk_degree(self, trade_step=None):
"""get_risk_degree
Expand Down Expand Up @@ -298,7 +310,7 @@ def generate_trade_decision(self, execute_result=None):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
if pred_score is None:
return TradeDecisionWO([], self)
current_temp = copy.deepcopy(self.trade_position)
Expand Down
41 changes: 1 addition & 40 deletions qlib/strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
from ..backtest.decision import BaseTradeDecision

__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"]
__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"]


class BaseStrategy:
Expand Down Expand Up @@ -194,45 +194,6 @@ def get_data_cal_avail_range(self, rtype: str = "full") -> Tuple[int, int]:
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])


class ModelStrategy(BaseStrategy):
"""Model-based trading strategy, use model to make predictions for trading"""

def __init__(
self,
model: BaseModel,
dataset: DatasetH,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
Parameters
----------
model : BaseModel
the model used in when making predictions
dataset : DatasetH
provide test data for model
kwargs : dict
arguments that will be passed into `reset` method
"""
super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
self.model = model
self.dataset = dataset
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
if isinstance(self.pred_scores, pd.DataFrame):
self.pred_scores = self.pred_scores.iloc[:, 0]

def _update_model(self):
"""
When using online data, pdate model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
"""
raise NotImplementedError("_update_model is not implemented!")


class RLStrategy(BaseStrategy):
"""RL-based strategy"""

Expand Down
2 changes: 1 addition & 1 deletion tests/test_all_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def backtest_analysis(pred, rid, uri_path: str = None):
},
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
Expand Down

0 comments on commit ac08468

Please sign in to comment.