In [None]:
import gymnasium as gym
import torch
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.trainers.torch import SequentialTrainer, ParallelTrainer
from skrl.utils import set_seed
from reward import simple_reward
from preprocess import preprocess
from lr_schedulers import CosineAnnealingWarmUpRestarts, AdaptiveScheduler

set_seed(42)

gym.register(
    id="MultiDatasetDiscretedTradingEnv",
    entry_point="mc_env:MultiDatasetDiscretedTradingEnv",
    disable_env_checker=True,
)

In [None]:
env_cfg = dict(
    id="MultiDatasetDiscretedTradingEnv",
    dataset_dir="./data/train/month_15m/**/**/*.pkl",
    preprocess=preprocess,
    # reward_function=simple_reward,
    positions=[-1, 1],
    multiplier=[1, 2, 5],
    trading_fees=0.01,
    borrow_interest_rate=0.03,
    portfolio_initial_value=1e4,
    max_episode_duration=1000,
    verbose=0,
    window_size=60,
)

In [None]:
env = gym.make(**env_cfg)
obs = env.observation_space
env = gym.make_vec(
    vectorization_mode="sync",
    wrappers=[gym.wrappers.FlattenObservation],
    num_envs=32,
    **env_cfg,
)
env = wrap_env(env, wrapper="gymnasium")

In [None]:
device = env.device
replay_buffer_size = 1024 * 16 * env.num_envs
memory_size = int(replay_buffer_size / env.num_envs)
memory = RandomMemory(memory_size=memory_size, num_envs=env.num_envs, device=device, replacement=False)

In [None]:
import torch.nn as nn


class DepthwiseSeparableConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(DepthwiseSeparableConv1d, self).__init__()
        self.depthwise = nn.Conv1d(
            in_channels, 
            in_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            padding=padding, 
            groups=in_channels  # Depthwise convolution
        )
        self.pointwise = nn.Conv1d(
            in_channels, 
            out_channels, 
            kernel_size=1  # Pointwise convolution
        )
        self.batchnorm = nn.BatchNorm1d(out_channels)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.batchnorm(x)
        x = self.activation(x)
        return x

In [None]:
import torch.nn as nn
from skrl.models.torch import DeterministicMixin, MultiCategoricalMixin, Model
from skrl.utils.spaces.torch import unflatten_tensorized_space
import torch.nn.functional as F


