In [1]:
import gymnasium as gym
import torch
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.trainers.torch import SequentialTrainer
from skrl.utils import set_seed
from skrl.utils.model_instantiators.torch import Shape, deterministic_model
from reward import refined_pnl
from preprocess import only_sub_indicators

gym.register(
    id="MultiDatasetDiscretedTradingEnv",
    entry_point="environment:MultiDatasetDiscretedTradingEnv",
    disable_env_checker=True,
)

In [None]:
set_seed(42)

In [6]:
env = gym.make_vec(
    "MultiDatasetDiscretedTradingEnv",
    vectorization_mode="async",
    num_envs=16,
    wrappers=[gym.wrappers.FlattenObservation],
    dataset_dir="./data/train/day/**/**/*.pkl",
    preprocess=only_sub_indicators,
    reward_function=refined_pnl,
    positions=[-5, -2, 0, 2, 5],
    trading_fees=0.0001,
    borrow_interest_rate=0.0003,
    portfolio_initial_value=100,
    max_episode_duration="max",  # 24 * 60,
    verbose=0,
    window_size=120,
)

In [None]:
env = wrap_env(env, wrapper="gymnasium")

In [8]:
device = env.device
memory = RandomMemory(memory_size=4096, num_envs=env.num_envs, device=device, replacement=False)

In [None]:
models = {}
models["q_network"] = deterministic_model(
    observation_space=env.observation_space,
    action_space=env.action_space,
    device=device,
    clip_actions=False,
    input_shape=Shape.OBSERVATIONS,
    hiddens=[64, 64],
    hidden_activation=["relu", "relu"],
    output_shape=Shape.ACTIONS,
    output_activation=None,
    output_scale=1.0,
)
models["target_q_network"] = deterministic_model(
    observation_space=env.observation_space,
    action_space=env.action_space,
    device=device,
    clip_actions=False,
    input_shape=Shape.OBSERVATIONS,
    hiddens=[64, 64],
    hidden_activation=["relu", "relu"],
    output_shape=Shape.ACTIONS,
    output_activation=None,
    output_scale=1.0,
)

In [10]:
from skrl.agents.torch.dqn import DQN
from skrl.agents.torch.dqn import DQN_DEFAULT_CONFIG

cfg = DQN_DEFAULT_CONFIG.copy()
cfg["learning_starts"] = 10000
cfg["exploration"]["final_epsilon"] = 0.04
cfg["exploration"]["timesteps"] = 100000
# logging to TensorBoard and write checkpoints (in timesteps)
cfg["experiment"]["write_interval"] = 1000
cfg["experiment"]["checkpoint_interval"] = 100000
cfg["experiment"]["directory"] = "runs/"

In [11]:
agent = DQN(
    models=models,
    memory=memory,
    cfg=cfg,
    observation_space=env.observation_space,
    action_space=env.action_space,
    device=device,
)

In [12]:
cfg_trainer = {"timesteps": 1000000, "headless": True}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])

In [None]:
trainer.train()

In [None]:
terminated = False
observation, info = env.reset()

while terminated:
    # state-preprocessor + policy
    with torch.no_grad():
        states = state_preprocessor(states)
        actions = policy.act({"states": states})[0]

    # step the environment
    next_states, rewards, terminated, truncated, infos = env.step(actions)

    # render the environment
    env.render()

    # check for termination/truncation
    if terminated.any() or truncated.any():
        states, infos = env.reset()
    else:
        states = next_states