In [1]:
import os

import gymnasium as gym
import torch
from tianshou.data import (
    Collector,
    CollectStats,
    PrioritizedVectorReplayBuffer,
    ReplayBuffer,
    VectorReplayBuffer,
)
from tianshou.env import DummyVectorEnv
from tianshou.policy import TD3BCPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
from tianshou.utils.space_info import SpaceInfo
from tianshou.exploration import GaussianNoise
from torch.utils.tensorboard import SummaryWriter

from preprocess import preprocess

device = "cuda" if torch.cuda.is_available() else "cpu"

gym.register(
    id="MultiDatasetDiscretedTradingEnv",
    entry_point="tianshou_env:MultiDatasetDiscretedTradingEnv",
    disable_env_checker=True,
)

In [2]:
env_cfg = dict(
    id="MultiDatasetDiscretedTradingEnv",
    dataset_dir="./data/futures/5m/**/**/*.pkl",
    preprocess=preprocess,
    positions=[-1, 0, 1],
    multiplier=range(1, 51),
    trading_fees=0.0001,
    borrow_interest_rate=0.0003,
    portfolio_initial_value=1e3,
    max_episode_duration="max",
    verbose=0,
    window_size=60,
    btc_index=True,
)

In [3]:
env = gym.make(**env_cfg)
space_info = SpaceInfo.from_env(env)
state_shape = space_info.observation_info.obs_shape
action_shape = space_info.action_info.action_shape

In [4]:
task = "TradingEnv"
expert_data_task = "trading-env"

seed = 42
scale_obs = 0
eps_test = 0.005
eps_train = 1.0
eps_train_final = 0.05
replay_buffer_size = 100000

actor_lr = 3e-4
critic_lr = 3e-4

alpha = 0.6
beta = 0.4
gamma = 0.99
tau = 0.005
num_fractions = 32
num_cosines = 64
ent_coef = 10.0
hidden_sizes = [512]
target_update_freq = 500
update_per_step = 0.1

exploration_noise = 0.1
policy_noise = 0.2
update_actor_freq = 2
noise_clip = 0.05

logdir = "log"
render = 0.0
device = "cuda" if torch.cuda.is_available() else "cpu"
frames_stack = 4
resume_path = None
resume_id = None
logger = "tensorboard"
wandb_project = "trading.benchmark"
watch = False
save_buffer_name = None

reward_threshold = 1e4
num_train_envs = 256
num_test_envs = 32
batch_size = 32
n_step = batch_size * num_train_envs

epoch = 100
step_per_epoch = 8000
step_per_collect = num_train_envs * 10

In [5]:
train_envs = DummyVectorEnv([lambda: gym.make(**env_cfg) for _ in range(num_train_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(**env_cfg) for _ in range(num_test_envs)])

In [6]:
actor_net = Net(
    state_shape,
    hidden_sizes[-1],
    hidden_sizes=hidden_sizes[:-1],
    device=device,
    softmax=False,
)
actor = Actor(
    actor_net,
    action_shape,
    hidden_sizes,
    num_cosines=num_cosines,
    device=device,
)
actor_optim = torch.optim.Adam(actor.parameters(), lr=actor_lr)

In [None]:
critic_net_1 = Net(
    state_shape,
    hidden_sizes[-1],
    hidden_sizes=hidden_sizes[:-1],
    device=device,
    softmax=False,
)
critic_net_2 = Net(
    state_shape,
    hidden_sizes[-1],
    hidden_sizes=hidden_sizes[:-1],
    device=device,
    softmax=False,
)
critic_1 = Critic(critic_net_1, device=device)
critic_2 = Critic(critic_net_2, device=device)
critic_1_optim = torch.optim.Adam(critic_net_1.parameters(), lr=critic_lr)
critic_2_optim = torch.optim.Adam(critic_net_2.parameters(), lr=critic_lr)

In [7]:
policy: TD3BCPolicy = TD3BCPolicy(
    actor=actor,
    actor_optim=actor_optim,
    critic=critic_1,
    critic_optim=critic_1_optim,
    critic2=critic_2,
    critic2_optim=critic_2_optim,
    tau=tau,
    gamma=gamma,
    exploration_noise=GaussianNoise(sigma=exploration_noise),
    policy_noise=policy_noise,
    update_actor_freq=update_actor_freq,
    noise_clip=noise_clip,
    alpha=alpha,
    estimation_step=n_step,
    action_space=env.action_space,
).to(device)

In [8]:
if resume_path:
    policy.load_state_dict(torch.load(resume_path, map_location=device))
    print("Loaded agent from: ", resume_path)

In [9]:
test_collector = Collector(policy, test_envs)

In [11]:
log_path = os.path.join(logdir, task, "td3bc")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

In [12]:
def save_best_fn(policy: BasePolicy) -> None:
    torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards: float) -> bool:
    return mean_rewards >= reward_threshold

def train_fn(epoch: int, env_step: int) -> None:
    # eps annnealing, just a demo
    if env_step <= 10000:
        policy.set_eps(eps_train)
    elif env_step <= 50000:
        eps = eps_train - (env_step - 10000) / 40000 * (0.9 * eps_train)
        policy.set_eps(eps)
    else:
        policy.set_eps(0.1 * eps_train)

def test_fn(epoch: int, env_step: int | None) -> None:
    policy.set_eps(eps_test)

In [None]:
result = OfflineTrainer(
    policy=policy,
    buffer=replay_buffer,
    test_collector=test_collector,
    max_epoch=epoch,
    step_per_epoch=step_per_epoch,
    episode_per_test=num_test_envs,
    batch_size=batch_size,
    save_best_fn=save_best_fn,
    logger=logger,
).run()
assert stop_fn(result.best_reward)