# Libraries

In [1]:
import math

import numpy as np

import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm

# Constant

In [2]:
INPUT_DIR = os.path.join('..', 'data', 'processed')
MODEL_DIR = os.path.join('..', 'models')
MODEL_VERSION = 'v01'

# Config

In [3]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data Import

In [4]:
class CombatForecastDataset(Dataset):
    def __init__(self, path_lb_our, path_lb_bandit, path_fc_action, path_fc_our, path_fc_bandit):
        self.lb_our = np.load(path_lb_our)
        self.lb_bandit = np.load(path_lb_bandit)
        self.fc_action = np.load(path_fc_action)
        self.fc_our = np.load(path_fc_our)
        self.fc_bandit = np.load(path_fc_bandit)
        self.nr_of_padding = self.lb_our.shape[-1] - self.lb_bandit.shape[-1]

    def __len__(self):
        return len(self.lb_our)

    def __getitem__(self, idx):
        lb_our = self.lb_our[idx]
        lb_bandit = self.lb_bandit[idx]
    
        pad_width = ((0, 0), (0, self.nr_of_padding))
        lb_bandit_padded = np.pad(lb_bandit, pad_width=pad_width, mode='constant', constant_values=0)
        src = np.concatenate([lb_our, lb_bandit_padded], axis=0)
        
        # Decoder input
        # tgt_input = self.fc_action[idx]
        dec_beg = self.lb_bandit[idx][[-1], [3, 7, 11]]
        tgt_input = self.fc_bandit[idx]
        tgt_input = np.vstack([dec_beg, tgt_input[:-1]])

        # Target output
        tgt_output = self.fc_bandit[idx]
    
        return (
            torch.tensor(src, dtype=torch.float32),
            torch.tensor(tgt_input, dtype=torch.float32),
            torch.tensor(tgt_output, dtype=torch.float32)
        )

In [5]:
training_dataset = CombatForecastDataset(
    os.path.join(INPUT_DIR, 'train_lb_state_our.npy'),
    os.path.join(INPUT_DIR, 'train_lb_state_bandit.npy'),
    os.path.join(INPUT_DIR, 'train_fc_action_our.npy'),
    os.path.join(INPUT_DIR, 'train_fc_state_our.npy'),
    os.path.join(INPUT_DIR, 'train_fc_state_bandit.npy'))

testing_dataset = CombatForecastDataset(
    os.path.join(INPUT_DIR, 'test_lb_state_our.npy'),
    os.path.join(INPUT_DIR, 'test_lb_state_bandit.npy'),
    os.path.join(INPUT_DIR, 'test_fc_action_our.npy'),
    os.path.join(INPUT_DIR, 'test_fc_state_our.npy'),
    os.path.join(INPUT_DIR, 'test_fc_state_bandit.npy'))

# Analysis

## UDF 

In [6]:
class ForecastTransformer(nn.Module):
    def __init__(self, enc_input_dim=18, dec_input_dim=10, out_dim=31,
                 d_model=256, nhead=4, num_encoder_layers=3, num_decoder_layers=3,
                 dim_feedforward=512, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # Project encoder and decoder inputs to d_model
        self.encoder_input_proj = nn.Linear(enc_input_dim, d_model)
        self.decoder_input_proj = nn.Linear(dec_input_dim, d_model)

        # Positional encoding
        self.positional_encoding = PositionalEncoding(d_model)

        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )

        # Output projection to target dimension
        self.output_proj = nn.Linear(d_model, out_dim)

    def forward(self, src, tgt_input, tgt_mask=None):
        # src: (batch, src_seq_len, enc_input_dim)
        # tgt_input: (batch, tgt_seq_len, dec_input_dim)

        src = self.encoder_input_proj(src) * math.sqrt(self.d_model)  # (batch, 20, d_model)
        tgt = self.decoder_input_proj(tgt_input) * math.sqrt(self.d_model)  # (batch, 3, d_model)

        src = self.positional_encoding(src).transpose(0, 1)  # (20, batch, d_model)
        tgt = self.positional_encoding(tgt).transpose(0, 1)  # (3, batch, d_model)

        out = self.transformer(src, tgt, tgt_mask=tgt_mask)  # (3, batch, d_model)
        out = self.output_proj(out.transpose(0, 1))  # (batch, 3, out_dim)

        return out

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)  # correctly handled across devices

    def forward(self, x):
        # Automatically move self.pe to same device as x
        return x + self.pe[:, :x.size(1)].to(x.device)

## Compile

In [7]:
train_data_loader = DataLoader(training_dataset, batch_size=128, shuffle=True)
test_data_loader = DataLoader(testing_dataset, batch_size=128, shuffle=True)
model = ForecastTransformer(
    enc_input_dim=15,
    dec_input_dim=3,
    out_dim=3,
    d_model=256,
    nhead=4,
    num_encoder_layers=3,
    num_decoder_layers=3,
    dim_feedforward=512,
    dropout=0.1)\
    .to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)



## Fit

