In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
from pathlib import Path
from typing import Any, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
from stable_baselines3.common.base_class import BaseAlgorithm

from trainer.asset.stock import Stock
from trainer.env.asset_pool import AssetPool
from trainer.env.trading_platform import PositionType, TradingPlatform, calc_earning


YEARLY_TRADABLE_DAYS_NUM = 250  # Typically there are about 250 tradable days in one year
MAX_DAYS_NUM = YEARLY_TRADABLE_DAYS_NUM * 20


def generate_asset_pool(symbols: List[str]) -> AssetPool:
    assets = [Stock(symbol, Path("../data/stock/us"), max_days_num=MAX_DAYS_NUM) for symbol in symbols]
    return AssetPool(assets, polarity_temperature=2.0)


def trade(
    env: TradingPlatform,
    model: Optional[BaseAlgorithm] = None, max_step: Optional[int] = None, stop_when_done: bool = True,
) -> Tuple[Any, Tuple[float, ...]]:
    env.render_mode = "rgb_array"
    obs, _ = env.reset()
    # Run one episode
    step = 0
    while max_step is None or step < max_step:
        action = None
        if model is not None:
            action, _ = model.predict(obs, deterministic=True)
        else:
            action = int(np.random.choice([PositionType.LONG, PositionType.SHORT]))
        obs, _, terminated, truncated, info = env.step(action)
        is_end_of_date = info["is_end_of_date"]
        if is_end_of_date or (stop_when_done and (terminated or truncated)):
            break
        step += 1
    rendered = env.render()
    # Calculate the balance
    self_calculated_balance = env._initial_balance
    earning, actual_price_change = calc_earning(
        env._positions, env._prices[-1],
        position_opening_fee=env._position_opening_fee if env.is_training_mode else 0,
    )
    self_calculated_balance += earning
    # Platform balance and self-calculated balance should be equal
    return rendered, (self_calculated_balance, actual_price_change)


def show_image(image: Any, text: str):
    figure = plt.figure(figsize=(10, 6), dpi=200)
    axes = figure.add_subplot(111)
    axes.imshow(image)
    axes.text(0, 0, text)
    plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
from stable_baselines3 import DQN


LAST_TRAINING_DATE = datetime.datetime.strptime("2019-12-31", "%Y-%m-%d").date()
LAST_VALIDATION_DATE = datetime.datetime.strptime("2022-12-31", "%Y-%m-%d").date()
HISTORICAL_DAYS_NUM = 90
POSITION_OPENING_FEE = 0.01

# Training environment
train_env = TradingPlatform(
    generate_asset_pool(["AAPL"]), HISTORICAL_DAYS_NUM,
    position_opening_fee=POSITION_OPENING_FEE,
    max_balance_loss=0.5, max_balance_gain=0.5, max_positions_num=50, max_steps_num=YEARLY_TRADABLE_DAYS_NUM,
)
train_env.is_training_mode = True
train_env.apply_date_range(max_date=LAST_TRAINING_DATE)
# Validation environment
val_env = TradingPlatform(
    generate_asset_pool(["AAPL"]), HISTORICAL_DAYS_NUM,
    position_opening_fee=POSITION_OPENING_FEE,
)
val_env.is_training_mode = False
val_env.apply_date_range(min_date=LAST_TRAINING_DATE, max_date=LAST_VALIDATION_DATE, exclude_historical=False)
# Test environment
test_env = TradingPlatform(
    generate_asset_pool(["AAPL"]), HISTORICAL_DAYS_NUM,
    position_opening_fee=POSITION_OPENING_FEE,
)
test_env.is_training_mode = False
# test_env.apply_date_range(min_date=LAST_VALIDATION_DATE, exclude_historical=False)
test_env.apply_date_range(max_date=LAST_TRAINING_DATE)
# Model
model = DQN(
    "MultiInputPolicy", train_env,
    verbose=1,
)

In [None]:
from stable_baselines3.common.callbacks import BaseCallback


def eval_model(model: BaseAlgorithm, val_env: TradingPlatform, test_env: Optional[TradingPlatform] = None):
    rendered, (earning, actual_price_change) = trade(val_env, model=model, stop_when_done=False)
    show_image(rendered, f"earning={earning:.2f}, actual_price_change={actual_price_change:.2f}")
    if test_env is not None:
        rendered, (earning, actual_price_change) = trade(test_env, model=model, stop_when_done=False)
        show_image(rendered, f"earning={earning:.2f}, actual_price_change={actual_price_change:.2f}")


class FullEvalCallback(BaseCallback):
    _eval_freq: int
    _val_env: TradingPlatform
    _test_env: Optional[TradingPlatform]
    _ep_count: int

    def __init__(self, eval_freq: int, val_env: TradingPlatform, test_env: Optional[TradingPlatform] = None, verbose: int = 0):
        super().__init__(verbose)
        self._eval_freq = eval_freq
        self._val_env = val_env
        self._test_env = test_env
        self._ep_count = 0

    def _on_training_start(self) -> None:
        self.__eval_model()
        return super()._on_training_start()

    def _on_step(self) -> bool:
        # See: https://github.com/DLR-RM/stable-baselines3/blob/v2.3.2/stable_baselines3/common/callbacks.py#L590-L631
        self._ep_count += np.sum(self.locals["dones"]).item()
        if self._ep_count > 0 and self._ep_count % self._eval_freq == 0:
            self.__eval_model()
            self._ep_count = 0  # Prevent eval from running more than once in the same episode
        return super()._on_step()

    def _on_training_end(self) -> None:
        self.__eval_model()
        return super()._on_training_end()

    def __eval_model(self):
        eval_model(self.model, self._val_env, test_env=self._test_env)


model.learn(
    total_timesteps=2000000,
    callback=FullEvalCallback(1000, val_env, test_env=test_env),
    log_interval=1000,
)