In [None]:
import warnings

import gymnasium as gym
import torch
import torch.nn as nn
from skrl.agents.torch.td3 import TD3, TD3_DEFAULT_CONFIG
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, Model
from skrl.resources.noises.torch import GaussianNoise
from skrl.resources.schedulers.torch import KLAdaptiveLR
from skrl.trainers.torch import ParallelTrainer
from skrl.utils import set_seed
from skrl.utils.spaces.torch import unflatten_tensorized_space
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from preprocess import preprocess

set_seed(42)

gym.register(
    id="MultiDatasetDiscretedTradingEnv",
    entry_point="predict_next_candle:MultiDatasetDiscretedTradingEnv",
    disable_env_checker=True,
)

In [None]:
env_cfg = dict(
    id="MultiDatasetDiscretedTradingEnv",
    dataset_dir="./data/futures/15m/**/**/*.pkl",
    preprocess=preprocess,
    max_episode_duration="max",
    verbose=0,
    window_size=30,
    btc_index=True,
)

In [None]:
env = gym.make(**env_cfg)
obs = env.observation_space
env = gym.make_vec(
    vectorization_mode="async",
    num_envs=1024,
    **env_cfg,
)
env = wrap_env(env, wrapper="gymnasium")

In [None]:
device = env.device
replay_buffer_size = 1024 * 1 * env.num_envs
memory_size = int(replay_buffer_size / env.num_envs)
memory = RandomMemory(memory_size=memory_size, num_envs=env.num_envs, device=device, replacement=False)

In [None]:
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=replay_buffer_size):
        super().__init__()
        self.position_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        positions = positions.expand(batch_size, seq_len)
        position_encoded = self.position_embedding(positions)
        return x + position_encoded


class Actor(DeterministicMixin, Model):
    def __init__(self, observation_space, action_space, device, clip_actions=False):
        Model.__init__(self, observation_space, action_space, device)
        DeterministicMixin.__init__(self, clip_actions)

        self.num_features = 7
        self.net_projection = nn.Sequential(
            nn.Conv1d(self.num_features, 8, kernel_size=1, padding=1),
        )

        transformer_layer = TransformerEncoderLayer(
            d_model=8,
            nhead=4,
            dim_feedforward=256,
            dropout=0.1,
            batch_first=True,
        )
        self.positional_encoding = LearnablePositionalEncoding(d_model=8)
        self.transformer_encoder = TransformerEncoder(transformer_layer, num_layers=2)

        self.policy_head = nn.Sequential(
            nn.Conv1d(8, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(8, self.num_actions, kernel_size=1),
            nn.AdaptiveAvgPool1d(1),
        )

    def compute(self, inputs, role):
        states = unflatten_tensorized_space(obs, inputs["states"])

        features = states.permute(0, 2, 1)
        features = self.net_projection(features)
        features = self.positional_encoding(features.permute(0, 2, 1))
        features = features + self.transformer_encoder(features)

        shared_features = features.permute(0, 2, 1)  # torch.Size([32, 128, 7])
        actions = self.policy_head(shared_features)  # torch.Size([32, 5, 1])
        return actions.squeeze(-1), {}


class Critic(DeterministicMixin, Model):
    def __init__(self, observation_space, action_space, device, clip_actions=False):
        Model.__init__(self, observation_space, action_space, device)
        DeterministicMixin.__init__(self, clip_actions)

        self.num_features = 7
        self.net_projection = nn.Sequential(
            nn.Conv1d(self.num_features, 8, kernel_size=1, padding=1),
        )

        transformer_layer = TransformerEncoderLayer(
            d_model=8,
            nhead=4,
            dim_feedforward=256,
            dropout=0.1,
            batch_first=True,
        )
        self.positional_encoding = LearnablePositionalEncoding(d_model=8)
        self.transformer_encoder = TransformerEncoder(transformer_layer, num_layers=2)

        self.value_head = nn.Sequential(
            nn.Conv1d(8, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(8, 1, kernel_size=1),
            nn.AdaptiveAvgPool1d(1),
        )

    def compute(self, inputs, role):
        states = unflatten_tensorized_space(obs, inputs["states"])

        features = states.permute(0, 2, 1)
        features = self.net_projection(features)
        features = self.positional_encoding(features.permute(0, 2, 1))
        features = features + self.transformer_encoder(features)

        shared_features = features.permute(0, 2, 1)

        value = self.value_head(shared_features)
        return value.squeeze(-1), {}


In [None]:
models = {}
models["policy"] = Actor(env.observation_space, env.action_space, device)
models["target_policy"] = Actor(env.observation_space, env.action_space, device)
models["critic_1"] = Critic(env.observation_space, env.action_space, device)
models["critic_2"] = Critic(env.observation_space, env.action_space, device)
models["target_critic_1"] = Critic(env.observation_space, env.action_space, device)
models["target_critic_2"] = Critic(env.observation_space, env.action_space, device)


for model in models.values():
    model.init_parameters(method_name="normal_", mean=0.0, std=0.1)

In [None]:
cfg = TD3_DEFAULT_CONFIG.copy()
# cfg["exploration"]["noise"] = GaussianNoise(0, 0.1, device=device)
# cfg["smooth_regularization_noise"] = GaussianNoise(0, 0.2, device=device)
cfg["smooth_regularization_clip"] = 0.5

cfg["batch_size"] = 4096
cfg["random_timesteps"] = 1000
cfg["learning_starts"] = 1000

cfg["discount_factor"] = 0.99
cfg["learning_rate_scheduler"] = KLAdaptiveLR
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01, "min_lr": 1e-7}

cfg["experiment"]["write_interval"] = 1000
cfg["experiment"]["checkpoint_interval"] = 10000
cfg["experiment"]["directory"] = "runs/torch/mddt"

In [None]:
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)

agent = TD3(
    models=models,
    memory=memory,
    cfg=cfg,
    observation_space=env.observation_space,
    action_space=env.action_space,
    device=device,
)
# path = "24-12-10_11-34-51-118959_PPO"
# agent.load(f"/home/pitin/Desktop/hp/runs/torch/mddt/{path}/checkpoints/best_agent.pt")
cfg_trainer = {"timesteps": 10000000, "headless": True, "environment_info": ["reward"]}
trainer = ParallelTrainer(cfg=cfg_trainer, env=env, agents=[agent])

In [None]:
trainer.train()