In [8]:
patience = 10  # number of epochs without improvement before stopping
best_val_loss = float('inf')  # initially, the best validation loss is infinite
epochs_without_improvement = 0  # to count how many epochs without improvement
start_after = 200
# Prevent model from seeing future tokens during training
def generate_square_subsequent_mask(sz):
    return torch.triu(torch.ones((sz, sz)) * float('-inf'), diagonal=1)
    
for epoch in tqdm(range(1000)):
    model.train()
    total_loss = 0
    for src, tgt_input, tgt_output in train_data_loader:
        src, tgt_input, tgt_output = src.to(device), tgt_input.to(device), tgt_output.to(device)

        tgt_mask = generate_square_subsequent_mask(tgt_input.size(1)).to(device)

        pred = model(src, tgt_input, tgt_mask)
        loss = criterion(pred, tgt_output)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for src, tgt_input, tgt_output in test_data_loader:
            src, tgt_input, tgt_output = src.to(device), tgt_input.to(device), tgt_output.to(device)
            tgt_mask = generate_square_subsequent_mask(tgt_input.size(1)).to(device)
            pred = model(src, tgt_input, tgt_mask)
            loss = criterion(pred, tgt_output)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(test_data_loader)
    print(f"Epoch {epoch+1}, Training Loss: {total_loss / len(train_data_loader):.4f}, Validation Loss: {avg_val_loss:.4f}")

    if epoch <= start_after:
        torch.save(model.state_dict(), os.path.join(MODEL_DIR, f'{MODEL_VERSION}.pth'))
        continue
        
    # Early stopping check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_without_improvement = 0  # Reset counter
        # Save the best model
        torch.save(model.state_dict(), os.path.join(MODEL_DIR, f'{MODEL_VERSION}.pth'))  # Save model weights
        print(f"Model saved with validation loss: {avg_val_loss:.4f}")
    else:
        epochs_without_improvement += 1

    # Stop training if the patience is exceeded
    if epochs_without_improvement >= patience:
        print(f"Early stopping after epoch {epoch+1} due to no improvement in validation loss.")
        break

  0% 1/1000 [00:35<9:48:01, 35.32s/it]

Epoch 1, Training Loss: 0.0914, Validation Loss: 0.0442


  0% 2/1000 [01:10<9:41:52, 34.98s/it]

Epoch 2, Training Loss: 0.0123, Validation Loss: 0.0066


  0% 3/1000 [01:44<9:39:18, 34.86s/it]

Epoch 3, Training Loss: 0.0041, Validation Loss: 0.0050


  0% 4/1000 [02:19<9:37:43, 34.80s/it]

Epoch 4, Training Loss: 0.0023, Validation Loss: 0.0042


  0% 5/1000 [02:54<9:36:22, 34.76s/it]

Epoch 5, Training Loss: 0.0016, Validation Loss: 0.0035


  1% 6/1000 [03:28<9:35:37, 34.75s/it]

Epoch 6, Training Loss: 0.0012, Validation Loss: 0.0037


  1% 7/1000 [04:03<9:35:04, 34.75s/it]

Epoch 7, Training Loss: 0.0009, Validation Loss: 0.0031


  1% 8/1000 [04:38<9:35:02, 34.78s/it]

Epoch 8, Training Loss: 0.0008, Validation Loss: 0.0032


  1% 9/1000 [05:13<9:33:48, 34.74s/it]

Epoch 9, Training Loss: 0.0007, Validation Loss: 0.0031


  1% 10/1000 [05:47<9:33:06, 34.73s/it]

Epoch 10, Training Loss: 0.0006, Validation Loss: 0.0031


  1% 11/1000 [06:22<9:32:37, 34.74s/it]

Epoch 11, Training Loss: 0.0006, Validation Loss: 0.0028


  1% 12/1000 [06:57<9:31:20, 34.70s/it]

Epoch 12, Training Loss: 0.0005, Validation Loss: 0.0028


  1% 13/1000 [07:31<9:30:48, 34.70s/it]

Epoch 13, Training Loss: 0.0005, Validation Loss: 0.0032


  1% 14/1000 [08:06<9:30:10, 34.70s/it]

Epoch 14, Training Loss: 0.0005, Validation Loss: 0.0027


  2% 15/1000 [08:41<9:29:25, 34.69s/it]

Epoch 15, Training Loss: 0.0004, Validation Loss: 0.0031


  2% 16/1000 [09:15<9:28:43, 34.68s/it]

Epoch 16, Training Loss: 0.0004, Validation Loss: 0.0026


  2% 17/1000 [09:50<9:28:07, 34.68s/it]

Epoch 17, Training Loss: 0.0004, Validation Loss: 0.0028


  2% 18/1000 [10:25<9:27:20, 34.66s/it]

Epoch 18, Training Loss: 0.0004, Validation Loss: 0.0025


  2% 19/1000 [10:59<9:26:45, 34.66s/it]

Epoch 19, Training Loss: 0.0004, Validation Loss: 0.0025


  2% 20/1000 [11:34<9:26:14, 34.67s/it]

