Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qlib RL framework (stage 2) - trainer #1125

Merged
merged 34 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
542d295
checkpoint
ultmaster May 5, 2022
d5e15ac
Not a workable version
ultmaster May 6, 2022
4acb0c2
vessel
ultmaster Jun 7, 2022
2d1d8cb
ckpt
ultmaster Jun 8, 2022
1f85487
.
ultmaster Jun 8, 2022
ea40fdf
vessel
ultmaster Jun 8, 2022
6816c7e
.
ultmaster Jun 9, 2022
319766f
.
ultmaster Jun 10, 2022
f2c02e0
checkpoint callback
ultmaster Jun 10, 2022
4db8567
.
ultmaster Jun 10, 2022
163bc2a
cleanup
ultmaster Jun 10, 2022
b76b810
logger
ultmaster Jun 11, 2022
647dc76
.
ultmaster Jun 12, 2022
fc7eb9a
test
ultmaster Jun 12, 2022
38ecc21
.
ultmaster Jun 12, 2022
67a53fb
add test
ultmaster Jun 12, 2022
6b391de
.
ultmaster Jun 13, 2022
c73ec3a
.
ultmaster Jun 13, 2022
5980e45
.
ultmaster Jun 13, 2022
85710ef
.
ultmaster Jun 13, 2022
26883c8
New reward
ultmaster Jun 13, 2022
30afc6c
Add train API
ultmaster Jun 13, 2022
68716e4
fix mypy
ultmaster Jun 13, 2022
cbf2577
fix lint
ultmaster Jun 13, 2022
8e479c2
Merge branch 'main' of https://github.com/microsoft/qlib into rl-trai…
ultmaster Jun 13, 2022
4aa421f
More comment
ultmaster Jun 13, 2022
54d2342
3.7 compat
ultmaster Jun 13, 2022
f432525
fix test
ultmaster Jun 13, 2022
5344846
fix test
ultmaster Jun 14, 2022
3123f1a
.
ultmaster Jun 14, 2022
1bb307d
Resolve comments
ultmaster Jun 27, 2022
e8c6f4f
Merge branch 'main' of https://github.com/microsoft/qlib into rl-trai…
ultmaster Jun 28, 2022
5fe8bff
fix typehint
ultmaster Jun 28, 2022
4b5fcb0
Merge branch 'main' of https://github.com/microsoft/qlib into rl-trai…
ultmaster Jun 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions qlib/rl/entries/__init__.py

This file was deleted.

99 changes: 0 additions & 99 deletions qlib/rl/entries/test.py

This file was deleted.

4 changes: 0 additions & 4 deletions qlib/rl/entries/train.py

This file was deleted.

1 change: 1 addition & 0 deletions qlib/rl/order_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
from .interpreter import *
from .network import *
from .policy import *
from .reward import *
from .simulator_simple import *
46 changes: 46 additions & 0 deletions qlib/rl/order_execution/reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

from typing import cast

import numpy as np
from qlib.rl.reward import Reward

from .simulator_simple import SAOEState, SAOEMetrics

__all__ = ["PAPenaltyReward"]


class PAPenaltyReward(Reward[SAOEState]):
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
"""Encourage higher PAs, but penalize stacking all the amounts within a very short time.
Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`.

Parameters
----------
penalty
The penalty for large volume in a short time.
"""

def __init__(self, penalty: float = 100.0):
self.penalty = penalty

def reward(self, simulator_state: SAOEState) -> float:
whole_order = simulator_state.order.amount
assert whole_order > 0
last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict())
pa = last_step["pa"] * last_step["amount"] / whole_order

# Inspect the "break-down" of the latest step: trading amount at every tick
last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :]
penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum()

reward = pa + penalty

# Throw error in case of NaN
assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}"

self.log("reward/pa", pa)
self.log("reward/penalty", penalty)
return reward
7 changes: 5 additions & 2 deletions qlib/rl/order_execution/simulator_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,14 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
"""

history_exec: pd.DataFrame
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns.
Index is ``datetime``.
"""

history_steps: pd.DataFrame
"""Positions at each step. The position before first step is also recorded.
See :class:`SAOEMetrics` for available columns."""
See :class:`SAOEMetrics` for available columns.
Index is ``datetime``, which is the **starting** time of each step."""

metrics: SAOEMetrics | None
"""Metrics. Only available when done."""
Expand Down
9 changes: 9 additions & 0 deletions qlib/rl/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Train, test, inference utilities."""

from .api import backtest, train
from .callbacks import EarlyStopping, Checkpoint
from .trainer import Trainer
from .vessel import TrainingVessel, TrainingVesselBase
118 changes: 118 additions & 0 deletions qlib/rl/trainer/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

from typing import Callable, Sequence, cast, Any

from tianshou.policy import BasePolicy

from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.reward import Reward
from qlib.rl.utils import FiniteEnvType, LogWriter

from .vessel import TrainingVessel
from .trainer import Trainer


def train(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
reward: Reward,
vessel_kwargs: dict[str, Any],
trainer_kwargs: dict[str, Any],
) -> None:
"""Train a policy with the parallelism provided by RL framework.

Experimental API. Parameters might change shortly.

Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to train against.
reward
Reward function.
vessel_kwargs
Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``.
trainer_kwargs
Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``.
"""

vessel = TrainingVessel(
simulator_fn=simulator_fn,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
train_initial_states=initial_states,
reward=reward, # ignore none
**vessel_kwargs,
)
trainer = Trainer(**trainer_kwargs)
trainer.fit(vessel)


def backtest(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: LogWriter | list[LogWriter],
reward: Reward | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
) -> None:
"""Backtest with the parallelism provided by RL framework.

Experimental API. Parameters might change shortly.

Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to test against.
logger
Logger to record the backtest results. Logger must be present because
without logger, all information will be lost.
reward
Optional reward function. For backtest, this is for testing the rewards
and logging them only.
finite_env_type
Type of finite env implementation.
concurrency
Parallel workers.
"""

vessel = TrainingVessel(
simulator_fn=simulator_fn,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
test_initial_states=initial_states,
reward=cast(Reward, reward), # ignore none
)
trainer = Trainer(
finite_env_type=finite_env_type,
concurrency=concurrency,
loggers=logger,
)
trainer.test(vessel)
Loading