From f925b8a88bef04d58ec842d71489fb245ba4ac5b Mon Sep 17 00:00:00 2001 From: lwwang Date: Wed, 5 Jul 2023 23:26:16 +0800 Subject: [PATCH 1/5] Update algorithm trading demo. --- .../exp_configs/train_at_opds.yml | 62 ++++ qlib/rl/algorithm_trading/__init__.py | 29 ++ qlib/rl/algorithm_trading/interpreter.py | 174 ++++++++++ qlib/rl/algorithm_trading/network.py | 141 ++++++++ qlib/rl/algorithm_trading/policy.py | 187 ++++++++++ qlib/rl/algorithm_trading/reward.py | 39 +++ qlib/rl/algorithm_trading/simulator_simple.py | 321 ++++++++++++++++++ qlib/rl/algorithm_trading/state.py | 92 +++++ qlib/rl/contrib/train_at_onpolicy.py | 261 ++++++++++++++ 9 files changed, 1306 insertions(+) create mode 100755 examples/rl_algorithm_trading/exp_configs/train_at_opds.yml create mode 100644 qlib/rl/algorithm_trading/__init__.py create mode 100644 qlib/rl/algorithm_trading/interpreter.py create mode 100644 qlib/rl/algorithm_trading/network.py create mode 100644 qlib/rl/algorithm_trading/policy.py create mode 100644 qlib/rl/algorithm_trading/reward.py create mode 100644 qlib/rl/algorithm_trading/simulator_simple.py create mode 100644 qlib/rl/algorithm_trading/state.py create mode 100644 qlib/rl/contrib/train_at_onpolicy.py diff --git a/examples/rl_algorithm_trading/exp_configs/train_at_opds.yml b/examples/rl_algorithm_trading/exp_configs/train_at_opds.yml new file mode 100755 index 0000000000..3875119217 --- /dev/null +++ b/examples/rl_algorithm_trading/exp_configs/train_at_opds.yml @@ -0,0 +1,62 @@ +simulator: + data_granularity: 5 + time_per_step: 30 + vol_limit: null + fee_rate: 0.002 +env: + concurrency: 24 + parallel_mode: shmem +action_interpreter: + class: CategoricalATActionInterpreter + kwargs: + values: [-1, 0, 1] + max_step: 8 + module_path: qlib.rl.algorithm_trading.interpreter +state_interpreter: + class: FullHistoryATStateInterpreter + kwargs: + data_dim: 5 + data_ticks: 48 # 48 = 240 min / 5 min + max_step: 8 + processed_data_provider: + class: PickleProcessedDataProvider + module_path: qlib.rl.data.pickle_styled + kwargs: + data_dir: ./data/pickle_dataframe/feature + module_path: qlib.rl.algorithm_trading.interpreter +reward: + class: LongShortReward + kwargs: + trans_fee: 0.002 + scale: 1000 + module_path: qlib.rl.algorithm_trading.reward +data: + source: + task_dir: ./data/tasks + data_dir: ./data/pickle_dataframe/backtest + total_time: 240 + default_start_time_index: 0 + default_end_time_index: 235 + proc_data_dim: 5 + num_workers: 0 + queue_size: 20 +network: + class: Recurrent + module_path: qlib.rl.algorithm_trading.network +policy: + class: PPO + kwargs: + lr: 0.0001 + module_path: qlib.rl.order_execution.policy +runtime: + seed: 42 + use_cuda: false +trainer: + max_epoch: 500 + repeat_per_collect: 20 + earlystop_patience: 5 + episode_per_collect: 10000 + batch_size: 1024 + val_every_n_epoch: 5 + checkpoint_path: ./outputs/algorithm_trading + checkpoint_every_n_iters: 1 diff --git a/qlib/rl/algorithm_trading/__init__.py b/qlib/rl/algorithm_trading/__init__.py new file mode 100644 index 0000000000..f7a7001695 --- /dev/null +++ b/qlib/rl/algorithm_trading/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Currently it supports single-asset order execution. +Multi-asset is on the way. +""" + +from .interpreter import ( + FullHistoryATStateInterpreter, + CategoricalATActionInterpreter, +) +from .network import Recurrent +from .policy import AllOne, PPO +from .reward import LongShortReward +from .simulator_simple import SingleAssetAlgorithmTradingSimple +from .state import SAATMetrics, SAATState + +__all__ = [ + "FullHistoryATStateInterpreter", + "CategoricalATActionInterpreter", + "Recurrent", + "AllOne", + "PPO", + "LongShortReward", + "SingleAssetAlgorithmTradingSimple", + "SAATMetrics", + "SAATState", +] diff --git a/qlib/rl/algorithm_trading/interpreter.py b/qlib/rl/algorithm_trading/interpreter.py new file mode 100644 index 0000000000..44af2e2f20 --- /dev/null +++ b/qlib/rl/algorithm_trading/interpreter.py @@ -0,0 +1,174 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import math +from typing import Any, List, Optional, cast + +import numpy as np +import pandas as pd +from gym import spaces + +from qlib.constant import EPS +from qlib.rl.data.base import ProcessedDataProvider +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter +from qlib.rl.algorithm_trading.state import SAATState +from qlib.typehint import TypedDict + +__all__ = [ + "FullHistoryATStateInterpreter", + "CategoricalATActionInterpreter", + "FullHistoryATObs", +] + +from qlib.utils import init_instance_by_config + + +def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict: + """To 32-bit numeric types. Recursively.""" + if isinstance(value, pd.DataFrame): + return value.to_numpy() + if isinstance(value, (float, np.floating)) or (isinstance(value, np.ndarray) and value.dtype.kind == "f"): + return np.array(value, dtype=np.float32) + elif isinstance(value, (int, bool, np.integer)) or (isinstance(value, np.ndarray) and value.dtype.kind == "i"): + return np.array(value, dtype=np.int32) + elif isinstance(value, dict): + return {k: canonicalize(v) for k, v in value.items()} + else: + return value + + +class FullHistoryATObs(TypedDict): + data_processed: Any + data_processed_prev: Any + cur_tick: Any + cur_step: Any + num_step: Any + position: Any + position_history: Any + + +class FullHistoryATStateInterpreter(StateInterpreter[SAATState, FullHistoryATObs]): + """The observation of all the history, including today (until this moment), and yesterday. + + Parameters + ---------- + max_step + Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps. + data_ticks + Equal to the total number of records. For example, in SAAT per minute, + the total ticks is the length of day in minutes. + data_dim + Number of dimensions in data. + processed_data_provider + Provider of the processed data. + """ + + def __init__( + self, + max_step: int, + data_ticks: int, + data_dim: int, + processed_data_provider: dict | ProcessedDataProvider, + ) -> None: + super().__init__() + + self.max_step = max_step + self.data_ticks = data_ticks + self.data_dim = data_dim + self.processed_data_provider: ProcessedDataProvider = init_instance_by_config( + processed_data_provider, + accept_types=ProcessedDataProvider, + ) + + def interpret(self, state: SAATState) -> FullHistoryATObs: + processed = self.processed_data_provider.get_data( + stock_id=state.task.stock_id, + date=pd.Timestamp(state.task.start_time.date()), + feature_dim=self.data_dim, + time_index=state.ticks_index, + ) + + position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32) # Initialize position is 0 + position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy() + + # The min, slice here are to make sure that indices fit into the range, + # even after the final step of the simulator (in the done step), + # to make network in policy happy. + return cast( + FullHistoryATObs, + canonicalize( + { + "data_processed": np.array(self._mask_future_info(processed.today, state.cur_time)), + "data_processed_prev": np.array(processed.yesterday), + "cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)), + "cur_step": _to_int32(min(state.cur_step, self.max_step - 1)), + "num_step": _to_int32(self.max_step), + "position": _to_float32(state.position), + "position_history": _to_float32(position_history[: self.max_step]), + }, + ), + ) + + @property + def observation_space(self) -> spaces.Dict: + space = { + "data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), + "data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), + "cur_tick": spaces.Box(0, self.data_ticks - 1, shape=(), dtype=np.int32), + "cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32), + # TODO: support arbitrary length index + "num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32), + "position": spaces.Box(-np.inf, np.inf, shape=()), + "position_history": spaces.Box(-np.inf, np.inf, shape=(self.max_step,)), + } + return spaces.Dict(space) + + @staticmethod + def _mask_future_info(arr: pd.DataFrame, current: pd.Timestamp) -> pd.DataFrame: + arr = arr.copy(deep=True) + arr.loc[current:] = 0.0 # mask out data after this moment (inclusive) + return arr + + +class CategoricalATActionInterpreter(ActionInterpreter[SAATState, int, float]): + """Convert a discrete policy action to a continuous action, then multiplied by ``task.cash``. + + Parameters + ---------- + values + It can be a list of length $L$: $[a_1, a_2, \\ldots, a_L]$. + Then when policy givens decision $x$, $a_x$ times order amount is the output. + It can also be an integer $n$, in which case the list of length $n+1$ is auto-generated, + i.e., $[0, 1/n, 2/n, \\ldots, n/n]$. + max_step + Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps. + """ + + def __init__(self, values: List[int], max_step: Optional[int] = None) -> None: + super().__init__() + + self.action_values = values + self.max_step = max_step + + @property + def action_space(self) -> spaces.Discrete: + return spaces.Discrete(len(self.action_values)) + + def interpret(self, state: SAATState, action: int) -> str: + assert 0 <= action < len(self.action_values) + if self.action_values[action] == -1: + return "short" + elif self.action_values[action] == 1: + return "long" + else: + return "hold" + + +def _to_int32(val): + return np.array(int(val), dtype=np.int32) + + +def _to_float32(val): + return np.array(val, dtype=np.float32) diff --git a/qlib/rl/algorithm_trading/network.py b/qlib/rl/algorithm_trading/network.py new file mode 100644 index 0000000000..fc2fffa842 --- /dev/null +++ b/qlib/rl/algorithm_trading/network.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import List, Tuple, cast + +import torch +import torch.nn as nn +from tianshou.data import Batch + +from qlib.typehint import Literal + +from .interpreter import FullHistoryATObs + +__all__ = ["Recurrent"] + + +class Recurrent(nn.Module): + """The network architecture proposed in `OPD `_. + + At every time step the input of policy network is divided into two parts, + the public variables and the private variables. which are handled by ``raw_rnn`` + and ``pri_rnn`` in this network, respectively. + + One minor difference is that, in this implementation, we don't assume the direction to be fixed. + Thus, another ``dire_fc`` is added to produce an extra direction-related feature. + """ + + def __init__( + self, + obs_space: FullHistoryATObs, + hidden_dim: int = 64, + output_dim: int = 32, + rnn_type: Literal["rnn", "lstm", "gru"] = "gru", + rnn_num_layers: int = 1, + ) -> None: + super().__init__() + + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.num_sources = 3 + + rnn_classes = {"rnn": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU} + + self.rnn_class = rnn_classes[rnn_type] + self.rnn_layers = rnn_num_layers + + self.raw_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers) + self.prev_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers) + self.pri_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers) + + self.raw_fc = nn.Sequential(nn.Linear(obs_space["data_processed"].shape[-1], hidden_dim), nn.ReLU()) + self.pri_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU()) + self.dire_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU()) + + self._init_extra_branches() + + self.fc = nn.Sequential( + nn.Linear(hidden_dim * self.num_sources, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim), + nn.ReLU(), + ) + + def _init_extra_branches(self) -> None: + pass + + def _source_features(self, obs: FullHistoryATObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]: + bs, _, data_dim = obs["data_processed"].size() + data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1) + cur_step = obs["cur_step"].long() + cur_tick = obs["cur_tick"].long() + bs_indices = torch.arange(bs, device=device) + + # position = obs["position_history"] / obs["target"].unsqueeze(-1) # [bs, num_step] + position = obs["position_history"].sign() + steps = ( + torch.arange(position.size(-1), device=device).unsqueeze(0).repeat(bs, 1).float() + / obs["num_step"].unsqueeze(-1).float() + ) # [bs, num_step] + priv = torch.stack((position.float(), steps), -1) + + data_in = self.raw_fc(data) + data_out, _ = self.raw_rnn(data_in) + # as it is padded with zero in front, this should be last minute + data_out_slice = data_out[bs_indices, cur_tick] + + priv_in = self.pri_fc(priv) + priv_out = self.pri_rnn(priv_in)[0] + priv_out = priv_out[bs_indices, cur_step] + + sources = [data_out_slice, priv_out] + + dir_out = self.dire_fc(torch.stack((obs["position"], -obs["position"]), -1).float()) + sources.append(dir_out) + + return sources, data_out + + def forward(self, batch: Batch) -> torch.Tensor: + """ + Input should be a dict (at least) containing: + + - data_processed: [N, T, C] + - cur_step: [N] (int) + - cur_time: [N] (int) + - position_history: [N, S] (S is number of steps) + - target: [N] + - num_step: [N] (int) + - acquiring: [N] (0 or 1) + """ + + inp = cast(FullHistoryATObs, batch) + device = inp["data_processed"].device + + sources, _ = self._source_features(inp, device) + assert len(sources) == self.num_sources + + out = torch.cat(sources, -1) + return self.fc(out) + + +class Attention(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.q_net = nn.Linear(in_dim, out_dim) + self.k_net = nn.Linear(in_dim, out_dim) + self.v_net = nn.Linear(in_dim, out_dim) + + def forward(self, Q, K, V): + q = self.q_net(Q) + k = self.k_net(K) + v = self.v_net(V) + + attn = torch.einsum("ijk,ilk->ijl", q, k) + attn = attn.to(Q.device) + attn_prob = torch.softmax(attn, dim=-1) + + attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v) + + return attn_vec diff --git a/qlib/rl/algorithm_trading/policy.py b/qlib/rl/algorithm_trading/policy.py new file mode 100644 index 0000000000..2102ff6ab9 --- /dev/null +++ b/qlib/rl/algorithm_trading/policy.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, Generator, Iterable, Optional, OrderedDict, Tuple, cast + +import gym +import numpy as np +import torch +import torch.nn as nn +from gym.spaces import Discrete +from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.policy import BasePolicy, PPOPolicy + +from qlib.rl.trainer.trainer import Trainer + +__all__ = ["AllOne", "PPO"] + + +# baselines # + + +class NonLearnablePolicy(BasePolicy): + """Tianshou's BasePolicy with empty ``learn`` and ``process_fn``. + + This could be moved outside in future. + """ + + def __init__(self, obs_space: gym.Space, action_space: gym.Space) -> None: + super().__init__() + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]: + return {} + + def process_fn( + self, + batch: Batch, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> Batch: + return Batch({}) + + +class AllOne(NonLearnablePolicy): + """Forward returns a batch full of 1. + + Useful when implementing some baselines (e.g., TWAP). + """ + + def __init__(self, obs_space: gym.Space, action_space: gym.Space, fill_value: float | int = 1.0) -> None: + super().__init__(obs_space, action_space) + + self.fill_value = fill_value + + def forward( + self, + batch: Batch, + state: dict | Batch | np.ndarray = None, + **kwargs: Any, + ) -> Batch: + return Batch(act=np.full(len(batch), self.fill_value), state=state) + + +# ppo # + + +class PPOActor(nn.Module): + def __init__(self, extractor: nn.Module, action_dim: int) -> None: + super().__init__() + self.extractor = extractor + self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1)) + + def forward( + self, + obs: torch.Tensor, + state: torch.Tensor = None, + info: dict = {}, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + feature = self.extractor(to_torch(obs, device=auto_device(self))) + out = self.layer_out(feature) + return out, state + + +class PPOCritic(nn.Module): + def __init__(self, extractor: nn.Module) -> None: + super().__init__() + self.extractor = extractor + self.value_out = nn.Linear(cast(int, extractor.output_dim), 1) + + def forward( + self, + obs: torch.Tensor, + state: torch.Tensor = None, + info: dict = {}, + ) -> torch.Tensor: + feature = self.extractor(to_torch(obs, device=auto_device(self))) + return self.value_out(feature).squeeze(dim=-1) + + +class PPO(PPOPolicy): + """A wrapper of tianshou PPOPolicy. + + Differences: + + - Auto-create actor and critic network. Supports discrete action space only. + - Dedup common parameters between actor network and critic network + (not sure whether this is included in latest tianshou or not). + - Support a ``weight_file`` that supports loading checkpoint. + - Some parameters' default values are different from original. + """ + + def __init__( + self, + network: nn.Module, + obs_space: gym.Space, + action_space: gym.Space, + lr: float, + weight_decay: float = 0.0, + discount_factor: float = 1.0, + max_grad_norm: float = 100.0, + reward_normalization: bool = True, + eps_clip: float = 0.3, + value_clip: bool = True, + vf_coef: float = 1.0, + gae_lambda: float = 1.0, + max_batch_size: int = 256, + deterministic_eval: bool = True, + weight_file: Optional[Path] = None, + ) -> None: + assert isinstance(action_space, Discrete) + actor = PPOActor(network, action_space.n) + critic = PPOCritic(network) + optimizer = torch.optim.Adam( + chain_dedup(actor.parameters(), critic.parameters()), + lr=lr, + weight_decay=weight_decay, + ) + super().__init__( + actor, + critic, + optimizer, + torch.distributions.Categorical, + discount_factor=discount_factor, + max_grad_norm=max_grad_norm, + reward_normalization=reward_normalization, + eps_clip=eps_clip, + value_clip=value_clip, + vf_coef=vf_coef, + gae_lambda=gae_lambda, + max_batchsize=max_batch_size, + deterministic_eval=deterministic_eval, + observation_space=obs_space, + action_space=action_space, + ) + if weight_file is not None: + set_weight(self, Trainer.get_policy_state_dict(weight_file)) + + +# utilities: these should be put in a separate (common) file. # + + +def auto_device(module: nn.Module) -> torch.device: + for param in module.parameters(): + return param.device + return torch.device("cpu") # fallback to cpu + + +def set_weight(policy: nn.Module, loaded_weight: OrderedDict) -> None: + try: + policy.load_state_dict(loaded_weight) + except RuntimeError: + # try again by loading the converted weight + # https://github.com/thu-ml/tianshou/issues/468 + for k in list(loaded_weight): + loaded_weight["_actor_critic." + k] = loaded_weight[k] + policy.load_state_dict(loaded_weight) + + +def chain_dedup(*iterables: Iterable) -> Generator[Any, None, None]: + seen = set() + for iterable in iterables: + for i in iterable: + if i not in seen: + seen.add(i) + yield i diff --git a/qlib/rl/algorithm_trading/reward.py b/qlib/rl/algorithm_trading/reward.py new file mode 100644 index 0000000000..05cd5e0491 --- /dev/null +++ b/qlib/rl/algorithm_trading/reward.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import cast + +import numpy as np + +from qlib.rl.algorithm_trading.state import SAATMetrics, SAATState +from qlib.rl.reward import Reward + +__all__ = ["LongShortReward"] + + +class LongShortReward(Reward[SAATState]): + """Encourage higher return considering transaction cost with both long and short operation. + Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`. + + Parameters + ---------- + trans_fee + The cost for opening or closing a position. + """ + + def __init__(self, trans_fee: float = 0.0015, scale: float = 10.0) -> None: + self.trans_fee = trans_fee + self.scale = scale + + def reward(self, simulator_state: SAATState) -> float: + last_step = cast(SAATMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict()) + self.log("reward/ret_with_transfee", last_step["ret"]) + self.log("reward/trans_fee", last_step["swap_value"] * self.trans_fee) + reward = last_step["ret"] / last_step["total_cash"] + self.log("reward_total", reward) + # Throw error in case of NaN + assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}" + + return reward * self.scale diff --git a/qlib/rl/algorithm_trading/simulator_simple.py b/qlib/rl/algorithm_trading/simulator_simple.py new file mode 100644 index 0000000000..ae71a9e811 --- /dev/null +++ b/qlib/rl/algorithm_trading/simulator_simple.py @@ -0,0 +1,321 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from pathlib import Path +from typing import Any, cast, Optional + +import numpy as np +import pandas as pd +from qlib.backtest.decision import Task +from qlib.constant import EPS, EPS_T +from qlib.rl.data.pickle_styled import DealPriceType, load_simple_intraday_backtest_data +from qlib.rl.simulator import Simulator +from qlib.rl.utils import LogLevel + +from .state import SAATMetrics, SAATState + +__all__ = ["SingleAssetAlgorithmTradingSimple"] + + +class SingleAssetAlgorithmTradingSimple(Simulator[Task, SAATState, float]): + """Single-asset algorithm trading (SAAT) simulator. + + As there's no "calendar" in the simple simulator, ticks are used to trade. + A tick is a record (a line) in the pickle-styled data file. + Each tick is considered as a individual trading opportunity. + If such fine granularity is not needed, use ``ticks_per_step`` to + lengthen the ticks for each step. + + In each step, the traded amount are "equally" separated to each tick, + then bounded by volume maximum execution volume (i.e., ``vol_threshold``), + and if it's the last step, try to ensure all the amount to be executed. + + Parameters + ---------- + task + The seed to start an SAAT simulator is an task. + data_granularity + Number of ticks between consecutive data entries. + ticks_per_step + How many ticks per step. + data_dir + Path to load backtest data + """ + + history_exec: pd.DataFrame + """All execution history at every possible time ticks. See :class:`SAATMetrics` for available columns. + Index is ``datetime``. + """ + + history_steps: pd.DataFrame + """Positions at each step. The position before first step is also recorded. + See :class:`SAATMetrics` for available columns. + Index is ``datetime``, which is the **starting** time of each step.""" + + metrics: Optional[SAATMetrics] + """Metrics. Only available when done.""" + + ticks_index: pd.DatetimeIndex + """All available ticks for the day (not restricted to task).""" + + ticks_for_trading: pd.DatetimeIndex + """Ticks that is available for trading (sliced by task).""" + + def __init__( + self, + task: Task, + data_dir: Path, + fee_rate: float, + data_granularity: int = 1, + ticks_per_step: int = 30, + deal_price_type: DealPriceType = "close", + ) -> None: + super().__init__(initial=task) + + assert ticks_per_step % data_granularity == 0 + + self.task = task + self.ticks_per_step: int = ticks_per_step // data_granularity + self.deal_price_type = deal_price_type + self.data_dir = data_dir + self.fee_rate = fee_rate + self.backtest_data = load_simple_intraday_backtest_data( + self.data_dir, + task.stock_id, + pd.Timestamp(task.start_time.date()), + self.deal_price_type, + 2, + ) + + self.ticks_index = self.backtest_data.get_time_index() + + # Get time index available for trading + self.ticks_for_trading = self._get_ticks_slice(self.task.start_time, self.task.end_time) + + self.cur_time = self.ticks_for_trading[0] + self.cur_step = 0 + # NOTE: astype(float) is necessary in some systems. + # this will align the precision with `.to_numpy()` in `_split_exec_vol` + self.current_cash = task.cash + self.total_cash = task.cash + self.position = 0 + + metric_keys = list(SAATMetrics.__annotations__.keys()) # pylint: disable=no-member + # NOTE: can empty dataframe contain index? + self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime") + self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") + self.metrics = None + + self.market_price: Optional[np.ndarray] = None + self.market_vol: Optional[np.ndarray] = None + self.market_vol_limit: Optional[np.ndarray] = None + + def step(self, action: str) -> None: + """Execute one step or SAAT. + + Parameters + ---------- + amount + The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt. + """ + + assert not self.done() + self.market_price = self.market_vol = None # avoid misuse + trading_value = self.take_action(action) + assert self.market_price is not None + assert self.market_vol is not None + + if abs(self.position) < 1e-6: + self.position = 0.0 + if abs(self.current_cash) < 1e-6: + self.current_cash = 0.0 + if trading_value < 1e-6: + trading_value = 0 + + ret = self.position * (self.market_price[-1] - self.market_price[0]) + + # Get time index available for this step + time_index = self._get_ticks_slice(self.cur_time, self._next_time()) + + self.history_exec = self._dataframe_append( + self.history_exec, + SAATMetrics( + # It should have the same keys with SAOEMetrics, + # but the values do not necessarily have the annotated type. + # Some values could be vectorized (e.g., exec_vol). + stock_id=self.task.stock_id, + datetime=time_index, + direction=2, # other: 2 + market_volume=self.market_vol, + market_price=self.market_price, + action=action, + cash=self.current_cash, + total_cash=self.total_cash, + position=self.position, + trade_price=self.market_price[0], + ret=ret, + swap_value=trading_value, + ), + ) + + self.history_steps = self._dataframe_append( + self.history_steps, + [ + SAATMetrics( + # It should have the same keys with SAOEMetrics, + # but the values do not necessarily have the annotated type. + # Some values could be vectorized (e.g., exec_vol). + stock_id=self.task.stock_id, + datetime=time_index, + direction=2, # other: 2 + market_volume=self.market_vol, + market_price=self.market_price, + action=action, + trading_value=trading_value, + cash=self.current_cash, + total_cash=self.total_cash, + position=self.position, + trade_price=self.market_price[0], + ret=ret, + swap_value=trading_value, + ) + ], + ) + + if self.done(): + if self.env is not None: + self.env.logger.add_any("history_steps", self.history_steps, loglevel=LogLevel.DEBUG) + self.env.logger.add_any("history_exec", self.history_exec, loglevel=LogLevel.DEBUG) + + self.metrics = ( + SAATMetrics( + stock_id=self.task.stock_id, + datetime=time_index, + direction=self.task.direction, # other: 2 + market_volume=self.history_steps["market_vol"].sum(), + market_price=self.market_price[0], + action=action, + trading_value=self.history_steps["trading_value"].sum(), + cash=self.current_cash, + position=self.position, + trade_price=self.history_steps["trade_price"].mean(), + ret=self.history_steps["ret"].sum(), + swap_value=self.history_steps["trading_value"].sum(), + ), + ) + + # NOTE (yuge): It looks to me that it's the "correct" decision to + # put all the logs here, because only components like simulators themselves + # have the knowledge about what could appear in the logs, and what's the format. + # But I admit it's not necessarily the most convenient way. + # I'll rethink about it when we have the second environment + # Maybe some APIs like self.logger.enable_auto_log() ? + + if self.env is not None: + for key, value in self.metrics.items(): + if isinstance(value, float): + self.env.logger.add_scalar(key, value) + else: + self.env.logger.add_any(key, value) + + self.cur_time = self._next_time() + self.cur_step += 1 + + def get_state(self) -> SAATState: + return SAATState( + task=self.task, + cur_time=self.cur_time, + cur_step=self.cur_step, + position=self.position, + cash=self.current_cash, + history_exec=self.history_exec, + history_steps=self.history_steps, + metrics=self.metrics, + backtest_data=self.backtest_data, + ticks_per_step=self.ticks_per_step, + ticks_index=self.ticks_index, + ticks_for_trading=self.ticks_for_trading, + ) + + def done(self) -> bool: + return self.cur_time >= self.task.end_time + + def _next_time(self) -> pd.Timestamp: + """The "current time" (``cur_time``) for next step.""" + # Look for next time on time index + current_loc = self.ticks_index.get_loc(self.cur_time) + next_loc = current_loc + self.ticks_per_step + + # Calibrate the next location to multiple of ticks_per_step. + # This is to make sure that: + # as long as ticks_per_step is a multiple of something, each step won't cross morning and afternoon. + next_loc = next_loc - next_loc % self.ticks_per_step + + if next_loc < len(self.ticks_index) and self.ticks_index[next_loc] < self.task.end_time: + return self.ticks_index[next_loc] + else: + return self.task.end_time + + def _cur_duration(self) -> pd.Timedelta: + """The "duration" of this step (step that is about to happen).""" + return self._next_time() - self.cur_time + + def take_action(self, action: str) -> np.ndarray: + """ + Split the volume in each step into minutes, considering possible constraints. + This follows TWAP strategy. + """ + next_time = self._next_time() + + # get the backtest data for next interval + self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - EPS_T].to_numpy() + self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - EPS_T].to_numpy() + + assert self.market_vol is not None and self.market_price is not None + + if next_time >= self.task.end_time and not self.position: + trading_value = abs(self.market_price[-1] * self.position) + self.current_cash += trading_value - self.fee_rate * trading_value + self.position = 0 + + if self.position == 0: + if action == "long": + trading_value = self.current_cash + self.position = self.current_cash * (1 - self.fee_rate) / self.market_price[0] + self.current_cash = 0 + elif action == "short": + trading_value = self.current_cash + self.position = -self.current_cash * (1 - self.fee_rate) / self.market_price[0] + self.current_cash = 0 + else: + trading_value = 0 + elif self.position > 0: + if action == "long" or action == "hold": + trading_value = 0 + else: + trading_value = 2 * abs(self.market_price[0] * self.position) + self.position = -self.position * (1 - self.fee_rate) ** 2 + self.current_cash = 0 + else: + if action == "short" or action == "hold": + trading_value = 0 + else: + trading_value = 2 * abs(self.market_price[0] * self.position) + self.position = -self.position * (1 - self.fee_rate) ** 2 + self.current_cash = 0 + + return trading_value + + def _get_ticks_slice(self, start: pd.Timestamp, end: pd.Timestamp, include_end: bool = False) -> pd.DatetimeIndex: + if not include_end: + end = end - EPS_T + return self.ticks_index[self.ticks_index.slice_indexer(start, end)] + + @staticmethod + def _dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame: + # dataframe.append is deprecated + other_df = pd.DataFrame(other).set_index("datetime") + other_df.index.name = "datetime" + return pd.concat([df, other_df], axis=0) diff --git a/qlib/rl/algorithm_trading/state.py b/qlib/rl/algorithm_trading/state.py new file mode 100644 index 0000000000..0868bea9ee --- /dev/null +++ b/qlib/rl/algorithm_trading/state.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import typing +from typing import NamedTuple, Optional + +import numpy as np +import pandas as pd +from qlib.backtest.decision import Task +from qlib.typehint import TypedDict + +if typing.TYPE_CHECKING: + from qlib.rl.data.base import BaseIntradayBacktestData + + +class SAATMetrics(TypedDict): + """Metrics for SAAT accumulated for a "period". + It could be accumulated for a day, or a period of time (e.g., 30min), or calculated separately for every minute. + + Warnings + -------- + The type hints are for single elements. In lots of times, they can be vectorized. + For example, ``market_volume`` could be a list of float (or ndarray) rather tahn a single float. + """ + + stock_id: str + """Stock ID of this record.""" + datetime: pd.Timestamp | pd.DatetimeIndex + """Datetime of this record (this is index in the dataframe).""" + direction: int + """Direction to support reuse order. 0 for sell, 1 for buy, 2 for algorithm trading.""" + + # Market information. + market_volume: np.ndarray | float + """(total) market volume traded in the period.""" + market_price: np.ndarray | float + """Deal price. If it's a period of time, this is the average market deal price.""" + + # Strategy records. + action: np.ndarray | float + """Next step action.""" + trade_price: np.ndarray | float + """The average deal price for this strategy.""" + trading_value: np.ndarray | float + """Total worth of trading. In the simple simulation, trade_value = deal_amount * price.""" + position: np.ndarray | float + """Position after this step.""" + cash: np.ndarray | float + """Cash after this step.""" + total_cash: np.ndarray | float + """Total cash used for trading.""" + + # Accumulated metrics + ret: np.ndarray | float + """Return.""" + swap_value: np.ndarray | int + """Swap value for calculating transaction fee.""" + + +class SAATState(NamedTuple): + """Data structure holding a state for SAAT simulator.""" + + task: Task + """The stock we are dealing with.""" + cur_time: pd.Timestamp + """Current time, e.g., 9:30.""" + cur_step: int + """Current step, e.g., 0.""" + cash: float + """Current remaining cash can be used.""" + position: float + """Current position.""" + history_exec: pd.DataFrame + """See :attr:`SingleAssetAlgorithmTrading.history_exec`.""" + history_steps: pd.DataFrame + """See :attr:`SingleAssetAlgorithmTrading.history_steps`.""" + metrics: Optional[SAATMetrics] + """Daily metric, only available when the trading is in "done" state.""" + backtest_data: BaseIntradayBacktestData + """Backtest data is included in the state. + Actually, only the time index of this data is needed, at this moment. + I include the full data so that algorithms (e.g., VWAP) that relies on the raw data can be implemented. + Interpreter can use this as they wish, but they should be careful not to leak future data. + """ + ticks_per_step: int + """How many ticks for each step.""" + ticks_index: pd.DatetimeIndex + """Trading ticks in all day, NOT sliced by task (defined in data). e.g., [9:30, 9:31, ..., 14:59].""" + ticks_for_trading: pd.DatetimeIndex + """Trading ticks sliced by trading, e.g., [9:45, 9:46, ..., 14:44].""" diff --git a/qlib/rl/contrib/train_at_onpolicy.py b/qlib/rl/contrib/train_at_onpolicy.py new file mode 100644 index 0000000000..8f0e379d3d --- /dev/null +++ b/qlib/rl/contrib/train_at_onpolicy.py @@ -0,0 +1,261 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import argparse +import os +import random +import warnings +from pathlib import Path +from typing import cast, List, Optional + +import numpy as np +import pandas as pd +import qlib +import torch +import yaml +from qlib.backtest import Order +from qlib.backtest.decision import OrderDir +from qlib.backtest.decision import Task +from qlib.constant import ONE_MIN +from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter +from qlib.rl.algorithm_trading import SingleAssetAlgorithmTradingSimple +from qlib.rl.reward import Reward +from qlib.rl.trainer import Checkpoint, backtest, train +from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter +from qlib.rl.utils.log import CsvWriter +from qlib.utils import init_instance_by_config +from tianshou.policy import BasePolicy +from torch.utils.data import Dataset + + +def seed_everything(seed: int) -> None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +def _read_tasks(task_dir: Path) -> pd.DataFrame: + if os.path.isfile(task_dir): + return pd.read_pickle(task_dir) + else: + tasks = [] + for file in task_dir.iterdir(): + task_data = pd.read_pickle(file) + tasks.append(task_data) + return pd.concat(tasks) + + +class LazyLoadDataset(Dataset): + def __init__( + self, + task_file_path: Path, + data_dir: Path, + default_start_time_index: int, + default_end_time_index: int, + ) -> None: + self._default_start_time_index = default_start_time_index + self._default_end_time_index = default_end_time_index + + self._task_file_path = task_file_path + self._task_df = _read_tasks(task_file_path).reset_index() + + self._data_dir = data_dir + self._ticks_index: Optional[pd.DatetimeIndex] = None + + def __len__(self) -> int: + return len(self._task_df) + + def __getitem__(self, index: int) -> Task: + row = self._task_df.iloc[index] + date = pd.Timestamp(str(row["date"])) + + if self._ticks_index is None: + # TODO: We only load ticks index once based on the assumption that ticks index of different dates + # TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index + # TODO: of all dates. + backtest_data = load_simple_intraday_backtest_data( + data_dir=self._data_dir, + stock_id=row["instrument"], + date=date, + ) + self._ticks_index = [t - date for t in backtest_data.get_time_index()] + # treat the + task = Task( + stock_id=row["instrument"], + cash=row["amount"], + start_time=date + self._ticks_index[self._default_start_time_index], + end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN, + ) + + return task + + +def train_and_test( + env_config: dict, + simulator_config: dict, + trainer_config: dict, + data_config: dict, + state_interpreter: StateInterpreter, + action_interpreter: ActionInterpreter, + policy: BasePolicy, + reward: Reward, + run_training: bool, + run_backtest: bool, +) -> None: + qlib.init() + + task_root_path = Path(data_config["source"]["task_dir"]) + + data_granularity = simulator_config.get("data_granularity", 1) + + def _simulator_factory_simple(task: Task) -> SingleAssetAlgorithmTradingSimple: + return SingleAssetAlgorithmTradingSimple( + task=task, + data_dir=Path(data_config["source"]["data_dir"]), + ticks_per_step=simulator_config["time_per_step"], + data_granularity=data_granularity, + fee_rate=simulator_config["fee_rate"], + deal_price_type=data_config["source"].get("deal_price_column", "close"), + ) + + assert data_config["source"]["default_start_time_index"] % data_granularity == 0 + assert data_config["source"]["default_end_time_index"] % data_granularity == 0 + + if run_training: + train_dataset, valid_dataset = [ + LazyLoadDataset( + task_file_path=task_root_path / tag, + data_dir=Path(data_config["source"]["data_dir"]), + default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, + default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, + ) + for tag in ("train", "valid") + ] + + callbacks: List[Callback] = [] + if "checkpoint_path" in trainer_config: + callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"]))) + callbacks.append( + Checkpoint( + dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints", + every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1), + save_latest="copy", + ), + ) + if "earlystop_patience" in trainer_config: + callbacks.append( + EarlyStopping( + patience=trainer_config["earlystop_patience"], + monitor="val/reward", + ) + ) + + train( + simulator_fn=_simulator_factory_simple, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + policy=policy, + reward=reward, + initial_states=cast(List[Task], train_dataset), + trainer_kwargs={ + "max_iters": trainer_config["max_epoch"], + "finite_env_type": env_config["parallel_mode"], + "concurrency": env_config["concurrency"], + "val_every_n_iters": trainer_config.get("val_every_n_epoch", None), + "callbacks": callbacks, + }, + vessel_kwargs={ + "episode_per_iter": trainer_config["episode_per_collect"], + "update_kwargs": { + "batch_size": trainer_config["batch_size"], + "repeat": trainer_config["repeat_per_collect"], + }, + "val_initial_states": valid_dataset, + }, + ) + + if run_backtest: + test_dataset = LazyLoadDataset( + task_file_path=task_root_path / "test", + data_dir=Path(data_config["source"]["data_dir"]), + default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, + default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, + ) + + backtest( + simulator_fn=_simulator_factory_simple, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + initial_states=test_dataset, + policy=policy, + logger=CsvWriter(Path(trainer_config["checkpoint_path"])), + reward=reward, + finite_env_type=env_config["parallel_mode"], + concurrency=env_config["concurrency"], + ) + + +def main(config: dict, run_training: bool, run_backtest: bool) -> None: + if not run_training and not run_backtest: + warnings.warn("Skip the entire job since training and backtest are both skipped.") + return + + if "seed" in config["runtime"]: + seed_everything(config["runtime"]["seed"]) + + state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"]) + action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"]) + reward: Reward = init_instance_by_config(config["reward"]) + + additional_policy_kwargs = { + "obs_space": state_interpreter.observation_space, + "action_space": action_interpreter.action_space, + } + + # Create torch network + if "network" in config: + if "kwargs" not in config["network"]: + config["network"]["kwargs"] = {} + config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space}) + additional_policy_kwargs["network"] = init_instance_by_config(config["network"]) + + # Create policy + if "kwargs" not in config["policy"]: + config["policy"]["kwargs"] = {} + config["policy"]["kwargs"].update(additional_policy_kwargs) + policy: BasePolicy = init_instance_by_config(config["policy"]) + + use_cuda = config["runtime"].get("use_cuda", False) + if use_cuda: + policy.cuda() + + train_and_test( + env_config=config["env"], + simulator_config=config["simulator"], + data_config=config["data"], + trainer_config=config["trainer"], + action_interpreter=action_interpreter, + state_interpreter=state_interpreter, + policy=policy, + reward=reward, + run_training=run_training, + run_backtest=run_backtest, + ) + + +if __name__ == "__main__": + warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") + parser.add_argument("--no_training", action="store_true", help="Skip training workflow.") + parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.") + args = parser.parse_args() + + with open(args.config_path, "r") as input_stream: + config = yaml.safe_load(input_stream) + + main(config, run_training=not args.no_training, run_backtest=args.run_backtest) From 29f94f15707ce38463fe3e179bd3d3a9b4d92819 Mon Sep 17 00:00:00 2001 From: lwwang Date: Wed, 5 Jul 2023 23:32:02 +0800 Subject: [PATCH 2/5] Fixed. --- examples/rl_algorithm_trading/exp_configs/train_at_opds.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/rl_algorithm_trading/exp_configs/train_at_opds.yml b/examples/rl_algorithm_trading/exp_configs/train_at_opds.yml index 3875119217..3f191e99dd 100755 --- a/examples/rl_algorithm_trading/exp_configs/train_at_opds.yml +++ b/examples/rl_algorithm_trading/exp_configs/train_at_opds.yml @@ -47,7 +47,7 @@ policy: class: PPO kwargs: lr: 0.0001 - module_path: qlib.rl.order_execution.policy + module_path: qlib.rl.algorithm_trading.policy runtime: seed: 42 use_cuda: false From 6ae734acea72f698efcf6ff5abb2030dcff3be4c Mon Sep 17 00:00:00 2001 From: lwwang Date: Wed, 5 Jul 2023 23:35:21 +0800 Subject: [PATCH 3/5] Fixed. --- qlib/rl/algorithm_trading/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/qlib/rl/algorithm_trading/__init__.py b/qlib/rl/algorithm_trading/__init__.py index f7a7001695..eadc3e6125 100644 --- a/qlib/rl/algorithm_trading/__init__.py +++ b/qlib/rl/algorithm_trading/__init__.py @@ -1,10 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -""" -Currently it supports single-asset order execution. -Multi-asset is on the way. -""" from .interpreter import ( FullHistoryATStateInterpreter, From e6e6baf3fe5d32a87f2b163eceaf60f3c2a824d1 Mon Sep 17 00:00:00 2001 From: lwwang Date: Wed, 5 Jul 2023 23:40:50 +0800 Subject: [PATCH 4/5] Remove unused imports. --- qlib/rl/contrib/train_at_onpolicy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/qlib/rl/contrib/train_at_onpolicy.py b/qlib/rl/contrib/train_at_onpolicy.py index 8f0e379d3d..95c01ea7e4 100644 --- a/qlib/rl/contrib/train_at_onpolicy.py +++ b/qlib/rl/contrib/train_at_onpolicy.py @@ -12,8 +12,6 @@ import qlib import torch import yaml -from qlib.backtest import Order -from qlib.backtest.decision import OrderDir from qlib.backtest.decision import Task from qlib.constant import ONE_MIN from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data From e4e0d3f4f7c6a70f2fd2e90a513bbcde70b396e7 Mon Sep 17 00:00:00 2001 From: Litzy Date: Fri, 7 Jul 2023 03:57:01 +0000 Subject: [PATCH 5/5] Fix a bug. --- qlib/backtest/decision.py | 17 +++++++++++++++++ qlib/rl/algorithm_trading/interpreter.py | 2 -- qlib/rl/contrib/train_at_onpolicy.py | 10 +++++----- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 7188bec7a5..3f939863f4 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -150,6 +150,23 @@ def date(self) -> pd.Timestamp: """Date of the order.""" return pd.Timestamp(self.start_time.replace(hour=0, minute=0, second=0)) +@dataclass +class Task: + """ + stock_id : str + cash : float + start_time : pd.Timestamp + closed start time for order trading + end_time : pd.Timestamp + closed end time for order trading + factor : float + presents the weight factor assigned in Exchange() + """ + + stock_id: str + cash: float + start_time: pd.Timestamp + end_time: pd.Timestamp class OrderHelper: """ diff --git a/qlib/rl/algorithm_trading/interpreter.py b/qlib/rl/algorithm_trading/interpreter.py index 44af2e2f20..2f9e8dda80 100644 --- a/qlib/rl/algorithm_trading/interpreter.py +++ b/qlib/rl/algorithm_trading/interpreter.py @@ -3,14 +3,12 @@ from __future__ import annotations -import math from typing import Any, List, Optional, cast import numpy as np import pandas as pd from gym import spaces -from qlib.constant import EPS from qlib.rl.data.base import ProcessedDataProvider from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.algorithm_trading.state import SAATState diff --git a/qlib/rl/contrib/train_at_onpolicy.py b/qlib/rl/contrib/train_at_onpolicy.py index 95c01ea7e4..48d259cae6 100644 --- a/qlib/rl/contrib/train_at_onpolicy.py +++ b/qlib/rl/contrib/train_at_onpolicy.py @@ -9,9 +9,12 @@ import numpy as np import pandas as pd -import qlib -import torch import yaml +import torch +from tianshou.policy import BasePolicy +from torch.utils.data import Dataset + +import qlib from qlib.backtest.decision import Task from qlib.constant import ONE_MIN from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data @@ -22,9 +25,6 @@ from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter from qlib.rl.utils.log import CsvWriter from qlib.utils import init_instance_by_config -from tianshou.policy import BasePolicy -from torch.utils.data import Dataset - def seed_everything(seed: int) -> None: torch.manual_seed(seed)