Epoch 20, Training Loss: 0.0003, Validation Loss: 0.0026


  2% 21/1000 [12:09<9:25:24, 34.65s/it]

Epoch 21, Training Loss: 0.0003, Validation Loss: 0.0025


  2% 22/1000 [12:43<9:24:46, 34.65s/it]

Epoch 22, Training Loss: 0.0003, Validation Loss: 0.0022


  2% 23/1000 [13:18<9:24:20, 34.66s/it]

Epoch 23, Training Loss: 0.0003, Validation Loss: 0.0022


  2% 24/1000 [13:53<9:23:27, 34.64s/it]

Epoch 24, Training Loss: 0.0003, Validation Loss: 0.0023


  2% 25/1000 [14:27<9:22:04, 34.59s/it]

Epoch 25, Training Loss: 0.0003, Validation Loss: 0.0021


  3% 26/1000 [15:02<9:20:54, 34.55s/it]

Epoch 26, Training Loss: 0.0003, Validation Loss: 0.0021


  3% 27/1000 [15:36<9:19:44, 34.52s/it]

Epoch 27, Training Loss: 0.0003, Validation Loss: 0.0025


  3% 28/1000 [16:10<9:18:40, 34.49s/it]

Epoch 28, Training Loss: 0.0003, Validation Loss: 0.0023


  3% 29/1000 [16:45<9:18:12, 34.49s/it]

Epoch 29, Training Loss: 0.0002, Validation Loss: 0.0025


  3% 30/1000 [17:19<9:17:51, 34.51s/it]

Epoch 30, Training Loss: 0.0002, Validation Loss: 0.0026


  3% 31/1000 [17:54<9:17:04, 34.49s/it]

Epoch 31, Training Loss: 0.0002, Validation Loss: 0.0026


  3% 32/1000 [18:29<9:17:18, 34.54s/it]

Epoch 32, Training Loss: 0.0002, Validation Loss: 0.0027


  3% 33/1000 [19:03<9:17:36, 34.60s/it]

Epoch 33, Training Loss: 0.0002, Validation Loss: 0.0031


  3% 34/1000 [19:38<9:15:58, 34.53s/it]

Epoch 34, Training Loss: 0.0002, Validation Loss: 0.0028


  4% 35/1000 [20:12<9:15:33, 34.54s/it]

Epoch 35, Training Loss: 0.0002, Validation Loss: 0.0028


  4% 36/1000 [20:47<9:15:26, 34.57s/it]

Epoch 36, Training Loss: 0.0002, Validation Loss: 0.0027


  4% 37/1000 [21:22<9:16:57, 34.70s/it]

Epoch 37, Training Loss: 0.0001, Validation Loss: 0.0027


  4% 38/1000 [21:57<9:16:19, 34.70s/it]

Epoch 38, Training Loss: 0.0001, Validation Loss: 0.0027


  4% 39/1000 [22:31<9:15:52, 34.71s/it]

Epoch 39, Training Loss: 0.0001, Validation Loss: 0.0026


  4% 40/1000 [23:06<9:15:57, 34.75s/it]

Epoch 40, Training Loss: 0.0001, Validation Loss: 0.0025


  4% 41/1000 [23:41<9:14:58, 34.72s/it]

Epoch 41, Training Loss: 0.0001, Validation Loss: 0.0027


  4% 42/1000 [24:16<9:14:39, 34.74s/it]

Epoch 42, Training Loss: 0.0001, Validation Loss: 0.0025


  4% 43/1000 [24:50<9:14:15, 34.75s/it]

Epoch 43, Training Loss: 0.0001, Validation Loss: 0.0024


  4% 44/1000 [25:25<9:13:59, 34.77s/it]

Epoch 44, Training Loss: 0.0001, Validation Loss: 0.0025


  4% 45/1000 [26:00<9:13:52, 34.80s/it]

Epoch 45, Training Loss: 0.0001, Validation Loss: 0.0024


  5% 46/1000 [26:35<9:12:46, 34.77s/it]

Epoch 46, Training Loss: 0.0001, Validation Loss: 0.0026


  5% 47/1000 [27:09<9:11:35, 34.73s/it]

Epoch 47, Training Loss: 0.0001, Validation Loss: 0.0025


  5% 48/1000 [27:44<9:10:51, 34.72s/it]

Epoch 48, Training Loss: 0.0001, Validation Loss: 0.0026


  5% 49/1000 [28:19<9:12:18, 34.85s/it]

Epoch 49, Training Loss: 0.0001, Validation Loss: 0.0025


  5% 50/1000 [28:54<9:11:23, 34.82s/it]

Epoch 50, Training Loss: 0.0001, Validation Loss: 0.0028


  5% 51/1000 [29:29<9:10:09, 34.78s/it]

Epoch 51, Training Loss: 0.0001, Validation Loss: 0.0028


  5% 52/1000 [30:03<9:09:12, 34.76s/it]

Epoch 52, Training Loss: 0.0001, Validation Loss: 0.0028


  5% 53/1000 [30:38<9:08:06, 34.73s/it]

