In [1]:
import os

import gymnasium as gym
import torch
from tianshou.data import (
    Collector,
    CollectStats,
    PrioritizedVectorReplayBuffer,
    ReplayBuffer,
    VectorReplayBuffer,
)
from tianshou.env import DummyVectorEnv
from tianshou.policy import FQFPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import (
    FractionProposalNetwork,
    FullQuantileFunction,
)
from tianshou.utils.space_info import SpaceInfo
from torch.utils.tensorboard import SummaryWriter

from preprocess import preprocess

device = "cuda" if torch.cuda.is_available() else "cpu"
# set_seed(42)

gym.register(
    id="MultiDatasetDiscretedTradingEnv",
    entry_point="tianshou_env:MultiDatasetDiscretedTradingEnv",
    disable_env_checker=True,
)

In [2]:
env_cfg = dict(
    id="MultiDatasetDiscretedTradingEnv",
    dataset_dir="./data/futures/5m/**/**/*.pkl",
    preprocess=preprocess,
    positions=[-1, 0, 1],
    multiplier=range(1, 51),
    trading_fees=0.0001,
    borrow_interest_rate=0.0003,
    portfolio_initial_value=1e3,
    max_episode_duration="max",
    verbose=0,
    window_size=60,
    btc_index=True,
)

In [3]:
env = gym.make(**env_cfg)
space_info = SpaceInfo.from_env(env)
state_shape = space_info.observation_info.obs_shape
action_shape = space_info.action_info.action_shape

In [4]:
task = "TradingEnv"
seed = 42
scale_obs = 0
eps_test = 0.005
eps_train = 1.0
eps_train_final = 0.05
buffer_size = 100000
lr = 5e-5
fraction_lr = 2.5e-9
alpha = 0.6
beta = 0.4
gamma = 0.99
num_fractions = 32
num_cosines = 64
ent_coef = 10.0
hidden_sizes = [512]
target_update_freq = 500
update_per_step = 0.1
logdir = "log"
render = 0.0
device = "cuda" if torch.cuda.is_available() else "cpu"
frames_stack = 4
resume_path = None
resume_id = None
logger = "tensorboard"
wandb_project = "trading.benchmark"
watch = False
save_buffer_name = None

reward_threshold = 1e4
num_train_envs = 256
num_test_envs = 32
batch_size = 32
n_step = batch_size * num_train_envs

epoch = 100
step_per_epoch = 8000
step_per_collect = num_train_envs * 10

In [5]:
train_envs = DummyVectorEnv([lambda: gym.make(**env_cfg) for _ in range(num_train_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(**env_cfg) for _ in range(num_test_envs)])

In [6]:
feature_net = Net(
    state_shape,
    hidden_sizes[-1],
    hidden_sizes=hidden_sizes[:-1],
    device=device,
    softmax=False,
)
net = FullQuantileFunction(
    feature_net,
    action_shape,
    hidden_sizes,
    num_cosines=num_cosines,
    device=device,
)
optim = torch.optim.Adam(net.parameters(), lr=lr)
fraction_net = FractionProposalNetwork(num_fractions, net.input_dim)
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=fraction_lr)

In [7]:
policy: FQFPolicy = FQFPolicy(
    model=net,
    optim=optim,
    fraction_model=fraction_net,
    fraction_optim=fraction_optim,
    action_space=env.action_space,
    discount_factor=gamma,
    num_fractions=num_fractions,
    ent_coef=ent_coef,
    estimation_step=n_step,
    target_update_freq=target_update_freq,
).to(device)

In [8]:
buf: ReplayBuffer
prioritized_replay = True

if prioritized_replay:
    buf = PrioritizedVectorReplayBuffer(
        buffer_size,
        buffer_num=len(train_envs),
        alpha=alpha,
        beta=beta,
    )
else:
    buf = VectorReplayBuffer(buffer_size, buffer_num=len(train_envs))

In [9]:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)

In [None]:
train_collector.reset()
train_collector.collect(n_step=n_step)

In [11]:
log_path = os.path.join(logdir, task, "fqf")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

In [12]:
def save_best_fn(policy: BasePolicy) -> None:
    torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards: float) -> bool:
    return mean_rewards >= reward_threshold

def train_fn(epoch: int, env_step: int) -> None:
    # eps annnealing, just a demo
    if env_step <= 10000:
        policy.set_eps(eps_train)
    elif env_step <= 50000:
        eps = eps_train - (env_step - 10000) / 40000 * (0.9 * eps_train)
        policy.set_eps(eps)
    else:
        policy.set_eps(0.1 * eps_train)

def test_fn(epoch: int, env_step: int | None) -> None:
    policy.set_eps(eps_test)

In [None]:
result = OffpolicyTrainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,
        max_epoch=epoch,
        step_per_epoch=step_per_epoch,
        step_per_collect=step_per_collect,
        episode_per_test=num_test_envs,
        batch_size=batch_size,
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        logger=logger,
        update_per_step=update_per_step,
    ).run()
assert stop_fn(result.best_reward)