In [None]:
import gymnasium as gym
import torch
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.trainers.torch import SequentialTrainer
from skrl.utils import set_seed
from reward import refined_pnl
from preprocess import only_sub_indicators

set_seed(42)

gym.register(
    id="MultiDatasetDiscretedTradingEnv",
    entry_point="environment:MultiDatasetDiscretedTradingEnv",
    disable_env_checker=True,
)

In [None]:
env_cfg = dict(
    id="MultiDatasetDiscretedTradingEnv",
    dataset_dir="./data/train/month_1h/**/*.pkl",
    preprocess=only_sub_indicators,
    reward_function=refined_pnl,
    positions=[-2, 0, 2],
    trading_fees=0.0001,
    borrow_interest_rate=0.0003,
    portfolio_initial_value=100,
    max_episode_duration="max",  # 24 * 60,
    verbose=1,
    window_size=None,
)

In [None]:
obs = gym.make(**env_cfg).observation_space
env = gym.make_vec(
    vectorization_mode="sync",
    num_envs=4,
    wrappers=[gym.wrappers.FlattenObservation],
    **env_cfg,
)
env = wrap_env(env, wrapper="gymnasium")

In [None]:
device = env.device
replay_buffer_size = 1024 * 4 * env.num_envs
memory_size = int(replay_buffer_size / env.num_envs)
memory = RandomMemory(memory_size=memory_size, num_envs=env.num_envs, device=device, replacement=False)

In [None]:
import torch.nn as nn
from skrl.models.torch import DeterministicMixin, CategoricalMixin, Model
from skrl.utils.spaces.torch import unflatten_tensorized_space


class Policy(CategoricalMixin, Model):
    def __init__(
        self,
        observation_space,
        action_space,
        device,
        clip_actions=False,
        unnormalized_log_prob=True,
        num_envs=4,
        num_layers=1,
        hidden_size=64,
        sequence_length=10,
    ):
        Model.__init__(self, observation_space, action_space, device)
        CategoricalMixin.__init__(self, unnormalized_log_prob)

        self.num_envs = num_envs
        self.num_layers = num_layers
        self.hidden_size = hidden_size  # Hcell (Hout is Hcell because proj_size = 0)
        self.sequence_length = sequence_length

        self.lstm = nn.LSTM(
            input_size=self.num_observations,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            batch_first=True,
        )

        self.net = nn.Sequential(
            nn.Linear(self.hidden_size, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Linear(16, self.num_actions),
        )

        self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))

    def get_specification(self):
        return {
            "rnn": {
                "sequence_length": self.sequence_length,
                "sizes": [
                    (self.num_layers, self.num_envs, self.hidden_size),  # hidden states (D ∗ num_layers, N, Hout)
                    (self.num_layers, self.num_envs, self.hidden_size),  # cell states   (D ∗ num_layers, N, Hcell)
                ],
            }
        }

    # def compute(self, inputs, role):
    #     states = unflatten_tensorized_space(obs, inputs["states"])

    #     features = self.net_features(states["features"])
    #     states = self.net_states(
    #         torch.cat(
    #             (states["equity"], states["total_ROE"], states["PNL"], states["ROE"], states["entry_price"]), dim=1
    #         )
    #     )
    #     values = torch.cat((features, states), dim=1)

    #     self._shared_output = self.net(values)
    #     action = self.mean_layer(self._shared_output)
    #     return action, {}

    def compute(self, inputs, role):
        states = inputs["states"] 
        # states = unflatten_tensorized_space(obs, inputs["states"])
        terminated = inputs.get("terminated", None)
        hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]

        # training
        if self.training:
            rnn_input = states.view(
                -1, self.sequence_length, states.shape[-1]
            )  # (N, L, Hin): N=batch_size, L=sequence_length
            hidden_states = hidden_states.view(
                self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
            )  # (D * num_layers, N, L, Hout)
            cell_states = cell_states.view(
                self.num_layers, -1, self.sequence_length, cell_states.shape[-1]
            )  # (D * num_layers, N, L, Hcell)
            # get the hidden/cell states corresponding to the initial sequence
            hidden_states = hidden_states[:, :, 0, :].contiguous()  # (D * num_layers, N, Hout)
            cell_states = cell_states[:, :, 0, :].contiguous()  # (D * num_layers, N, Hcell)

            # reset the RNN state in the middle of a sequence
            if terminated is not None and torch.any(terminated):
                rnn_outputs = []
                terminated = terminated.view(-1, self.sequence_length)
                indexes = (
                    [0]
                    + (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist()
                    + [self.sequence_length]
                )

                for i in range(len(indexes) - 1):
                    i0, i1 = indexes[i], indexes[i + 1]
                    rnn_output, (hidden_states, cell_states) = self.lstm(
                        rnn_input[:, i0:i1, :], (hidden_states, cell_states)
                    )
                    hidden_states[:, (terminated[:, i1 - 1]), :] = 0
                    cell_states[:, (terminated[:, i1 - 1]), :] = 0
                    rnn_outputs.append(rnn_output)

                rnn_states = (hidden_states, cell_states)
                rnn_output = torch.cat(rnn_outputs, dim=1)
            # no need to reset the RNN state in the sequence
            else:
                rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
        # rollout
        else:
            rnn_input = states.view(-1, 1, states.shape[-1])  # (N, L, Hin): N=num_envs, L=1
            rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))

        # flatten the RNN output
        rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1)  # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)

        return self.net(rnn_output), {"rnn": [rnn_states[0], rnn_states[1]]}