Epoch 53, Training Loss: 0.0001, Validation Loss: 0.0028


  5% 54/1000 [31:13<9:07:23, 34.72s/it]

Epoch 54, Training Loss: 0.0001, Validation Loss: 0.0031


  6% 55/1000 [31:47<9:06:36, 34.71s/it]

Epoch 55, Training Loss: 0.0001, Validation Loss: 0.0030


  6% 56/1000 [32:22<9:05:35, 34.68s/it]

Epoch 56, Training Loss: 0.0001, Validation Loss: 0.0030


  6% 57/1000 [32:57<9:05:06, 34.68s/it]

Epoch 57, Training Loss: 0.0001, Validation Loss: 0.0030


  6% 58/1000 [33:31<9:03:38, 34.63s/it]

Epoch 58, Training Loss: 0.0001, Validation Loss: 0.0030


  6% 59/1000 [34:06<9:02:15, 34.58s/it]

Epoch 59, Training Loss: 0.0001, Validation Loss: 0.0032


  6% 60/1000 [34:40<9:01:06, 34.54s/it]

Epoch 60, Training Loss: 0.0001, Validation Loss: 0.0032


  6% 61/1000 [35:15<9:00:05, 34.51s/it]

Epoch 61, Training Loss: 0.0001, Validation Loss: 0.0034


  6% 62/1000 [35:49<8:59:47, 34.53s/it]

Epoch 62, Training Loss: 0.0001, Validation Loss: 0.0034


  6% 63/1000 [36:24<8:59:32, 34.55s/it]

Epoch 63, Training Loss: 0.0001, Validation Loss: 0.0035


  6% 64/1000 [36:58<8:59:45, 34.60s/it]

Epoch 64, Training Loss: 0.0001, Validation Loss: 0.0038


  6% 65/1000 [37:33<8:59:53, 34.65s/it]

Epoch 65, Training Loss: 0.0001, Validation Loss: 0.0038


  7% 66/1000 [38:08<8:59:25, 34.65s/it]

Epoch 66, Training Loss: 0.0002, Validation Loss: 0.0037


  7% 67/1000 [38:43<8:58:58, 34.66s/it]

Epoch 67, Training Loss: 0.0000, Validation Loss: 0.0035


  7% 68/1000 [39:17<8:58:19, 34.66s/it]

Epoch 68, Training Loss: 0.0000, Validation Loss: 0.0036


  7% 69/1000 [39:52<8:57:51, 34.66s/it]

Epoch 69, Training Loss: 0.0000, Validation Loss: 0.0035


  7% 70/1000 [40:26<8:56:48, 34.63s/it]

Epoch 70, Training Loss: 0.0001, Validation Loss: 0.0041


  7% 71/1000 [41:01<8:56:20, 34.64s/it]

Epoch 71, Training Loss: 0.0001, Validation Loss: 0.0039


  7% 72/1000 [41:36<8:55:51, 34.65s/it]

Epoch 72, Training Loss: 0.0001, Validation Loss: 0.0046


  7% 73/1000 [42:10<8:55:07, 34.64s/it]

Epoch 73, Training Loss: 0.0000, Validation Loss: 0.0040


  7% 74/1000 [42:45<8:54:51, 34.66s/it]

Epoch 74, Training Loss: 0.0000, Validation Loss: 0.0042


  8% 75/1000 [43:20<8:54:17, 34.66s/it]

Epoch 75, Training Loss: 0.0000, Validation Loss: 0.0039


  8% 76/1000 [43:54<8:53:30, 34.64s/it]

Epoch 76, Training Loss: 0.0001, Validation Loss: 0.0045


  8% 77/1000 [44:29<8:52:19, 34.60s/it]

Epoch 77, Training Loss: 0.0000, Validation Loss: 0.0042


  8% 78/1000 [45:04<8:52:17, 34.64s/it]

Epoch 78, Training Loss: 0.0000, Validation Loss: 0.0046


  8% 79/1000 [45:38<8:52:02, 34.66s/it]

Epoch 79, Training Loss: 0.0001, Validation Loss: 0.0045


  8% 80/1000 [46:13<8:51:14, 34.65s/it]

Epoch 80, Training Loss: 0.0000, Validation Loss: 0.0045


  8% 81/1000 [46:48<8:50:56, 34.66s/it]

Epoch 81, Training Loss: 0.0000, Validation Loss: 0.0045


  8% 82/1000 [47:22<8:50:10, 34.65s/it]

Epoch 82, Training Loss: 0.0000, Validation Loss: 0.0047


  8% 83/1000 [47:57<8:49:39, 34.66s/it]

Epoch 83, Training Loss: 0.0000, Validation Loss: 0.0047


  8% 84/1000 [48:31<8:48:30, 34.62s/it]

Epoch 84, Training Loss: 0.0000, Validation Loss: 0.0043


  8% 85/1000 [49:06<8:47:55, 34.62s/it]

Epoch 85, Training Loss: 0.0001, Validation Loss: 0.0051


  9% 86/1000 [49:41<8:47:35, 34.63s/it]

