# Libraries

In [None]:
import math

import numpy as np

import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm

# Constant

In [None]:
INPUT_DIR = os.path.join('..', 'data', 'processed')
MODEL_DIR = os.path.join('..', 'models')
MODEL_VERSION = 'v01'

# Config

In [None]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data Import

In [None]:
class CombatForecastDataset(Dataset):
    def __init__(self, path_lb_our, path_lb_bandit, path_fc_action, path_fc_our, path_fc_bandit, sos_value=9999.0):
        self.lb_our = np.load(path_lb_our)          # (N, 10, 18)
        self.lb_bandit = np.load(path_lb_bandit)    # (N, 10, 13)
        self.fc_action = np.load(path_fc_action)    # (N, 3, 10)
        self.fc_our = np.load(path_fc_our)          # (N, 3, 18)
        self.fc_bandit = np.load(path_fc_bandit)    # (N, 3, 13)
        self.sos_value = sos_value
        self.nr_of_padding = self.lb_our.shape[-1] - self.lb_bandit.shape[-1]

    def __len__(self):
        return len(self.lb_our)

    def __getitem__(self, idx):
        lb_our = self.lb_our[idx]         # (10, 18)
        lb_bandit = self.lb_bandit[idx]   # (10, 13)
    
        # Pad lb_bandit to (10, 18)
        pad_width = ((0, 0), (0, self.nr_of_padding))  # pad 5 columns (features) at the end
        lb_bandit_padded = np.pad(lb_bandit, pad_width=pad_width, mode='constant', constant_values=0)
    
        # Now both are (10, 18) → stack along time axis (axis=0)
        src = np.concatenate([lb_our, lb_bandit_padded], axis=0)  # (20, 18)
    
        # Decoder input
        tgt = self.fc_action[idx]  # (3, 10)
        sos = np.ones((1, tgt.shape[1])) * self.sos_value
        tgt_input = np.vstack([sos, tgt[:-1]])  # (3, 10)

        # Target output
        tgt_output = np.concatenate([self.fc_our[idx], self.fc_bandit[idx]], axis=-1)  # (3, 31)
    
        return (
            torch.tensor(src, dtype=torch.float32),        # encoder input (20, 18)
            torch.tensor(tgt_input, dtype=torch.float32),  # decoder input (3, 10)
            torch.tensor(tgt_output, dtype=torch.float32)  # decoder target (3, 31)
        )

In [None]:
training_dataset = CombatForecastDataset(
    os.path.join(INPUT_DIR, 'train_lb_state_our.npy'),
    os.path.join(INPUT_DIR, 'train_lb_state_bandit.npy'),
    os.path.join(INPUT_DIR, 'train_fc_action_our.npy'),
    os.path.join(INPUT_DIR, 'train_fc_state_our.npy'),
    os.path.join(INPUT_DIR, 'train_fc_state_bandit.npy'),
    sos_value=0)

testing_dataset = CombatForecastDataset(
    os.path.join(INPUT_DIR, 'test_lb_state_our.npy'),
    os.path.join(INPUT_DIR, 'test_lb_state_bandit.npy'),
    os.path.join(INPUT_DIR, 'test_fc_action_our.npy'),
    os.path.join(INPUT_DIR, 'test_fc_state_our.npy'),
    os.path.join(INPUT_DIR, 'test_fc_state_bandit.npy'),
    sos_value=0)

# Analysis

## UDF 

In [None]:
class ForecastTransformer(nn.Module):
    def __init__(self, enc_input_dim=18, dec_input_dim=10, out_dim=31,
                 d_model=256, nhead=4, num_encoder_layers=3, num_decoder_layers=3,
                 dim_feedforward=512, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # Project encoder and decoder inputs to d_model
        self.encoder_input_proj = nn.Linear(enc_input_dim, d_model)
        self.decoder_input_proj = nn.Linear(dec_input_dim, d_model)

        # Positional encoding
        self.positional_encoding = PositionalEncoding(d_model)

        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )

        # Output projection to target dimension
        self.output_proj = nn.Linear(d_model, out_dim)

    def forward(self, src, tgt_input, tgt_mask=None):
        # src: (batch, src_seq_len, enc_input_dim)
        # tgt_input: (batch, tgt_seq_len, dec_input_dim)

        src = self.encoder_input_proj(src) * math.sqrt(self.d_model)  # (batch, 20, d_model)
        tgt = self.decoder_input_proj(tgt_input) * math.sqrt(self.d_model)  # (batch, 3, d_model)

        src = self.positional_encoding(src).transpose(0, 1)  # (20, batch, d_model)
        tgt = self.positional_encoding(tgt).transpose(0, 1)  # (3, batch, d_model)

        out = self.transformer(src, tgt, tgt_mask=tgt_mask)  # (3, batch, d_model)
        out = self.output_proj(out.transpose(0, 1))  # (batch, 3, out_dim)

        return out

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)  # correctly handled across devices

    def forward(self, x):
        # Automatically move self.pe to same device as x
        return x + self.pe[:, :x.size(1)].to(x.device)

## Compile

In [None]:
train_data_loader = DataLoader(training_dataset, batch_size=128, shuffle=True)
test_data_loader = DataLoader(testing_dataset, batch_size=128, shuffle=True)
model = ForecastTransformer(
    enc_input_dim=15,
    dec_input_dim=9,
    out_dim=6,
    d_model=256,
    nhead=4,
    num_encoder_layers=3,
    num_decoder_layers=3,
    dim_feedforward=512,
    dropout=0.1)\
    .to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

## Fit

In [None]:
patience = 10  # number of epochs without improvement before stopping
best_val_loss = float('inf')  # initially, the best validation loss is infinite
epochs_without_improvement = 0  # to count how many epochs without improvement
start_after = 200
# Prevent model from seeing future tokens during training
def generate_square_subsequent_mask(sz):
    return torch.triu(torch.ones((sz, sz)) * float('-inf'), diagonal=1)
    
for epoch in tqdm(range(1000)):
    model.train()
    total_loss = 0
    for src, tgt_input, tgt_output in train_data_loader:
        src, tgt_input, tgt_output = src.to(device), tgt_input.to(device), tgt_output.to(device)

        tgt_mask = generate_square_subsequent_mask(tgt_input.size(1)).to(device)

        pred = model(src, tgt_input, tgt_mask)
        loss = criterion(pred, tgt_output)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for src, tgt_input, tgt_output in test_data_loader:
            src, tgt_input, tgt_output = src.to(device), tgt_input.to(device), tgt_output.to(device)
            tgt_mask = generate_square_subsequent_mask(tgt_input.size(1)).to(device)
            pred = model(src, tgt_input, tgt_mask)
            loss = criterion(pred, tgt_output)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(test_data_loader)
    print(f"Epoch {epoch+1}, Training Loss: {total_loss / len(train_data_loader):.4f}, Validation Loss: {avg_val_loss:.4f}")

    if epoch <= start_after:
        torch.save(model.state_dict(), os.path.join(MODEL_DIR, f'{MODEL_VERSION}.pth'))
        continue
        
    # Early stopping check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_without_improvement = 0  # Reset counter
        # Save the best model
        torch.save(model.state_dict(), os.path.join(MODEL_DIR, f'{MODEL_VERSION}.pth'))  # Save model weights
        print(f"Model saved with validation loss: {avg_val_loss:.4f}")
    else:
        epochs_without_improvement += 1

    # Stop training if the patience is exceeded
    if epochs_without_improvement >= patience:
        print(f"Early stopping after epoch {epoch+1} due to no improvement in validation loss.")
        break