In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import logging

import numpy as np

from trainer.asset.sinusoid import ComposedSinusoid
from trainer.env.asset_pool import AssetPool


PUBLISHED_TIME = datetime.datetime.strptime("2015-10-01 08:00:00", "%Y-%m-%d %H:%M:%S")

logging.getLogger().setLevel(logging.INFO)


def generate_asset_pool(asset_num: int) -> AssetPool:
    assets = [
        ComposedSinusoid(
            str(i),
            PUBLISHED_TIME,
            np.random.randint(4, 8),
            # Wavelength: λ = (2*π)/α
            # We will retrieve close prices after each 1 unit of time.
            # That means, according to FFT theory, to make it easy to reconstruct the original function from sampled data,
            # we can select α small enough so that the wavelength is greater than 2 (units), or in other words: α < π.
            alpha_range=(0.1, 0.2),
            beta_range=(1, 2),
            gamma1_range=(1, 2),
            # Should be greater than 1 (since the minimum of the sine function is -1), ensuring the price is never negative.
            gamma2_range=(1, 5),
        ) for i in range(asset_num)
    ]
    return AssetPool(assets)

In [None]:
from trainer.env.trading_platform import TradingPlatform
from trainer.env.evaluation import trade


EPSILON = 1e-14
LAST_TRAINING_DATE = datetime.datetime.strptime("2019-12-31", "%Y-%m-%d").date()
HISTORICAL_DAYS_NUM = 90

env = TradingPlatform(
    generate_asset_pool(100), HISTORICAL_DAYS_NUM,
    position_opening_fee=0.01,
    max_balance_loss=0.2, max_positions_num=5, max_steps_num=60,
)
env.is_training_mode = True
env.apply_date_range()

for i in range(1000):
    if i % 100 == 0:
        print(i)
    logging.debug("==========")
    _, (platform_balance, self_calculated_balance, _) = trade(env, max_step=np.random.randint(10, 120), render=False)
    if abs(platform_balance - self_calculated_balance) >= EPSILON:
        print(platform_balance, self_calculated_balance, platform_balance - self_calculated_balance)
        break

In [None]:
from stable_baselines3 import DQN


env = TradingPlatform(
    generate_asset_pool(100), HISTORICAL_DAYS_NUM,
    position_opening_fee=0.01,
    max_balance_loss=0.2, max_balance_gain=0.5, max_positions_num=10, max_steps_num=120,
)
env.is_training_mode = True
env.apply_date_range(max_date=LAST_TRAINING_DATE)
model = DQN("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=20000, log_interval=100)

In [None]:
from trainer.env.evaluation import show_image


env = TradingPlatform(generate_asset_pool(1), HISTORICAL_DAYS_NUM, position_opening_fee=0.01)
env.is_training_mode = False
env.apply_date_range(min_date=LAST_TRAINING_DATE)
rendered, (_, earning, actual_price_change) = trade(env, model=model, stop_when_done=False)
print(earning, actual_price_change)
show_image(rendered)