Epoch 86, Training Loss: 0.0001, Validation Loss: 0.0052


  9% 87/1000 [50:15<8:47:16, 34.65s/it]

Epoch 87, Training Loss: 0.0000, Validation Loss: 0.0050


  9% 88/1000 [50:50<8:46:47, 34.66s/it]

Epoch 88, Training Loss: 0.0000, Validation Loss: 0.0049


  9% 89/1000 [51:25<8:46:22, 34.67s/it]

Epoch 89, Training Loss: 0.0000, Validation Loss: 0.0050


  9% 90/1000 [51:59<8:45:26, 34.64s/it]

Epoch 90, Training Loss: 0.0000, Validation Loss: 0.0046


  9% 91/1000 [52:34<8:45:07, 34.66s/it]

Epoch 91, Training Loss: 0.0000, Validation Loss: 0.0051


  9% 92/1000 [53:09<8:44:46, 34.68s/it]

Epoch 92, Training Loss: 0.0000, Validation Loss: 0.0048


  9% 93/1000 [53:43<8:44:19, 34.69s/it]

Epoch 93, Training Loss: 0.0001, Validation Loss: 0.0060


  9% 94/1000 [54:18<8:43:43, 34.68s/it]

Epoch 94, Training Loss: 0.0000, Validation Loss: 0.0052


 10% 95/1000 [54:53<8:43:08, 34.68s/it]

Epoch 95, Training Loss: 0.0000, Validation Loss: 0.0053


 10% 96/1000 [55:27<8:42:24, 34.67s/it]

Epoch 96, Training Loss: 0.0000, Validation Loss: 0.0055


 10% 97/1000 [56:02<8:41:28, 34.65s/it]

Epoch 97, Training Loss: 0.0001, Validation Loss: 0.0055


 10% 98/1000 [56:37<8:40:52, 34.65s/it]

Epoch 98, Training Loss: 0.0000, Validation Loss: 0.0057


 10% 99/1000 [57:11<8:40:25, 34.66s/it]

Epoch 99, Training Loss: 0.0000, Validation Loss: 0.0058


 10% 100/1000 [57:46<8:40:03, 34.67s/it]

Epoch 100, Training Loss: 0.0000, Validation Loss: 0.0057


 10% 101/1000 [58:21<8:39:08, 34.65s/it]

Epoch 101, Training Loss: 0.0000, Validation Loss: 0.0058


 10% 102/1000 [58:55<8:38:14, 34.63s/it]

Epoch 102, Training Loss: 0.0000, Validation Loss: 0.0062


 10% 103/1000 [59:30<8:37:43, 34.63s/it]

Epoch 103, Training Loss: 0.0000, Validation Loss: 0.0061


 10% 104/1000 [1:00:04<8:36:43, 34.60s/it]

Epoch 104, Training Loss: 0.0000, Validation Loss: 0.0059


 10% 105/1000 [1:00:39<8:36:02, 34.60s/it]

Epoch 105, Training Loss: 0.0000, Validation Loss: 0.0060


 11% 106/1000 [1:01:14<8:35:15, 34.58s/it]

Epoch 106, Training Loss: 0.0000, Validation Loss: 0.0063


 11% 107/1000 [1:01:48<8:34:33, 34.57s/it]

Epoch 107, Training Loss: 0.0000, Validation Loss: 0.0060


 11% 108/1000 [1:02:23<8:33:59, 34.57s/it]

Epoch 108, Training Loss: 0.0000, Validation Loss: 0.0061


 11% 109/1000 [1:02:57<8:33:27, 34.58s/it]

Epoch 109, Training Loss: 0.0000, Validation Loss: 0.0062


 11% 110/1000 [1:03:34<8:44:25, 35.35s/it]

Epoch 110, Training Loss: 0.0000, Validation Loss: 0.0063


 11% 111/1000 [1:04:09<8:40:06, 35.10s/it]

Epoch 111, Training Loss: 0.0000, Validation Loss: 0.0068


 11% 112/1000 [1:04:44<8:36:57, 34.93s/it]

Epoch 112, Training Loss: 0.0000, Validation Loss: 0.0065


 11% 113/1000 [1:05:18<8:34:59, 34.84s/it]

Epoch 113, Training Loss: 0.0000, Validation Loss: 0.0066


 11% 114/1000 [1:05:53<8:33:22, 34.77s/it]

Epoch 114, Training Loss: 0.0000, Validation Loss: 0.0067


 12% 115/1000 [1:06:27<8:32:00, 34.71s/it]

Epoch 115, Training Loss: 0.0000, Validation Loss: 0.0064


 12% 116/1000 [1:07:02<8:30:47, 34.67s/it]

Epoch 116, Training Loss: 0.0000, Validation Loss: 0.0064


 12% 117/1000 [1:07:36<8:29:41, 34.63s/it]

Epoch 117, Training Loss: 0.0000, Validation Loss: 0.0069


 12% 118/1000 [1:08:11<8:28:32, 34.59s/it]

