In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import FQFPolicy
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils.net.common import FullQuantileFunction
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import numpy as np

from preprocess import preprocess

device = "cuda" if torch.cuda.is_available() else "cpu"
# set_seed(42)

gym.register(
    id="MultiDatasetDiscretedTradingEnv",
    entry_point="dict_env:MultiDatasetDiscretedTradingEnv",
    disable_env_checker=True,
)

In [None]:
env_cfg = dict(
    id="MultiDatasetDiscretedTradingEnv",
    preprocess=preprocess,
    positions=[-1, 0, 1],
    multiplier=range(1, 51),
    trading_fees=0.0001,
    borrow_interest_rate=0.0003,
    portfolio_initial_value=1e3,
    max_episode_duration="max",
    verbose=0,
    window_size=60,
    btc_index=True,
)

In [None]:
env = gym.make(**env_cfg)
obs = env.observation_space
train_envs = gym.make_vec(
    vectorization_mode="async",
    num_envs=16,
    dataset_dir="./data/futures/5m/**/**/*.pkl",
    **env_cfg,
)
train_envs = gym.make_vec(
    vectorization_mode="async",
    num_envs=8,
    dataset_dir="./data/futures/5m/2024/**/*.pkl",
    **env_cfg,
)

# env = gym.make("CartPole-v1")
# train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(20)])
# test_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(10)])

In [None]:
feature_net = Net(
    args.state_shape,
    args.hidden_sizes[-1],
    hidden_sizes=args.hidden_sizes[:-1],
    device=args.device,
    softmax=False,
)
net = FullQuantileFunction(
    feature_net,
    env.action_shape,
    args.hidden_sizes,
    args.num_cosines,
    device=args.device,
).to(args.device)

optim = torch.optim.Adam(net.parameters(), lr=args.lr)
fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr)

In [None]:
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1024 * 16):
        super().__init__()
        self.position_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        positions = positions.expand(batch_size, seq_len)
        position_encoded = self.position_embedding(positions)
        return x + position_encoded


class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super(AttentionPooling, self).__init__()
        self.attention = nn.Linear(hidden_size, 1)

    def forward(self, x):
        attention_weights = self.attention(x).squeeze(-1)
        attention_weights = torch.softmax(attention_weights, dim=-1)
        pooled_x = torch.bmm(attention_weights.unsqueeze(1), x).squeeze(1)
        return pooled_x


class SharedNoFC(MultiCategoricalMixin, DeterministicMixin, Model):
    def __init__(
        self,
        observation_space,
        action_space,
        device,
        clip_actions=False,
        unnormalized_log_prob=True,
        reduction="sum",
    ):
        Model.__init__(self, observation_space, action_space, device)
        MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
        DeterministicMixin.__init__(self, clip_actions)

        self._shared_features = None
        self._fusion = None
        self.num_features = 8
        self.num_infos = 8

        # Transformer Encoder for self-attention
        transformer_layer = TransformerEncoderLayer(
            d_model=self.num_features,  # The size of the input feature vector
            nhead=4,  # Number of attention heads
            dim_feedforward=256,  # The size of the feedforward network in the encoder
            dropout=0.1,
            batch_first=True,  # Use batch_first for better inference performance
        )
        self.positional_encoding = LearnablePositionalEncoding(
            d_model=self.num_features
        )
        self.transformer_encoder = TransformerEncoder(transformer_layer, num_layers=2)
        self.attention_pooling = AttentionPooling(self.num_features)

        self.shared_head = nn.Sequential(
            nn.LayerNorm(8),
            nn.GELU(),
            nn.Dropout(0.2),
        )

        self.info_head = nn.Sequential(
            nn.Linear(self.num_infos, 8),
            nn.LayerNorm(8),
            nn.ELU(),
            nn.Dropout(0.2),
        )

        self.policy_head = nn.Sequential(nn.Linear(self.num_features + self.num_infos, self.num_actions))
        self.value_head = nn.Sequential(nn.Linear(self.num_features + self.num_infos, 1))

    def act(self, inputs, role):
        if role == "policy":
            return MultiCategoricalMixin.act(self, inputs, role)
        elif role == "value":
            return DeterministicMixin.act(self, inputs, role)

    def compute(self, inputs, role):
        states = unflatten_tensorized_space(obs, inputs["states"])
        features = states["features"]  # (batch_size, seq_length, num_features=8)
        infos = states["infos"] # (batch_size, num_infos=6)

        # features = torch.cat([features, infos.unsqueeze(1)], dim=1)

        if role == "policy":
            features = self.positional_encoding(features)
            features = self.transformer_encoder(features)
            features = self.attention_pooling(features)

            _shared_features = self.shared_head(features)
            _shared_infos = self.info_head(infos)

            self._fusion = torch.cat([_shared_features, _shared_infos], dim=-1)

            actions = self.policy_head(self._fusion)

            # rpo_alpha = 0.1
            # perturbation = torch.zeros_like(actions).uniform_(-rpo_alpha, rpo_alpha)
            # actions += perturbation

            return actions, {}

        elif role == "value":
            if self._fusion is None:
                features = self.positional_encoding(features)
                features = self.transformer_encoder(features)
                features = self.attention_pooling(features)

                _shared_features = self.shared_head(features)
                _shared_infos = self.info_head(infos)

                fusion = torch.cat([_shared_features, _shared_infos], dim=-1)
            else:
                fusion = self._fusion

            self._fusion = None

            value = self.value_head(fusion)
            return value, {}

In [None]:
policy: FQFPolicy = FQFPolicy(
    model=net,
    optim=optim,
    fraction_model=fraction_net,
    fraction_optim=fraction_optim,
    action_space=env.action_space,
    # discount_factor=args.gamma,
    # num_fractions=args.num_fractions,
    # ent_coef=args.ent_coef,
    # estimation_step=args.n_step,
    # target_update_freq=args.target_update_freq,
)#.to(args.device)

In [None]:
train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))
test_collector = Collector(policy, test_envs)

In [None]:
train_result = OnpolicyTrainer(
    policy=policy,
    batch_size=256,
    train_collector=train_collector,
    test_collector=test_collector,
    max_epoch=10,
    step_per_epoch=50000,
    repeat_per_collect=10,
    episode_per_test=10,
    step_per_collect=2000,
    stop_fn=lambda mean_reward: mean_reward >= 195,
).run()

In [None]:
train_result.pprint_asdict()

In [None]:
# Let's watch its performance!
policy.eval()
eval_result = test_collector.collect(n_episode=3, render=False)
print(f"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}")