# Libraries

In [6]:
import math

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# Config

In [3]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Analysis

## Train 

In [None]:
class CombatForecastDataset(Dataset):
    def __init__(self, path_lb_our, path_lb_bandit, path_fc_action, path_fc_our, path_fc_bandit, sos_value=9999.0):
        self.lb_our = np.load(path_lb_our)          # (N, 10, 18)
        self.lb_bandit = np.load(path_lb_bandit)    # (N, 10, 13)
        self.fc_action = np.load(path_fc_action)    # (N, 3, 10)
        self.fc_our = np.load(path_fc_our)          # (N, 3, 18)
        self.fc_bandit = np.load(path_fc_bandit)    # (N, 3, 13)
        self.sos_value = sos_value

    def __len__(self):
        return len(self.lb_our)

    def __getitem__(self, idx):
        # Encoder input: concat states
        src = np.concatenate([self.lb_our[idx], self.lb_bandit[idx]], axis=-1)  # (10, 31)

        # Decoder input: action forecast
        tgt = self.fc_action[idx]  # (3, 10)
        sos = np.ones((1, 10)) * self.sos_value
        tgt_input = np.vstack([sos, tgt[:-1]])  # (3, 10)

        # Decoder output target: concat future states
        tgt_output = np.concatenate([self.fc_our[idx], self.fc_bandit[idx]], axis=-1)  # (3, 31)

        return (
            torch.tensor(src, dtype=torch.float32),        # encoder input (10, 31)
            torch.tensor(tgt_input, dtype=torch.float32),  # decoder input (3, 10)
            torch.tensor(tgt_output, dtype=torch.float32)  # decoder target (3, 31)
        )


In [5]:
# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# Transformer Model
class TransformerModel(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            tgt_vocab_size,
            d_model=256,
            nhead=4,
            num_encoder_layers=3,
            num_decoder_layers=3,
            dim_feedforward=512,
            dropout=0.1):
        super().__init__()
        self.d_model = d_model

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)

        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )

        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def forward(
            self,
            src,
            tgt,
            src_mask=None,
            tgt_mask=None,
            src_key_padding_mask=None,
            tgt_key_padding_mask=None,
            memory_key_padding_mask=None):
        
        src = self.positional_encoding(self.src_embedding(src) * math.sqrt(self.d_model)).transpose(0, 1)
        tgt = self.positional_encoding(self.tgt_embedding(tgt) * math.sqrt(self.d_model)).transpose(0, 1)

        output = self.transformer(
            src, tgt, src_mask=src_mask, tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
        )

        return self.fc_out(output.transpose(0, 1))