Epoch 118, Training Loss: 0.0000, Validation Loss: 0.0068


 12% 119/1000 [1:08:45<8:27:47, 34.58s/it]

Epoch 119, Training Loss: 0.0000, Validation Loss: 0.0066


 12% 120/1000 [1:09:20<8:27:39, 34.61s/it]

Epoch 120, Training Loss: 0.0000, Validation Loss: 0.0070


 12% 121/1000 [1:09:55<8:27:27, 34.64s/it]

Epoch 121, Training Loss: 0.0000, Validation Loss: 0.0071


 12% 122/1000 [1:10:29<8:26:39, 34.62s/it]

Epoch 122, Training Loss: 0.0000, Validation Loss: 0.0072


 12% 123/1000 [1:11:04<8:26:14, 34.63s/it]

Epoch 123, Training Loss: 0.0000, Validation Loss: 0.0074


 12% 124/1000 [1:11:39<8:25:50, 34.65s/it]

Epoch 124, Training Loss: 0.0000, Validation Loss: 0.0070


 12% 125/1000 [1:12:13<8:24:55, 34.62s/it]

Epoch 125, Training Loss: 0.0000, Validation Loss: 0.0074


 13% 126/1000 [1:12:48<8:24:15, 34.62s/it]

Epoch 126, Training Loss: 0.0001, Validation Loss: 0.0074


 13% 127/1000 [1:13:23<8:23:54, 34.63s/it]

Epoch 127, Training Loss: 0.0000, Validation Loss: 0.0073


 13% 128/1000 [1:13:58<8:25:17, 34.77s/it]

Epoch 128, Training Loss: 0.0000, Validation Loss: 0.0073


 13% 129/1000 [1:14:33<8:25:13, 34.80s/it]

Epoch 129, Training Loss: 0.0000, Validation Loss: 0.0071


 13% 130/1000 [1:15:07<8:24:00, 34.76s/it]

Epoch 130, Training Loss: 0.0000, Validation Loss: 0.0075


 13% 131/1000 [1:15:42<8:22:26, 34.69s/it]

Epoch 131, Training Loss: 0.0000, Validation Loss: 0.0074


 13% 132/1000 [1:16:16<8:21:02, 34.63s/it]

Epoch 132, Training Loss: 0.0000, Validation Loss: 0.0075


 13% 133/1000 [1:16:51<8:20:28, 34.63s/it]

Epoch 133, Training Loss: 0.0000, Validation Loss: 0.0083


 13% 134/1000 [1:17:26<8:20:11, 34.65s/it]

Epoch 134, Training Loss: 0.0000, Validation Loss: 0.0080


 14% 135/1000 [1:18:03<8:29:25, 35.34s/it]

Epoch 135, Training Loss: 0.0000, Validation Loss: 0.0076


 14% 136/1000 [1:19:36<12:37:42, 52.62s/it]

Epoch 136, Training Loss: 0.0000, Validation Loss: 0.0078


 14% 137/1000 [1:21:08<15:30:48, 64.71s/it]

Epoch 137, Training Loss: 0.0000, Validation Loss: 0.0079


 14% 138/1000 [1:22:41<17:31:42, 73.20s/it]

Epoch 138, Training Loss: 0.0000, Validation Loss: 0.0077


 14% 139/1000 [1:24:02<18:03:47, 75.53s/it]

Epoch 139, Training Loss: 0.0000, Validation Loss: 0.0077


 14% 140/1000 [1:24:37<15:06:11, 63.22s/it]

Epoch 140, Training Loss: 0.0000, Validation Loss: 0.0078


 14% 141/1000 [1:25:12<13:02:22, 54.65s/it]

Epoch 141, Training Loss: 0.0000, Validation Loss: 0.0078


 14% 142/1000 [1:25:46<11:35:40, 48.65s/it]

Epoch 142, Training Loss: 0.0000, Validation Loss: 0.0077


 14% 143/1000 [1:26:21<10:34:19, 44.41s/it]

Epoch 143, Training Loss: 0.0014, Validation Loss: 0.0125


 14% 144/1000 [1:26:55<9:51:25, 41.46s/it] 

Epoch 144, Training Loss: 0.0004, Validation Loss: 0.0053


 14% 145/1000 [1:27:30<9:21:18, 39.39s/it]

Epoch 145, Training Loss: 0.0001, Validation Loss: 0.0053


 15% 146/1000 [1:28:04<8:59:56, 37.94s/it]

Epoch 146, Training Loss: 0.0001, Validation Loss: 0.0050


 15% 147/1000 [1:28:39<8:44:24, 36.89s/it]

Epoch 147, Training Loss: 0.0001, Validation Loss: 0.0052


 15% 148/1000 [1:29:13<8:34:04, 36.20s/it]

Epoch 148, Training Loss: 0.0001, Validation Loss: 0.0053


 15% 149/1000 [1:29:48<8:26:20, 35.70s/it]

Epoch 149, Training Loss: 0.0000, Validation Loss: 0.0054


 15% 150/1000 [1:30:22<8:20:38, 35.34s/it]