class Shared(MultiCategoricalMixin, DeterministicMixin, Model):
    def __init__(
        self,
        observation_space,
        action_space,
        device,
        clip_actions=False,
        unnormalized_log_prob=True,
        reduction="sum",
    ):
        Model.__init__(self, observation_space, action_space, device)
        MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
        DeterministicMixin.__init__(self, clip_actions)

        self.net_feature = nn.Sequential(
            nn.Conv1d(10, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(32),
            nn.RReLU(),
            nn.Dropout(0.2),
            nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(64),
            nn.RReLU(),
            nn.Dropout(0.2),
            nn.Conv1d(64, 128, kernel_size=3, stride=2),
            nn.RReLU(),
            nn.Flatten(),
             # torch.Size([batch, 1144])
            nn.Linear(896, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(128, 32),
            nn.BatchNorm1d(32),
            nn.GELU(),
        )
        self.net_info = nn.Sequential(
            nn.Linear(3, 4),
            nn.BatchNorm1d(4),
            nn.RReLU(),
            nn.Dropout(0.2),
            nn.Linear(4, 8),
            nn.RReLU(),
        )

        # self.rnd_target = nn.Sequential(
        #     nn.Linear(512, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 512),
        # )

        # self.rnd_predictor = nn.Sequential(
        #     nn.Linear(512, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 512),
        # )

        self.mean_layer = nn.Linear(32 + 8, self.num_actions)
        self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))

        self.value_layer = nn.Linear(32 + 8, 1)

    def act(self, inputs, role):
        if role == "policy":
            return MultiCategoricalMixin.act(self, inputs, role)
        elif role == "value":
            return DeterministicMixin.act(self, inputs, role)
        
    def compute(self, inputs, role):
        states = unflatten_tensorized_space(obs, inputs["states"])  # 32, 288, 7-10

        # rnd_target_output = self.rnd_target(features.detach())
        # rnd_predictor_output = self.rnd_predictor(features)
        # rnd_bonus = F.mse_loss(rnd_predictor_output, rnd_target_output, reduction="none").mean(dim=1, keepdim=True) ** 2
        # rnd_loss = F.mse_loss(rnd_predictor_output, rnd_target_output)

        if role == "policy":
            features = self.net_feature(states["features"].permute(0, 2, 1))
            info = torch.cat([states["total_ROE"], states["position"], states["ROE"]], dim=1)
            info = self.net_info(info)
            fusion = torch.cat([features, info], dim=1)

            self._shared_output = fusion
            action = self.mean_layer(self._shared_output)
            return action, {}
        elif role == "value":
            if self._shared_output is None:
                features = self.net_feature(states["features"].permute(0, 2, 1))
                info = torch.cat([states["total_ROE"], states["position"], states["ROE"]], dim=1)
                info = self.net_info(info)
                fusion = torch.cat([features, info], dim=1)

                shared_output = fusion
            else:
                shared_output = self._shared_output
            
            # shared_output = (
            #     self.net(states) if self._shared_output is None else self._shared_output
            # ) # single forward-pass
            self._shared_output = None
            value = self.value_layer(shared_output)
            return value, {}

In [None]:
models = {}
models["policy"] = Shared(env.observation_space, env.action_space, device)
models["value"] = models["policy"]

for model in models.values():
    model.init_parameters(method_name="normal_", mean=0.0, std=0.1)

In [None]:
from skrl.agents.torch.ppo import PPO_DEFAULT_CONFIG

cfg = PPO_DEFAULT_CONFIG.copy()
cfg["rollouts"] = memory_size
cfg["learning_epochs"] = 128
cfg["mini_batches"] = 4
cfg["discount_factor"] = 0.99
cfg["learning_rate"] = 0
# cfg["learning_starts"] = 1000000
cfg["learning_rate_scheduler"] = CosineAnnealingWarmUpRestarts
cfg["learning_rate_scheduler_kwargs"] = {
    "T_0": 16 * cfg["learning_epochs"],
    "T_mult": 2,
    "T_up": cfg["learning_epochs"],
    "eta_max": 5e-4,
    "gamma": 0.8,
}

cfg["experiment"]["write_interval"] = 10000
cfg["experiment"]["checkpoint_interval"] = 500000
cfg["experiment"]["directory"] = "runs/torch/mddt"

In [None]:
import warnings
from skrl.agents.torch.ppo import PPO

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)

agent = PPO(
    models=models,
    memory=memory,
    cfg=cfg,
    observation_space=env.observation_space,
    action_space=env.action_space,
    device=device,
)
cfg_trainer = {"timesteps": 30000000, "headless": True, "environment_info": ["pc_counter", "portfolio_valuation"]}
# agent.load("/home/pitin/Desktop/hp/runs/torch/mddt/24-11-27_17-35-07-776155_PPO/checkpoints/best_agent.pt")
# agent.load("/home/pitin/Desktop/hp/runs/torch/mddt/24-11-27_15-29-47-376108_PPO/checkpoints/best_agent.pt")

# 24-11-27_09-40-22-711638_PPO -> 700k 
# 24-11-27_10-55-49-209747_PPO -> 300k
# runs/torch/mddt/24-11-27_09-07-52-359848_PPO/checkpoints/agent_1700000.pt
# 24-11-27_13-20-25-262904_PPO
trainer = ParallelTrainer(cfg=cfg_trainer, env=env, agents=[agent])
# agent.track_data("Episode/Position changed")

In [None]:
trainer.train()