In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import logging
import random
from typing import Any, List, Optional, Tuple

import matplotlib.pyplot as plt
from stable_baselines3.common.base_class import BaseAlgorithm

from trainer.asset.base import DailyAsset
from trainer.asset.sinusoid import ComposedSinusoid
from trainer.trading_platform import PositionType, TradingPlatform, calc_position_net_ratio, calc_earning


EPSILON = 1e-14
PUBLISHED_TIME = datetime.datetime.strptime("2015-10-01 08:00:00", "%Y-%m-%d %H:%M:%S")
LAST_TRAINING_DATE = datetime.datetime.strptime("2019-12-31", "%Y-%m-%d").date()
HISTORICAL_DAYS_NUM = 90

logging.getLogger().setLevel(logging.INFO)


def generate_asset_pool(asset_num: int) -> List[DailyAsset]:
    return [
        ComposedSinusoid(
            PUBLISHED_TIME,
            random.randint(4, 8),
            # Wavelength: λ = (2*π)/α
            # We will retrieve close prices after each 1 unit of time.
            # That means, according to FFT theory, to make it easy to reconstruct the original function from sampled data,
            # we can select α small enough so that the wavelength is greater than 2 (units), or in other words: α < π.
            alpha_range=(0.1, 0.2),
            beta_range=(1, 2),
            gamma1_range=(1, 2),
            # Should be greater than 1 (since the minimum of the sine function is -1), ensuring the price is never negative.
            gamma2_range=(1, 5),
        ) for _ in range(asset_num)
    ]


def trade(
    env: TradingPlatform,
    model: Optional[BaseAlgorithm] = None, max_step: Optional[int] = None, stop_at_termination: bool = True,
) -> Tuple[Any, Tuple[float, ...]]:
    env.render_mode = "rgb_array"
    obs, _ = env.reset()
    # Run one episode
    done = False
    step = 0
    while max_step is None or step < max_step:
        action = None
        if model is not None:
            action, _ = model.predict(obs, deterministic=True)
        else:
            action = int(random.choice([PositionType.LONG, PositionType.SHORT]))
        obs, _, terminated, truncated, _ = env.step(action)
        logging.debug("%s %f %f", env._prices[-1].date, env._prices[-1].actual_price, env._balance)
        if (stop_at_termination and terminated) or truncated:
            done = True
            break
        step += 1
    opening_position_net = 0 if done else env._positions[-1].amount \
        * calc_position_net_ratio(env._positions[-1], env._prices[-1].actual_price)
    logging.debug("%s %f", done, opening_position_net)
    rendered = env.render()
    # Calculate the balance
    self_calculated_balance = env._initial_balance
    earning, actual_price_change = calc_earning(
        env._positions, env._prices[-1],
        env._position_opening_fee, env._position_holding_daily_fee, env._short_period_penalty,
    )
    self_calculated_balance += earning
    logging.debug("%s %f", env._prices[-1].date, env._prices[-1].actual_price)
    platform_balance = env._balance + opening_position_net
    # Platform balance and self-calculated balance should be equal if `stop_at_termination` is true
    return rendered, (platform_balance, self_calculated_balance, actual_price_change)

In [None]:
env = TradingPlatform(
    generate_asset_pool(100), HISTORICAL_DAYS_NUM, LAST_TRAINING_DATE,
    position_opening_fee=0.01, position_holding_daily_fee=0.002, short_period_penalty=0.005,
    max_balance_loss=0.2, min_positions_num=5, min_steps_num=60,
)

for i in range(1000):
    if i % 100 == 0:
        print(i)
    logging.debug("==========")
    rendered, (platform_balance, self_calculated_balance, _) = trade(env, max_step=90)
    if abs(platform_balance - self_calculated_balance) >= EPSILON:
        print(platform_balance, self_calculated_balance, platform_balance - self_calculated_balance)
        plt.axis("off")
        plt.tight_layout()
        plt.imshow(rendered)
        break

In [None]:
from stable_baselines3 import DQN


env = TradingPlatform(
    generate_asset_pool(100), HISTORICAL_DAYS_NUM, LAST_TRAINING_DATE,
    position_opening_fee=0.01, position_holding_daily_fee=0.002, short_period_penalty=0.005,
    max_balance_loss=0.2, min_positions_num=5, min_steps_num=60,
)
model = DQN("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=100000, log_interval=100)

In [None]:
env = TradingPlatform(
    generate_asset_pool(1), HISTORICAL_DAYS_NUM, LAST_TRAINING_DATE,
    use_price_as_position_amount=True,
)
env.is_training_mode = False
rendered, (_, earning, actual_price_change) = trade(env, model, stop_at_termination=False)
print(earning, actual_price_change)
plt.axis("off")
plt.tight_layout()
plt.imshow(rendered)