class Value(DeterministicMixin, Model):
    def __init__(
        self,
        observation_space,
        action_space,
        device,
        clip_actions=False,
        num_envs=1,
        num_layers=1,
        hidden_size=64,
        sequence_length=128,
    ):
        Model.__init__(self, observation_space, action_space, device)
        DeterministicMixin.__init__(self, clip_actions)

        self.num_envs = num_envs
        self.num_layers = num_layers
        self.hidden_size = hidden_size  # Hcell (Hout is Hcell because proj_size = 0)
        self.sequence_length = sequence_length

        self.lstm = nn.LSTM(
            input_size=self.num_observations,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            batch_first=True,
        )

        self.net = nn.Sequential(
            nn.Linear(self.hidden_size, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Linear(16, 1),
        )

    def get_specification(self):
        return {
            "rnn": {
                "sequence_length": self.sequence_length,
                "sizes": [
                    (self.num_layers, self.num_envs, self.hidden_size),  # hidden states (D ∗ num_layers, N, Hout)
                    (self.num_layers, self.num_envs, self.hidden_size),  # cell states   (D ∗ num_layers, N, Hcell)
                ],
            }
        }

    # def compute(self, inputs, role):
    #     states = unflatten_tensorized_space(obs, inputs["states"])

    #     features = self.net_features(states["features"])
    #     states = self.net_states(
    #         torch.cat(
    #             (states["equity"], states["total_ROE"], states["PNL"], states["ROE"], states["entry_price"]), dim=1
    #         )
    #     )
    #     values = torch.cat((features, states), dim=1)

    #     shared_output = self.net(values) if self._shared_output is None else self._shared_output
    #     self._shared_output = None
    #     value = self.value_layer(shared_output)
    #     return value, {}

    def compute(self, inputs, role):
        states = inputs["states"]
        terminated = inputs.get("terminated", None)
        hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]

        # training
        if self.training:
            rnn_input = states.view(-1, self.sequence_length, states.shape[-1])  # (N, L, Hin): N=batch_size, L=sequence_length

            hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1])  # (D * num_layers, N, L, Hout)
            cell_states = cell_states.view(self.num_layers, -1, self.sequence_length, cell_states.shape[-1])  # (D * num_layers, N, L, Hcell)
            # get the hidden/cell states corresponding to the initial sequence
            hidden_states = hidden_states[:,:,0,:].contiguous()  # (D * num_layers, N, Hout)
            cell_states = cell_states[:,:,0,:].contiguous()  # (D * num_layers, N, Hcell)

            # reset the RNN state in the middle of a sequence
            if terminated is not None and torch.any(terminated):
                rnn_outputs = []
                terminated = terminated.view(-1, self.sequence_length)
                indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]

                for i in range(len(indexes) - 1):
                    i0, i1 = indexes[i], indexes[i + 1]
                    rnn_output, (hidden_states, cell_states) = self.lstm(rnn_input[:,i0:i1,:], (hidden_states, cell_states))
                    hidden_states[:, (terminated[:,i1-1]), :] = 0
                    cell_states[:, (terminated[:,i1-1]), :] = 0
                    rnn_outputs.append(rnn_output)

                rnn_states = (hidden_states, cell_states)
                rnn_output = torch.cat(rnn_outputs, dim=1)
            # no need to reset the RNN state in the sequence
            else:
                rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
        # rollout
        else:
            rnn_input = states.view(-1, 1, states.shape[-1])  # (N, L, Hin): N=num_envs, L=1
            rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))

        # flatten the RNN output
        rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1)  # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)

        return self.net(rnn_output), {"rnn": [rnn_states[0], rnn_states[1]]}