Epoch 150, Training Loss: 0.0000, Validation Loss: 0.0056


 15% 151/1000 [1:30:57<8:16:30, 35.09s/it]

Epoch 151, Training Loss: 0.0000, Validation Loss: 0.0057


 15% 152/1000 [1:31:32<8:14:04, 34.96s/it]

Epoch 152, Training Loss: 0.0000, Validation Loss: 0.0057


 15% 153/1000 [1:32:06<8:11:43, 34.83s/it]

Epoch 153, Training Loss: 0.0000, Validation Loss: 0.0061


 15% 154/1000 [1:32:41<8:09:43, 34.73s/it]

Epoch 154, Training Loss: 0.0000, Validation Loss: 0.0060


 16% 155/1000 [1:33:15<8:08:36, 34.69s/it]

Epoch 155, Training Loss: 0.0000, Validation Loss: 0.0060


 16% 156/1000 [1:33:50<8:07:34, 34.66s/it]

Epoch 156, Training Loss: 0.0000, Validation Loss: 0.0064


 16% 157/1000 [1:34:24<8:06:21, 34.62s/it]

Epoch 157, Training Loss: 0.0000, Validation Loss: 0.0064


 16% 158/1000 [1:34:59<8:05:22, 34.59s/it]

Epoch 158, Training Loss: 0.0000, Validation Loss: 0.0063


 16% 159/1000 [1:35:33<8:04:40, 34.58s/it]

Epoch 159, Training Loss: 0.0000, Validation Loss: 0.0065


 16% 160/1000 [1:36:08<8:04:00, 34.57s/it]

Epoch 160, Training Loss: 0.0000, Validation Loss: 0.0067


 16% 161/1000 [1:37:28<11:13:43, 48.18s/it]

Epoch 161, Training Loss: 0.0000, Validation Loss: 0.0067


 16% 162/1000 [1:38:02<10:15:03, 44.04s/it]

Epoch 162, Training Loss: 0.0000, Validation Loss: 0.0066


 16% 163/1000 [1:38:37<9:34:28, 41.18s/it] 

Epoch 163, Training Loss: 0.0000, Validation Loss: 0.0067


 16% 164/1000 [1:39:11<9:06:01, 39.19s/it]

Epoch 164, Training Loss: 0.0000, Validation Loss: 0.0068


 16% 165/1000 [1:39:46<8:45:39, 37.77s/it]

Epoch 165, Training Loss: 0.0000, Validation Loss: 0.0068


 17% 166/1000 [1:40:20<8:31:10, 36.78s/it]

Epoch 166, Training Loss: 0.0000, Validation Loss: 0.0070


 17% 167/1000 [1:40:55<8:21:14, 36.10s/it]

Epoch 167, Training Loss: 0.0000, Validation Loss: 0.0072


 17% 168/1000 [1:41:29<8:14:21, 35.65s/it]

Epoch 168, Training Loss: 0.0000, Validation Loss: 0.0071


 17% 169/1000 [1:42:04<8:09:08, 35.32s/it]

Epoch 169, Training Loss: 0.0000, Validation Loss: 0.0072


 17% 170/1000 [1:42:39<8:05:24, 35.09s/it]

Epoch 170, Training Loss: 0.0000, Validation Loss: 0.0075


 17% 171/1000 [1:43:13<8:02:25, 34.92s/it]

Epoch 171, Training Loss: 0.0000, Validation Loss: 0.0076


 17% 172/1000 [1:43:48<8:00:28, 34.82s/it]

Epoch 172, Training Loss: 0.0000, Validation Loss: 0.0076


 17% 173/1000 [1:44:22<7:58:16, 34.70s/it]

Epoch 173, Training Loss: 0.0000, Validation Loss: 0.0077


 17% 174/1000 [1:44:57<7:57:08, 34.66s/it]

Epoch 174, Training Loss: 0.0000, Validation Loss: 0.0077


 18% 175/1000 [1:45:31<7:56:23, 34.65s/it]

Epoch 175, Training Loss: 0.0000, Validation Loss: 0.0077


 18% 176/1000 [1:46:06<7:55:21, 34.61s/it]

Epoch 176, Training Loss: 0.0000, Validation Loss: 0.0075


 18% 177/1000 [1:46:40<7:54:33, 34.60s/it]

Epoch 177, Training Loss: 0.0000, Validation Loss: 0.0079


 18% 178/1000 [1:47:15<7:53:54, 34.59s/it]

Epoch 178, Training Loss: 0.0000, Validation Loss: 0.0078


 18% 179/1000 [1:47:49<7:53:04, 34.57s/it]

Epoch 179, Training Loss: 0.0000, Validation Loss: 0.0080


 18% 180/1000 [1:48:24<7:51:59, 34.54s/it]

Epoch 180, Training Loss: 0.0000, Validation Loss: 0.0079


 18% 181/1000 [1:48:58<7:51:46, 34.56s/it]

