In [1]:
import warnings
import time
import numpy as np

import gymnasium as gym
import torch
import torch.nn as nn
from gym_trading_env.renderer import Renderer
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, Model, GaussianMixin, MultivariateGaussianMixin
from skrl.resources.schedulers.torch import KLAdaptiveLR
from skrl.trainers.torch import ParallelTrainer
from skrl.utils import set_seed
from skrl.utils.spaces.torch import unflatten_tensorized_space
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from preprocess import preprocess


In [2]:
set_seed(42)

gym.register(
    id="MultiDatasetDiscretedTradingEnv",
    entry_point="predict_env:MultiDatasetDiscretedTradingEnv",
    disable_env_checker=True,
)

env_cfg = dict(
    id="MultiDatasetDiscretedTradingEnv",
    dataset_dir="./data/futures/15m/2024/11/*.pkl",
    preprocess=preprocess,
    max_episode_duration="max",
    verbose=1,
    leverage=10,
    stop_loss=-0.15,
    take_profit=0.6,
    window_size=30,
    btc_index=True,
)
env = gym.make(**env_cfg)
obs = env.observation_space
env = wrap_env(env, wrapper="gymnasium")

device = env.device
replay_buffer_size = 1024 * 1 * env.num_envs
memory_size = int(replay_buffer_size / env.num_envs)

[38;20m[skrl:INFO] Seed: 42[0m
[38;20m[skrl:INFO] Environment wrapper: gymnasium[0m


In [3]:
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1024 * 16):
        super().__init__()
        self.position_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        positions = positions.expand(batch_size, seq_len)
        position_encoded = self.position_embedding(positions)
        return x + position_encoded


class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super(AttentionPooling, self).__init__()
        self.attention = nn.Linear(hidden_size, 1)

    def forward(self, x):
        attention_weights = self.attention(x).squeeze(-1)
        attention_weights = torch.softmax(attention_weights, dim=-1)
        pooled_x = torch.bmm(attention_weights.unsqueeze(1), x).squeeze(1)
        return pooled_x


class SharedNoFC(MultivariateGaussianMixin, DeterministicMixin, Model):
    def __init__(
        self,
        observation_space,
        action_space,
        device,
        clip_actions=False,
        clip_log_std=True,
        min_log_std=-20,
        max_log_std=2,
    ):
        Model.__init__(self, observation_space, action_space, device)
        MultivariateGaussianMixin.__init__(
            self, clip_actions, clip_log_std, min_log_std, max_log_std
        )
        DeterministicMixin.__init__(self, clip_actions)

        self._shared_features = None
        self.num_features = 8

        # Transformer Encoder for self-attention
        transformer_layer = TransformerEncoderLayer(
            d_model=self.num_features,  # The size of the input feature vector
            nhead=4,  # Number of attention heads
            dim_feedforward=256,  # The size of the feedforward network in the encoder
            dropout=0.1,
            batch_first=True,  # Use batch_first for better inference performance
        )
        self.positional_encoding = LearnablePositionalEncoding(
            d_model=self.num_features
        )
        self.transformer_encoder = TransformerEncoder(transformer_layer, num_layers=2)
        self.attention_pooling = AttentionPooling(self.num_features)

        self.shared_head = nn.Sequential(
            nn.LayerNorm(8),
            nn.GELU(),
            nn.Dropout(0.3),
        )

        self.policy_head = nn.Sequential(
            nn.Linear(self.num_features, self.num_actions),
            # nn.Softplus(),
        )
        self.value_head = nn.Sequential(
            nn.Linear(self.num_features, 1),
            # nn.Softplus(),
        )

        self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))

    def act(self, inputs, role):
        if role == "policy":
            return MultivariateGaussianMixin.act(self, inputs, role)
        elif role == "value":
            return DeterministicMixin.act(self, inputs, role)

    def compute(self, inputs, role):
        states = unflatten_tensorized_space(
            obs, inputs["states"]
        )  # (batch_size, seq_length=30, num_features=8)

        if role == "policy":
            features = self.positional_encoding(states)
            features = self.transformer_encoder(features)
            features = self.attention_pooling(features)
            self._shared_features = self.shared_head(features)

            actions = self.policy_head(self._shared_features)
            return actions, self.log_std_parameter, {}

        elif role == "value":
            if self._shared_features is None:
                features = self.positional_encoding(states)
                features = self.transformer_encoder(features)
                features = self.attention_pooling(features)
                shared_features = self.shared_head(features)
            else:
                shared_features = self._shared_features

            self._shared_output = None

            value = self.value_head(shared_features)
            return value, {}

In [4]:
models = {}
models["policy"] = SharedNoFC(env.observation_space, env.action_space, device)
# models["value"] = models["policy"]

for model in models.values():
    model.init_parameters(method_name="normal_", mean=0.0, std=0.1)

cfg = PPO_DEFAULT_CONFIG.copy()
cfg["rollouts"] = memory_size
cfg["learning_epochs"] = 32
cfg["mini_batches"] = 16
cfg["discount_factor"] = 0.99
cfg["learning_rate"] = 5e-4
cfg["learning_rate_scheduler"] = KLAdaptiveLR
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01, "min_lr": 1e-7}

