In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
from pathlib import Path
from typing import Any, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
from stable_baselines3.common.base_class import BaseAlgorithm

from trainer.asset.stock import Stock
from trainer.env.asset_pool import AssetPool
from trainer.env.trading_platform import PositionType, TradingPlatform, calc_earning


LAST_TRAINING_DATE = datetime.datetime.strptime("2019-12-31", "%Y-%m-%d").date()
HISTORICAL_DAYS_NUM = 90
MAX_DAYS_NUM = 360 * 12


def generate_asset_pool(symbols: List[str]) -> AssetPool:
    assets = [Stock(symbol, Path("../data/stock/us"), max_days_num=MAX_DAYS_NUM) for symbol in symbols]
    return AssetPool(assets)


def trade(
    env: TradingPlatform,
    model: Optional[BaseAlgorithm] = None, max_step: Optional[int] = None, stop_at_termination: bool = True,
) -> Tuple[Any, Tuple[float, ...]]:
    env.render_mode = "rgb_array"
    obs, _ = env.reset()
    # Run one episode
    step = 0
    while max_step is None or step < max_step:
        action = None
        if model is not None:
            action, _ = model.predict(obs, deterministic=True)
        else:
            action = int(np.random.choice([PositionType.LONG, PositionType.SHORT]))
        obs, _, terminated, truncated, _ = env.step(action)
        if (stop_at_termination and terminated) or truncated:
            break
        step += 1
    rendered = env.render()
    # Calculate the balance
    self_calculated_balance = env._initial_balance
    earning, actual_price_change = calc_earning(
        env._positions, env._prices[-1],
        position_opening_fee=env._position_opening_fee,
    )
    self_calculated_balance += earning
    # Platform balance and self-calculated balance should be equal if `stop_at_termination` is true
    return rendered, (self_calculated_balance, actual_price_change)


def show_image(image: Any):
    figure = plt.figure(figsize=(10, 6), dpi=200)
    axes = figure.add_subplot(111)
    axes.imshow(image)
    plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
from stable_baselines3 import DQN


env = TradingPlatform(
    generate_asset_pool(["AAPL"]), HISTORICAL_DAYS_NUM,
    max_balance_loss=0.2, min_positions_num=20,
)
model = DQN(
    "MultiInputPolicy", env,
    exploration_fraction=0.2, exploration_initial_eps=0.4, exploration_final_eps=0.1,
    verbose=1,
)

In [None]:
env.is_training_mode = True
env.apply_date_range(max_date=LAST_TRAINING_DATE)
model.learn(total_timesteps=2000000, log_interval=1000)

In [None]:
env.is_training_mode = False
env.apply_date_range(min_date=LAST_TRAINING_DATE)
rendered, (earning, actual_price_change) = trade(env, model=model, stop_at_termination=False)
print(earning, actual_price_change)
show_image(rendered)