In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime

import numpy as np

from trainer.asset.sinusoid import OPENING_HOUR, ComposedSinusoid
from trainer.env.asset_pool import AssetPool


HISTORICAL_DAYS_NUM = 20
LAST_TRAINING_DATE = datetime.datetime.strptime("2019-12-31", "%Y-%m-%d").date()
PUBLISHED_TIME = datetime.datetime.combine(LAST_TRAINING_DATE + datetime.timedelta(
    days=(
        - 1  # Latest date for reset, i.e., the day before the last training date
        - (HISTORICAL_DAYS_NUM - 1)
        - sum(ComposedSinusoid.calc_buffer_days_num())
        - np.random.randint(2500)
    ),
), datetime.time(OPENING_HOUR))


def generate_asset_pool(assets_num: int) -> AssetPool:
    assets = [
        ComposedSinusoid(
            str(i),
            PUBLISHED_TIME,
            np.random.randint(4, 8),
            # - For alpha range:
            #   Wavelength: λ = (2*π)/α
            #   We will retrieve close prices after each 1 unit of time.
            #   That means, according to FFT theory, to make it easy to reconstruct the original function from sampled data,
            #   we can select α small enough so that the wavelength is greater than 2 (units), or in other words: α < π.
            # - For gamma2 range:
            #   Should be greater than 1 (since the minimum of the sine function is -1), ensuring the price is never negative.
            (0.1, 0.2), (1, 2),
            (1, 2), (10, 20),
            (0.018, 0.022),
        ) for i in range(assets_num)
    ]
    return AssetPool(assets)


train_asset_pool = generate_asset_pool(10)
train_asset_pool.apply_date_range((None, LAST_TRAINING_DATE), HISTORICAL_DAYS_NUM)

eval_asset_pool = generate_asset_pool(1)
eval_asset_pool.apply_date_range((LAST_TRAINING_DATE, None), HISTORICAL_DAYS_NUM)

In [None]:
import matplotlib.pyplot as plt

from stable_baselines3 import PPO

from trainer.env.trading_platform import TradingPlatform
from trainer.env.evaluation import show_image


train_env = TradingPlatform(train_asset_pool, HISTORICAL_DAYS_NUM)
train_env.set_mode(True)
model = PPO("MultiInputPolicy", train_env, verbose=1)
model.learn(total_timesteps=20000, log_interval=100)

eval_env = TradingPlatform(eval_asset_pool, HISTORICAL_DAYS_NUM)
eval_env.set_mode(False)
rendered, (_, earning, price_change, _), _ = eval_env.trade(model=model, stopping_when_done=False)
plt.close("all")
print(earning, price_change)
show_image(rendered)