cfg["experiment"]["write_interval"] = 5000
cfg["experiment"]["checkpoint_interval"] = 100000
cfg["experiment"]["directory"] = "runs/torch/mddt"

In [5]:
class PPO_EVAL(PPO):
    def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tensor:
        if timestep < self._random_timesteps:
            return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy")

        with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
            actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy")
            self._current_log_prob = log_prob

        dist = self.policy.distribution(role="policy")
        actions = dist.mean

        log_prob = dist.log_prob(actions)
        return actions, log_prob, outputs

In [6]:
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)

agent = PPO(
    models=models,
    memory=None,
    cfg=cfg,
    observation_space=env.observation_space,
    action_space=env.action_space,
    device=device,
)
path = "24-12-13_11-30-42-703580_PPO" # v4
# path = "24-12-11_14-07-04-844276_PPO"
# path = "24-12-11_10-57-22-046603_PPO"
agent.load(f"/home/pitin/Desktop/hp/runs/torch/mddt/{path}/checkpoints/best_agent.pt")
cfg_trainer = {"timesteps": 10000000, "headless": True, "environment_info": "pc_counter"}
trainer = ParallelTrainer(cfg=cfg_trainer, env=env, agents=[agent])



In [None]:
states, infos = env.reset()
print(infos["portfolio_valuation"])
timestep = 0
timesteps = 3000
terminated = torch.tensor([[False] * env.num_envs])
reward = torch.zeros(env.num_envs, 1, device=device)

# while (not terminated.any()):
for _ in range(1000):
    # time.sleep(0.5)
    agent.pre_interaction(timestep=timestep, timesteps=timesteps)

    with torch.no_grad():
        actions = agent.act(states, timestep=timestep, timesteps=timesteps)[0]
        next_states, rewards, terminated, truncated, infos = env.step(actions)
        env.render()
        reward += rewards
        # print(rewards)

    super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps)

    if env.num_envs > 1:
        states = next_states
    else:
        if terminated.any() or truncated.any():
            with torch.no_grad():
                states, infos = env.reset()
        else:
            states = next_states

print(infos["portfolio_valuation"])
# print(reward)
# infos

1000.0
v [360.99 355.   357.85]
v [352.48 348.33 352.36]
v [356.92 352.15 355.34]
v [350.99 345.39 350.63]
v [351.87 347.39 351.25]
v [357.27 351.16 355.2 ]
p [ 0.21496257 -0.22360215  1.0689622 ]
v [340.47 335.43 339.19]
p [ 1.137296   -0.19290473 -0.28521416]
v [340.07 331.07 331.41]
v [329.24 321.62 328.71]
p [ 1.2170407  -0.31195575 -0.09573906]
v [340.06 334.26 336.71]
v [343.91 340.17 343.1 ]
v [340.17 335.44 340.02]
p [ 1.0900803  -0.20162866  0.14130141]
v [349.56 345.61 349.03]
v [358.96 349.2  357.97]
v [371.   359.86 366.02]
v [369.97 362.6  364.25]
v [377.9  364.68 375.54]
v [380.73 375.02 378.66]
v [384.36 377.4  378.97]
v [365.69 361.2  365.42]
v [373.13 366.43 367.54]
v [382.24 376.42 380.58]
p [ 0.5082805  1.0074322 -0.4061418]
v [381.81 377.5  380.45]
p [-0.21334013 -0.05684573  1.0173258 ]
p [ 1.0323106  -0.01461381 -0.24159862]


In [8]:
# AVAX
# ADA
# SOL
# ETH
# BNB
# XLM

In [9]:
# env.save_for_render()

In [10]:
# renderer = Renderer(render_logs_dir="render_logs")
# renderer.run()