In [None]:
model_cfg = dict(
    observation_space=env.observation_space,
    action_space=env.action_space,
    device=device,
    num_envs=env.num_envs,
    num_layers=1,
    hidden_size=64,
    sequence_length=128,
)
models = {}
models["policy"] = Policy(
    unnormalized_log_prob=True,
    **model_cfg,
)
models["value"] = Value(
    clip_actions=False,
    **model_cfg,
)

In [None]:
from skrl.agents.torch.ppo.ppo_rnn import PPO_DEFAULT_CONFIG
from lr_schedulers import CosineAnnealingWarmUpRestarts

timesteps = 1000000
cfg = PPO_DEFAULT_CONFIG.copy()
cfg["rollouts"] = memory_size
cfg["learning_epochs"] = 4
cfg["mini_batches"] = 4
cfg["discount_factor"] = 0.95
# cfg["lambda"] = 0.95
cfg["learning_rate"] = 0
# cfg["grad_norm_clip"] = 1.0
# cfg["ratio_clip"] = 0.2
# cfg["value_clip"] = 0.2
# cfg["clip_predicted_values"] = False
# cfg["entropy_loss_scale"] = 0.0
# cfg["value_loss_scale"] = 0.5
# cfg["kl_threshold"] = 0
cfg["learning_starts"] = 0
cfg["learning_rate_scheduler"] = CosineAnnealingWarmUpRestarts
cfg["learning_rate_scheduler_kwargs"] = {
    "T_0": 16 * cfg["learning_epochs"],
    "T_mult": 1,
    "T_up": cfg["learning_epochs"],
    "eta_max": 1e-4,
    "gamma": 0.5,
}
# logging to TensorBoard and write checkpoints (in timesteps)
cfg["experiment"]["write_interval"] = 1024
cfg["experiment"]["checkpoint_interval"] = 100000
cfg["experiment"]["directory"] = "runs/torch/Trading"

In [None]:
import warnings
from skrl.agents.torch.ppo.ppo_rnn import PPO_RNN

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=DeprecationWarning)

agent = PPO_RNN(
    models=models,
    memory=memory,
    cfg=cfg,
    observation_space=env.observation_space,
    action_space=env.action_space,
    device=device,
)
cfg_trainer = {"timesteps": timesteps, "headless": True}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])

In [None]:
trainer.train()

In [None]:
terminated = False
observation, info = env.reset()

while terminated:
    # state-preprocessor + policy
    with torch.no_grad():
        states = state_preprocessor(states)
        actions = policy.act({"states": states})[0]

    # step the environment
    next_states, rewards, terminated, truncated, infos = env.step(actions)

    # render the environment
    env.render()

    # check for termination/truncation
    if terminated.any() or truncated.any():
        states, infos = env.reset()
    else:
        states = next_states