In [1]:
import math

import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, Model, MultiCategoricalMixin
from skrl.trainers.torch import ParallelTrainer
from skrl.utils import set_seed
from skrl.utils.spaces.torch import unflatten_tensorized_space
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from lr_schedulers import CosineAnnealingWarmUpRestarts
from preprocess import preprocess

set_seed(42)

gym.register(
    id="MultiDatasetDiscretedTradingEnv",
    entry_point="mc_env:MultiDatasetDiscretedTradingEnv",
    disable_env_checker=True,
)

[38;20m[skrl:INFO] Seed: 42[0m


In [2]:
env_cfg = dict(
    id="MultiDatasetDiscretedTradingEnv",
    dataset_dir="./data/train/month_15m/**/**/*.pkl",
    preprocess=preprocess,
    # reward_function=simple_reward,
    positions=[-1, 1],
    multiplier=[1, 2, 5],
    trading_fees=0.01,
    borrow_interest_rate=0.03,
    portfolio_initial_value=1e4,
    max_episode_duration=1000,
    verbose=0,
    window_size=60,
)

In [3]:
env = gym.make(**env_cfg)
obs = env.observation_space
env = gym.make_vec(
    vectorization_mode="sync",
    wrappers=[gym.wrappers.FlattenObservation],
    num_envs=32,
    **env_cfg,
)
env = wrap_env(env, wrapper="gymnasium")

[38;20m[skrl:INFO] Environment wrapper: gymnasium[0m


In [4]:
device = env.device
replay_buffer_size = 1024 * 16 * env.num_envs
memory_size = int(replay_buffer_size / env.num_envs)
memory = RandomMemory(memory_size=memory_size, num_envs=env.num_envs, device=device, replacement=False)

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)
        encoding = encoding.unsqueeze(0)
        self.register_buffer("encoding", encoding)

    def forward(self, x):
        # Add positional encoding to the input
        return x + self.encoding[:, : x.size(1), :]


class AttentionFusion(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        # Using Conv1d with kernel_size=1 to simulate fully connected layers
        self.query_layer = nn.Conv1d(input_dim, output_dim, kernel_size=1)
        self.key_layer = nn.Conv1d(input_dim, output_dim, kernel_size=1)
        self.value_layer = nn.Conv1d(input_dim, output_dim, kernel_size=1)

    def forward(self, i):
        i = i.unsqueeze(-1)  # [batch_size, input_dim] -> [batch_size, input_dim, 1]
        query = self.query_layer(i).squeeze(-1)  # [batch_size, output_dim]
        key = self.key_layer(i).squeeze(-1)      # [batch_size, output_dim]
        value = self.value_layer(i).squeeze(-1)  # [batch_size, output_dim]

        # Reshape for scaled_dot_product_attention
        query = query.unsqueeze(1)  # [batch_size, 1, output_dim]
        key = key.unsqueeze(1)      # [batch_size, 1, output_dim]
        value = value.unsqueeze(1)  # [batch_size, 1, output_dim]

        # Calculate attention using PyTorch's scaled_dot_product_attention
        attention_output = F.scaled_dot_product_attention(query, key, value)

        # Remove the extra dimension and add residual connection
        return attention_output.squeeze(1) + i.squeeze(-1)


class SharedNoFC(MultiCategoricalMixin, DeterministicMixin, Model):
    def __init__(
        self,
        observation_space,
        action_space,
        device,
        clip_actions=False,
        unnormalized_log_prob=True,
        reduction="sum",
    ):
        Model.__init__(self, observation_space, action_space, device)
        MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
        DeterministicMixin.__init__(self, clip_actions)

        # CNN-based feature extraction
        self.net_feature = nn.Sequential(
            nn.Conv1d(10, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(32),
            nn.RReLU(),
            nn.Dropout(0.2),
            nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(64),
            nn.RReLU(),
            nn.Dropout(0.2),
            nn.Conv1d(64, 128, kernel_size=3, stride=2),
            nn.RReLU(),
            nn.AdaptiveAvgPool1d(1),
        )

        # Transformer Encoder for self-attention
        transformer_layer = TransformerEncoderLayer(
            d_model=128,  # The size of the input feature vector
            nhead=4,      # Number of attention heads
            dim_feedforward=256,  # The size of the feedforward network in the encoder
            dropout=0.1,
            batch_first=True  # Use batch_first for better inference performance
        )
        self.positional_encoding = PositionalEncoding(d_model=128)
        self.transformer_encoder = TransformerEncoder(transformer_layer, num_layers=2)
        self.attention_fusion = AttentionFusion(131, 131)

        # Output heads using Conv1d with additional layers to capture interactions
        self.policy_head = nn.Sequential(
            nn.Conv1d(131, 32, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(32, self.num_actions, kernel_size=1)
        )
        self.value_head = nn.Sequential(
            nn.Conv1d(131, 32, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(32, 1, kernel_size=1)
        )

    def act(self, inputs, role):
        if role == "policy":
            return MultiCategoricalMixin.act(self, inputs, role)
        elif role == "value":
            return DeterministicMixin.act(self, inputs, role)

    def compute(self, inputs, role):
        states = unflatten_tensorized_space(obs, inputs["states"])

        if role == "policy":
            features = self.net_feature(states["features"].permute(0, 2, 1))
            features = features.squeeze(-1)  # Remove last dimension after AdaptiveAvgPool1d
            features = self.positional_encoding(features.unsqueeze(1))
            features = self.transformer_encoder(features).squeeze(1)

            info = torch.cat([states["total_ROE"], states["position"], states["ROE"]], dim=1)
            fused = torch.cat([features, info], dim=1)
            fused_features = self.attention_fusion(fused)

            self._shared_features = fused_features
            actions = self.policy_head(self._shared_features.unsqueeze(-1)).squeeze(-1)
            return actions, {}
            
        elif role == "value":
            if self._shared_features is None:
                features = self.net_feature(states["features"].permute(0, 2, 1))
                features = features.squeeze(-1)  # Remove last dimension after AdaptiveAvgPool1d
                features = self.positional_encoding(features.unsqueeze(1))
                features = self.transformer_encoder(features).squeeze(1)

                info = torch.cat([states["total_ROE"], states["position"], states["ROE"]], dim=1)
                fused = torch.cat([features, info], dim=1)
                fused_features = self.attention_fusion(fused)

                shared_features = fused_features
            else:
                shared_features = self._shared_features

            self._shared_output = None

            value = self.value_head(shared_features.unsqueeze(-1)).squeeze(-1)
            return value, {}


In [6]:
models = {}
models["policy"] = SharedNoFC(env.observation_space, env.action_space, device)
models["value"] = models["policy"]

for model in models.values():
    model.init_parameters(method_name="normal_", mean=0.0, std=0.1)

In [7]:
from skrl.agents.torch.ppo import PPO_DEFAULT_CONFIG

cfg = PPO_DEFAULT_CONFIG.copy()
cfg["rollouts"] = memory_size
cfg["learning_epochs"] = 64
cfg["mini_batches"] = 8
cfg["discount_factor"] = 0.99
cfg["learning_rate"] = 0
cfg["learning_rate_scheduler"] = CosineAnnealingWarmUpRestarts
cfg["learning_rate_scheduler_kwargs"] = {
    "T_0": 16 * cfg["learning_epochs"],
    "T_mult": 2,
    "T_up": cfg["learning_epochs"],
    "eta_max": 5e-4,
    "gamma": 0.8,
}

cfg["experiment"]["write_interval"] = 10000
cfg["experiment"]["checkpoint_interval"] = 500000
cfg["experiment"]["directory"] = "runs/torch/mddt"

In [8]:
import warnings

from skrl.agents.torch.ppo import PPO

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)

agent = PPO(
    models=models,
    memory=memory,
    cfg=cfg,
    observation_space=env.observation_space,
    action_space=env.action_space,
    device=device,
)
cfg_trainer = {"timesteps": 30000000, "headless": True, "environment_info": ["pc_counter", "portfolio_valuation"]}
trainer = ParallelTrainer(cfg=cfg_trainer, env=env, agents=[agent])

In [9]:
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

In [None]:
trainer.train()

  0%|          | 2239/30000000 [00:05<21:01:24, 396.35it/s]