Skip to content

Commit

Permalink
Qlib RL framework (stage 2) - trainer (#1125)
Browse files Browse the repository at this point in the history
* checkpoint

(cherry picked from commit 1a8e0bd)

* Not a workable version

(cherry picked from commit 3498e18)

* vessel

* ckpt

* .

* vessel

* .

* .

* checkpoint callback

* .

* cleanup

* logger

* .

* test

* .

* add test

* .

* .

* .

* .

* New reward

* Add train API

* fix mypy

* fix lint

* More comment

* 3.7 compat

* fix test

* fix test

* .

* Resolve comments

* fix typehint
  • Loading branch information
ultmaster committed Jun 28, 2022
1 parent 2ca0d88 commit 25ecb11
Show file tree
Hide file tree
Showing 17 changed files with 1,410 additions and 145 deletions.
7 changes: 0 additions & 7 deletions qlib/rl/entries/__init__.py

This file was deleted.

99 changes: 0 additions & 99 deletions qlib/rl/entries/test.py

This file was deleted.

4 changes: 0 additions & 4 deletions qlib/rl/entries/train.py

This file was deleted.

1 change: 1 addition & 0 deletions qlib/rl/order_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
from .interpreter import *
from .network import *
from .policy import *
from .reward import *
from .simulator_simple import *
46 changes: 46 additions & 0 deletions qlib/rl/order_execution/reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

from typing import cast

import numpy as np
from qlib.rl.reward import Reward

from .simulator_simple import SAOEState, SAOEMetrics

__all__ = ["PAPenaltyReward"]


class PAPenaltyReward(Reward[SAOEState]):
"""Encourage higher PAs, but penalize stacking all the amounts within a very short time.
Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`.
Parameters
----------
penalty
The penalty for large volume in a short time.
"""

def __init__(self, penalty: float = 100.0):
self.penalty = penalty

def reward(self, simulator_state: SAOEState) -> float:
whole_order = simulator_state.order.amount
assert whole_order > 0
last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict())
pa = last_step["pa"] * last_step["amount"] / whole_order

# Inspect the "break-down" of the latest step: trading amount at every tick
last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :]
penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum()

reward = pa + penalty

# Throw error in case of NaN
assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}"

self.log("reward/pa", pa)
self.log("reward/penalty", penalty)
return reward
7 changes: 5 additions & 2 deletions qlib/rl/order_execution/simulator_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,14 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
"""

history_exec: pd.DataFrame
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns.
Index is ``datetime``.
"""

history_steps: pd.DataFrame
"""Positions at each step. The position before first step is also recorded.
See :class:`SAOEMetrics` for available columns."""
See :class:`SAOEMetrics` for available columns.
Index is ``datetime``, which is the **starting** time of each step."""

metrics: SAOEMetrics | None
"""Metrics. Only available when done."""
Expand Down
9 changes: 9 additions & 0 deletions qlib/rl/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Train, test, inference utilities."""

from .api import backtest, train
from .callbacks import EarlyStopping, Checkpoint
from .trainer import Trainer
from .vessel import TrainingVessel, TrainingVesselBase
118 changes: 118 additions & 0 deletions qlib/rl/trainer/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

from typing import Callable, Sequence, cast, Any

from tianshou.policy import BasePolicy

from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.reward import Reward
from qlib.rl.utils import FiniteEnvType, LogWriter

from .vessel import TrainingVessel
from .trainer import Trainer


def train(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
reward: Reward,
vessel_kwargs: dict[str, Any],
trainer_kwargs: dict[str, Any],
) -> None:
"""Train a policy with the parallelism provided by RL framework.
Experimental API. Parameters might change shortly.
Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to train against.
reward
Reward function.
vessel_kwargs
Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``.
trainer_kwargs
Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``.
"""

vessel = TrainingVessel(
simulator_fn=simulator_fn,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
train_initial_states=initial_states,
reward=reward, # ignore none
**vessel_kwargs,
)
trainer = Trainer(**trainer_kwargs)
trainer.fit(vessel)


def backtest(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: LogWriter | list[LogWriter],
reward: Reward | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
) -> None:
"""Backtest with the parallelism provided by RL framework.
Experimental API. Parameters might change shortly.
Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to test against.
logger
Logger to record the backtest results. Logger must be present because
without logger, all information will be lost.
reward
Optional reward function. For backtest, this is for testing the rewards
and logging them only.
finite_env_type
Type of finite env implementation.
concurrency
Parallel workers.
"""

vessel = TrainingVessel(
simulator_fn=simulator_fn,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
test_initial_states=initial_states,
reward=cast(Reward, reward), # ignore none
)
trainer = Trainer(
finite_env_type=finite_env_type,
concurrency=concurrency,
loggers=logger,
)
trainer.test(vessel)
Loading

0 comments on commit 25ecb11

Please sign in to comment.