In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import logging
import random
from typing import Any, Tuple

from stable_baselines3.common.base_class import BaseAlgorithm

from trainer.asset import ComposedSinusoid
from trainer.trading_platform import TradingPlatform
from trainer.util import find_min_tradable_start_date


EPSILON = 1e-14
MIN_INITIAL_TIME = datetime.datetime.strptime("2015/01/01 08:00:00", "%Y/%m/%d %H:%M:%S")
LAST_TRAINING_DATE = datetime.datetime.strptime("2019/12/31", "%Y/%m/%d").date()
HISTORICAL_DAYS_NUM = 90

logging.getLogger().setLevel(logging.INFO)


asset_pool = [
    ComposedSinusoid(
        MIN_INITIAL_TIME +
        datetime.timedelta(days=random.randint(
            0,
            (LAST_TRAINING_DATE - find_min_tradable_start_date(MIN_INITIAL_TIME)).days - HISTORICAL_DAYS_NUM,
        )),
        random.randint(4, 8),
        # Wavelength: λ = (2*π)/α
        # We will retrieve close prices after each 1 unit of time.
        # That means, according to FFT theory, to make it easy to reconstruct the original function from sampled data,
        # we can select α small enough so that the wavelength is greater than 2 (units), or in other words: α < π.
        alpha_range=(0.1, 0.2),
        beta_range=(1, 2),
        gamma1_range=(1, 2),
        # Should be greater than 1 (since the minimum of the sine function is -1), ensuring the price is never negative.
        gamma2_range=(1, 5),
    ) for _ in range(20)
]
env = TradingPlatform(
    asset_pool, HISTORICAL_DAYS_NUM, LAST_TRAINING_DATE,
    0.01, 0.002, 0.2,
    min_positions_num=5, min_steps_num=60,
)

In [None]:
from stable_baselines3 import DQN


model = DQN("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=100000, log_interval=100)

In [None]:
import matplotlib.pyplot as plt


def trade(env: TradingPlatform, model: BaseAlgorithm, step: int) -> Tuple[Any, Tuple[float, float]]:
    env.render_mode = "rgb_array"
    obs, _ = env.reset()
    # Run one episode
    for _ in range(step):
        action, _ = model.predict(obs, deterministic=True)
        obs, _, _, _, _ = env.step(action)
        logging.debug("%s %f %f", env._prices[-1].time, env._prices[-1].actual_price, env._balance)
    rendered = env.render()
    # Calculate the earning
    earning, price_change = env.calc_earning(env._positions, env._prices[-1].time)
    logging.debug("%s %f", env._prices[-1].time, env._prices[-1].actual_price)
    return rendered, (earning, price_change)


plt.axis("off")
plt.tight_layout()

rendered, (earning, price_change) = trade(env, model, 600)
print(earning, price_change)
plt.imshow(rendered)