In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, Model, MultiCategoricalMixin, CategoricalMixin
from skrl.trainers.torch import ParallelTrainer
from skrl.utils import set_seed
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from timexer.TimeXer import Model as TimeXer

from preprocess import preprocess_timexer
from skrl.utils.spaces.torch import unflatten_tensorized_space

# set_seed(42)

gym.register(
    id="MultiDatasetDiscretedTradingEnv",
    entry_point="timexer_env:MultiDatasetDiscretedTradingEnv",
    disable_env_checker=True,
)

In [2]:
env_cfg = dict(
    id="MultiDatasetDiscretedTradingEnv",
    dataset_dir="./data/futures/5m/**/**/*.pkl",
    preprocess=preprocess_timexer,
    positions=[-1, 0, 1],
    multiplier=range(1, 51),
    trading_fees=0.001,
    borrow_interest_rate=0.003,
    portfolio_initial_value=1e3,
    max_episode_duration="max",
    verbose=0,
    window_size=60,
    btc_index=True,
)

In [None]:
env = gym.make(**env_cfg)
obs = env.observation_space
env = gym.make_vec(
    vectorization_mode="async",
    num_envs=16,
    **env_cfg,
)
env = wrap_env(env, wrapper="gymnasium")

In [4]:
device = env.device
replay_buffer_size = 1024 * env.num_envs
memory_size = int(replay_buffer_size / env.num_envs)
memory = RandomMemory(memory_size=memory_size, num_envs=env.num_envs, device=device, replacement=False)

In [5]:
class SharedTimeXer(MultiCategoricalMixin, DeterministicMixin, Model):
    def __init__(
        self,
        observation_space,
        action_space,
        device,
        clip_actions=False,
        unnormalized_log_prob=True,
        reduction="sum",
        timexer_config=None,
    ):
        Model.__init__(self, observation_space, action_space, device)
        MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
        DeterministicMixin.__init__(self, clip_actions)

        self._shared_features = None
        self._fusion = None

        # TimeXer 구성
        self.timexer = TimeXer(timexer_config)

        self.num_infos = 5

        # Shared tail
        self.shared_tail = nn.Sequential(
            nn.Flatten(),
            nn.Linear(50, 32),
            nn.BatchNorm1d(32),
            nn.ELU(),
            nn.Dropout(0.2),
        )

        self.policy_head = nn.Sequential(nn.Linear(32, self.num_actions))
        self.value_head = nn.Sequential(nn.Linear(32, 1))

    def act(self, inputs, role):
        if role == "policy":
            return MultiCategoricalMixin.act(self, inputs, role)
        elif role == "value":
            return DeterministicMixin.act(self, inputs, role)

    def compute(self, inputs, role):
        states = unflatten_tensorized_space(obs, inputs["states"])
        features = states["features"]  # (batch_size, seq_length, num_features)
        infos = states["infos"]

        if role == "policy":
            # TimeXer 호출
            ohlcv = features[:, :, 2:7]
            btc_btcdom = features[:, :, 0:2]
            timexer_output = self.timexer.forward(ohlcv, btc_btcdom, None, None)
            
            # Shared Tail에 바로 전달
            self._fusion = self.shared_tail(timexer_output)

            actions = self.policy_head(self._fusion)

            return actions, {}

        elif role == "value":
            if self._fusion is None:
                # TimeXer 호출
                ohlcv = features[:, :, 2:7]
                btc_btcdom = features[:, :, 0:2]
                timexer_output = self.timexer.forward(ohlcv, btc_btcdom, None, None)
                fusion = self.shared_tail(timexer_output)
            else:
                fusion = self._fusion

            self._fusion = None

            value = self.value_head(fusion)
            return value, {}


In [6]:
class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)

timexer_config = {
    "task_name": "short_term_forecast",
    "features": "M",
    "seq_len": 60,
    "pred_len": 10,
    "use_norm": True,
    "patch_len": 5,
    "d_model": 128,
    "n_heads": 4,
    "d_ff": 256,
    "dropout": 0.1,
    "activation": "relu",
    "e_layers": 2,
    "factor": 5,
    "embed": "timeF",
    "freq": "h",
    "enc_in": 5,  # 입력 변수 수를 설정
}

timexer_config = Config(**timexer_config)

In [7]:
models = {}
models["policy"] = SharedTimeXer(
    env.observation_space,
    env.action_space,
    device,
    timexer_config=timexer_config,
)
models["value"] = models["policy"]

for model in models.values():
    model.init_parameters(method_name="normal_", mean=0.0, std=0.1)


In [8]:
from skrl.agents.torch.ppo import PPO_DEFAULT_CONFIG
from skrl.resources.schedulers.torch import KLAdaptiveLR

cfg = PPO_DEFAULT_CONFIG.copy()
cfg["rollouts"] = memory_size
cfg["learning_epochs"] = 4
cfg["mini_batches"] = 6
cfg["discount_factor"] = 0.99
cfg["learning_rate"] = 5e-4
cfg["learning_rate_scheduler"] = KLAdaptiveLR
cfg["learning_rate_scheduler_kwargs"] = {
    "kl_threshold": 0.05,
    "min_lr": 1e-7,
    "max_lr": 1e-3,
}
cfg["mixed_precision"] = True
# cfg["random_timesteps"] = 10000
# cfg["learning_starts"] = 10000


cfg["experiment"]["write_interval"] = 1000
cfg["experiment"]["checkpoint_interval"] = 10000
cfg["experiment"]["directory"] = "runs/torch/mddt"

In [9]:
import warnings

from skrl.agents.torch.ppo import PPO

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=DeprecationWarning)

agent = PPO(
    models=models,
    memory=memory,
    cfg=cfg,
    observation_space=env.observation_space,
    action_space=env.action_space,
    device=device,
)
# path = "24-12-24_14-30-40-586405_PPO"
# agent.load(f"/home/pitin/Desktop/hp/runs/torch/mddt/{path}/checkpoints/agent_50000.pt")
cfg_trainer = {
    "timesteps": 1000000,
    "headless": True,
    "environment_info": [
        "pc_counter",
        "portfolio_valuation",
        "record",
        "position",
        "liquidation",
        "realized_pnl",
        "multiplier",
    ],
}
trainer = ParallelTrainer(cfg=cfg_trainer, env=env, agents=[agent])

In [None]:
trainer.train()