In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import random
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
from stable_baselines3.common.base_class import BaseAlgorithm

from trainer.asset.base import DailyAsset
from trainer.asset.stock import Stock
from trainer.trading_platform import PositionType, TradingPlatform, calc_earning


LAST_TRAINING_DATE = datetime.datetime.strptime("2019-12-31", "%Y-%m-%d").date()
HISTORICAL_DAYS_NUM = 90


def generate_asset_pool(symbols: List[str]) -> Dict[str, DailyAsset]:
    return {
        symbol: Stock(Path("../data/stock/us"), symbol) for symbol in symbols
    }


def trade(
    env: TradingPlatform,
    model: Optional[BaseAlgorithm] = None, max_step: Optional[int] = None, stop_at_termination: bool = True,
) -> Tuple[Any, Tuple[float, ...]]:
    env.render_mode = "rgb_array"
    obs, _ = env.reset()
    # Run one episode
    step = 0
    while max_step is None or step < max_step:
        action = None
        if model is not None:
            action, _ = model.predict(obs, deterministic=True)
        else:
            action = int(random.choice([PositionType.LONG, PositionType.SHORT]))
        obs, _, terminated, truncated, _ = env.step(action)
        if (stop_at_termination and terminated) or truncated:
            break
        step += 1
    rendered = env.render()
    # Calculate the balance
    self_calculated_balance = env._initial_balance
    earning, actual_price_change = calc_earning(
        env._positions, env._prices[-1],
        position_opening_fee=env._position_opening_fee,
    )
    self_calculated_balance += earning
    # Platform balance and self-calculated balance should be equal if `stop_at_termination` is true
    return rendered, (self_calculated_balance, actual_price_change)


def show_image(image: Any):
    figure = plt.figure(figsize=(10, 6), dpi=200)
    axes = figure.add_subplot(111)
    axes.imshow(image)
    plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
from stable_baselines3 import DQN


env = TradingPlatform(
    generate_asset_pool(["AAPL"]), HISTORICAL_DAYS_NUM, LAST_TRAINING_DATE,
    position_opening_fee=0.01,
    max_balance_loss=0.2, min_positions_num=20, min_steps_num=180,
)
model = DQN("MultiInputPolicy", env, exploration_fraction=0.2, verbose=1)

In [None]:
env.is_training_mode = True
model.learn(total_timesteps=200000, log_interval=100)

In [None]:
env.is_training_mode = False
rendered, (earning, actual_price_change) = trade(env, model, stop_at_termination=False)
print(earning, actual_price_change)
show_image(rendered)