Epoch 181, Training Loss: 0.0000, Validation Loss: 0.0080


 18% 182/1000 [1:49:33<7:51:26, 34.58s/it]

Epoch 182, Training Loss: 0.0000, Validation Loss: 0.0079


 18% 183/1000 [1:50:08<7:50:43, 34.57s/it]

Epoch 183, Training Loss: 0.0000, Validation Loss: 0.0080


 18% 184/1000 [1:50:42<7:50:00, 34.56s/it]

Epoch 184, Training Loss: 0.0000, Validation Loss: 0.0081


 18% 185/1000 [1:51:17<7:49:42, 34.58s/it]

Epoch 185, Training Loss: 0.0000, Validation Loss: 0.0082


 19% 186/1000 [1:51:51<7:48:54, 34.56s/it]

Epoch 186, Training Loss: 0.0000, Validation Loss: 0.0081


 19% 187/1000 [1:52:26<7:47:55, 34.53s/it]

Epoch 187, Training Loss: 0.0000, Validation Loss: 0.0080


 19% 188/1000 [1:53:00<7:47:33, 34.55s/it]

Epoch 188, Training Loss: 0.0000, Validation Loss: 0.0079


 19% 189/1000 [1:53:35<7:47:24, 34.58s/it]

Epoch 189, Training Loss: 0.0000, Validation Loss: 0.0081


 19% 190/1000 [1:54:10<7:46:40, 34.57s/it]

Epoch 190, Training Loss: 0.0000, Validation Loss: 0.0082


 19% 191/1000 [1:54:44<7:46:03, 34.57s/it]

Epoch 191, Training Loss: 0.0000, Validation Loss: 0.0080


 19% 192/1000 [1:55:19<7:45:44, 34.58s/it]

Epoch 192, Training Loss: 0.0000, Validation Loss: 0.0081


 19% 193/1000 [1:55:53<7:44:54, 34.57s/it]

Epoch 193, Training Loss: 0.0000, Validation Loss: 0.0082


 19% 194/1000 [1:56:28<7:43:54, 34.53s/it]

Epoch 194, Training Loss: 0.0000, Validation Loss: 0.0080


 20% 195/1000 [1:57:02<7:43:49, 34.57s/it]

Epoch 195, Training Loss: 0.0000, Validation Loss: 0.0081


 20% 196/1000 [1:57:37<7:43:42, 34.61s/it]

Epoch 196, Training Loss: 0.0000, Validation Loss: 0.0082


 20% 197/1000 [1:58:12<7:43:06, 34.60s/it]

Epoch 197, Training Loss: 0.0000, Validation Loss: 0.0082


 20% 198/1000 [1:58:46<7:42:36, 34.61s/it]

Epoch 198, Training Loss: 0.0000, Validation Loss: 0.0082


 20% 199/1000 [1:59:21<7:42:04, 34.61s/it]

Epoch 199, Training Loss: 0.0000, Validation Loss: 0.0080


 20% 200/1000 [1:59:55<7:41:05, 34.58s/it]

Epoch 200, Training Loss: 0.0000, Validation Loss: 0.0084


 20% 201/1000 [2:00:30<7:39:59, 34.54s/it]

Epoch 201, Training Loss: 0.0000, Validation Loss: 0.0080


 20% 202/1000 [2:01:04<7:39:37, 34.56s/it]

Epoch 202, Training Loss: 0.0000, Validation Loss: 0.0081
Model saved with validation loss: 0.0081


 20% 203/1000 [2:01:39<7:39:13, 34.57s/it]

Epoch 203, Training Loss: 0.0000, Validation Loss: 0.0081


 20% 204/1000 [2:02:14<7:38:32, 34.56s/it]

Epoch 204, Training Loss: 0.0000, Validation Loss: 0.0082


 20% 205/1000 [2:02:48<7:38:03, 34.57s/it]

Epoch 205, Training Loss: 0.0000, Validation Loss: 0.0084


 21% 206/1000 [2:03:23<7:37:19, 34.56s/it]

Epoch 206, Training Loss: 0.0000, Validation Loss: 0.0084


 21% 207/1000 [2:03:57<7:36:18, 34.53s/it]

Epoch 207, Training Loss: 0.0000, Validation Loss: 0.0086


 21% 208/1000 [2:04:32<7:35:29, 34.51s/it]

Epoch 208, Training Loss: 0.0000, Validation Loss: 0.0085


 21% 209/1000 [2:05:06<7:35:24, 34.54s/it]

Epoch 209, Training Loss: 0.0000, Validation Loss: 0.0085


 21% 210/1000 [2:05:41<7:34:57, 34.55s/it]

Epoch 210, Training Loss: 0.0000, Validation Loss: 0.0085


 21% 211/1000 [2:06:15<7:34:33, 34.57s/it]

Epoch 211, Training Loss: 0.0000, Validation Loss: 0.0085


 21% 211/1000 [2:06:50<7:54:18, 36.07s/it]

Epoch 212, Training Loss: 0.0000, Validation Loss: 0.0087
Early stopping after epoch 212 due to no improvement in validation loss.



