In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import logging
import random
from typing import Any, Tuple

import matplotlib.pyplot as plt

from trainer.asset import ComposedSinusoid
from trainer.trading_platform import OrderType, TradingPlatform
from trainer.util import find_min_tradable_start_date


MIN_INITIAL_TIME = datetime.datetime.strptime("2018/01/01 08:00:00", "%Y/%m/%d %H:%M:%S")
LAST_TRAINING_DATE = datetime.datetime.strptime("2019/12/31", "%Y/%m/%d").date()
HISTORICAL_DAYS_NUM = 90

asset_pool = [
    ComposedSinusoid(
        MIN_INITIAL_TIME +
        datetime.timedelta(days=random.randint(
            0,
            (LAST_TRAINING_DATE - find_min_tradable_start_date(MIN_INITIAL_TIME)).days - HISTORICAL_DAYS_NUM,
        )),
        random.randint(4, 8),
        # Wavelength: λ = (2*π)/α
        # We will retrieve close prices after each 1 unit of time.
        # That means, according to FFT theory, to make it easy to reconstruct the original function from sampled data,
        # we can select α small enough so that the wavelength is greater than 2 (units), or in other words: α < π.
        alpha_range=(0.1, 0.2),
        beta_range=(1, 2),
        gamma1_range=(1, 2),
        # Should be greater than 1 (since the minimum of the sine function is -1), ensuring the price is never negative.
        gamma2_range=(1, 5),
    ) for _ in range(20)
]

plt.axis("off")
plt.tight_layout()
logging.getLogger().setLevel(logging.INFO)


def trade() -> Tuple[Any, Tuple[float, float]]:
    trading_platform = TradingPlatform(
        asset_pool, HISTORICAL_DAYS_NUM, LAST_TRAINING_DATE,
        0.01, 0.002, 0.1, 0.2,
        min_positions_num=5, min_steps_num=60,
    )
    trading_platform.render_mode = "rgb_array"
    trading_platform.reset()
    # Run one episode
    done = False
    for _ in range(90):
        _, _, terminated, truncated, _ = trading_platform.step(
            int(random.choice([OrderType.LONG, OrderType.SHORT])))
        logging.debug(
            "%s %f %f",
            trading_platform._prices[-1].time, trading_platform._prices[-1].actual_price, trading_platform._balance,
        )
        if terminated or truncated:
            done = True
            break
    # Calculate the balance from closed positions
    balance = trading_platform._INITIAL_BALANCE
    earning, _ = trading_platform.calc_earning([p.order for p in trading_platform._positions])
    balance += earning
    # Earning and fee of the last position
    balance += trading_platform._positions[-1].order.amount * -trading_platform._position_opening_fee
    balance += (trading_platform._prices[-1].time.date() - trading_platform._positions[-1].order.time.date()).days * \
        trading_platform._positions[-1].order.amount * -trading_platform._position_holding_daily_fee
    if done:
        balance += trading_platform._positions[-1].order.amount * trading_platform._last_position_net_ratio
    logging.debug("%s %f", trading_platform._prices[-1].time, trading_platform._prices[-1].actual_price)
    return trading_platform.render(), (trading_platform._balance, balance)


EPSILON = 1e-14
for i in range(1000):
    if i % 100 == 0:
        print(i)
    logging.debug("==========")
    rendered, (platform_balance, self_calculated_balance) = trade()
    if platform_balance - self_calculated_balance >= EPSILON:
        print(platform_balance, self_calculated_balance, platform_balance - self_calculated_balance)
        plt.imshow(rendered)
        break