# Training of a simple encoder transformer
- trained with a small balanced dataset 

In [1]:
import os
import sys
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from torch import nn

print(torch.__version__)

notebook_dir = os.getcwd()

project_root = os.path.abspath(os.path.join(notebook_dir, "../.."))
print(project_root)
if project_root not in sys.path:
    sys.path.append(project_root)
    
MODEL_V = "04_simple_transformer"

2.2.2
/Users/damianstone/Documents/Code/machine-learning/dl-sepsis-prediction


In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device

device(type='cpu')

In [3]:
import torch
print(torch.cuda.is_available())  

False


## Splitting the data

In [66]:
from Dataset import Dataset

data_loader = Dataset(
    data_path="imputed_sofa.parquet",
    save_path="dataset_tensors.pth",
    method="oversample",
    balance=True,
    minority_ratio=0.5,
    target_column="SepsisLabel"
)
# as tensors
X_train, X_test, y_train, y_test = data_loader.get_train_test_tensors(size='small', train_size=0.1)
X_train, X_test, y_train, y_test = X_train.to(device), X_test.to(device), y_train.to(device), y_test.to(device)


Balanced training set balance:
SepsisLabel
0    1219428
1     609714
Name: count, dtype: int64
Reduced balanced training set balance:
SepsisLabel
0    121943
1     60971
Name: count, dtype: int64
Total records in reduced training set: 182914


## Transformer architecture

In [67]:
from architectures import TransformerClassifier


in_dim = X_train.shape[1]
model = TransformerClassifier(input_dim=in_dim)

In [68]:
from torch.utils.data import DataLoader, TensorDataset


train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

batch_size = 512 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


### Training loop

In [69]:
negative_count = (y_train == 0).sum().item()  # Count negatives (class 0)
positive_count = (y_train == 1).sum().item()  # Count positives (class 1)

if positive_count == 0:
    raise ValueError("No positive samples found in training set.")

print(negative_count / positive_count)

2.0000164012399337


In [70]:
from tqdm import tqdm
from torchmetrics import Accuracy, Precision, Recall
from torchvision.ops import sigmoid_focal_loss
from torch import nn


model = TransformerClassifier(input_dim=in_dim).to(device)
pos_weight = torch.tensor([2], dtype=torch.float32).to(device)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) 

epochs = 100
epoch_counter = []
loss_counter = []
acc_counter = []

t_accuracy = Accuracy(task='binary').to(device)
t_precision = Precision(task='binary').to(device)
t_recall = Recall(task='binary').to(device)

for epoch in range(epochs):
    model.train()
    epoch_loss, epoch_acc, epoch_prec, epoch_rec = 0, 0, 0, 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        # forward pass
        y_logits = model(X_batch)
        # print("Logits sample:", y_logits[:5].detach().cpu().numpy())  # Should have positive values too
        y_probs = torch.sigmoid(y_logits)
        # because the dataset is too imbalance
        y_preds = torch.round(y_probs)

        # loss function
        loss = loss_fn(y_logits, y_batch.unsqueeze(1).float())
        
        acc = t_accuracy(y_preds, y_batch.unsqueeze(1).float())
        prec = t_precision(y_preds, y_batch.unsqueeze(1).float())
        rec = t_recall(y_preds, y_batch.unsqueeze(1).float())
        
        # zero grad
        optimizer.zero_grad()
        # backpropagation
        loss.backward()
        # optimizer
        optimizer.step()

        # update epoch metrics
        epoch_loss += loss.item()
        epoch_acc += acc
        epoch_prec += prec
        epoch_rec += rec
        
        progress_bar.set_postfix({"Loss": loss.item(), "Acc": acc.item(), "Prec": prec.item(), "Rec": rec.item()})

    epoch_loss /= len(train_loader)
    epoch_acc /= len(train_loader)
    epoch_prec /= len(train_loader)
    epoch_rec /= len(train_loader)
    acc_counter.append(epoch_acc)
    
    print(f"Epoch {epoch+1}/{epochs} | Loss: {epoch_loss:.5f} | Accuracy: {epoch_acc:.2f}% | Precision: {epoch_prec:.2f}% | Recall: {epoch_rec:.2f}%")

Epoch 1/100:   0%|          | 0/358 [01:45<?, ?it/s, Loss=0.968, Acc=0.554, Prec=0.41, Rec=0.314]  

Epoch 1/100 | Loss: 0.92927 | Accuracy: 0.50% | Precision: 0.33% | Recall: 0.50%


                                                                                                 

Epoch 2/100 | Loss: 0.92502 | Accuracy: 0.49% | Precision: 0.32% | Recall: 0.52%


Epoch 3/100:   0%|          | 0/358 [01:53<?, ?it/s, Loss=0.959, Acc=0.431, Prec=0.391, Rec=0.86]   

Epoch 3/100 | Loss: 0.92464 | Accuracy: 0.50% | Precision: 0.32% | Recall: 0.51%


                                                                                                 

Epoch 4/100 | Loss: 0.92468 | Accuracy: 0.50% | Precision: 0.30% | Recall: 0.48%


Epoch 5/100:   0%|          | 0/358 [02:00<?, ?it/s, Loss=0.929, Acc=0.415, Prec=0.295, Rec=0.523]  

Epoch 5/100 | Loss: 0.92423 | Accuracy: 0.51% | Precision: 0.33% | Recall: 0.50%


                                                                                                  

Epoch 6/100 | Loss: 0.92453 | Accuracy: 0.50% | Precision: 0.29% | Recall: 0.49%


Epoch 7/100:   0%|          | 0/358 [02:05<?, ?it/s, Loss=0.942, Acc=0.377, Prec=0.367, Rec=1]      

Epoch 7/100 | Loss: 0.92447 | Accuracy: 0.50% | Precision: 0.30% | Recall: 0.50%


                                                                                              

Epoch 8/100 | Loss: 0.92450 | Accuracy: 0.49% | Precision: 0.30% | Recall: 0.52%


Epoch 9/100:   0%|          | 0/358 [02:01<?, ?it/s, Loss=0.947, Acc=0.638, Prec=0, Rec=0]          

Epoch 9/100 | Loss: 0.92452 | Accuracy: 0.49% | Precision: 0.29% | Recall: 0.54%


                                                                                          

Epoch 10/100 | Loss: 0.92435 | Accuracy: 0.49% | Precision: 0.30% | Recall: 0.52%


Epoch 11/100:   0%|          | 0/358 [02:02<?, ?it/s, Loss=0.958, Acc=0.623, Prec=0.5, Rec=0.0204]   

Epoch 11/100 | Loss: 0.92444 | Accuracy: 0.49% | Precision: 0.30% | Recall: 0.54%


                                                                                                  

Epoch 12/100 | Loss: 0.92493 | Accuracy: 0.51% | Precision: 0.34% | Recall: 0.49%


Epoch 13/100:   0%|          | 0/358 [02:04<?, ?it/s, Loss=0.929, Acc=0.346, Prec=0.319, Rec=0.818]   

Epoch 13/100 | Loss: 0.92430 | Accuracy: 0.51% | Precision: 0.33% | Recall: 0.48%


                                                                                                   

Epoch 14/100 | Loss: 0.92433 | Accuracy: 0.51% | Precision: 0.32% | Recall: 0.48%


Epoch 15/100:   0%|          | 0/358 [02:04<?, ?it/s, Loss=0.891, Acc=0.315, Prec=0.287, Rec=0.946]  

Epoch 15/100 | Loss: 0.92430 | Accuracy: 0.49% | Precision: 0.26% | Recall: 0.52%


                                                                                                   

Epoch 16/100 | Loss: 0.92436 | Accuracy: 0.49% | Precision: 0.25% | Recall: 0.51%


Epoch 17/100:   0%|          | 0/358 [02:04<?, ?it/s, Loss=0.945, Acc=0.638, Prec=0, Rec=0]          

Epoch 17/100 | Loss: 0.92435 | Accuracy: 0.49% | Precision: 0.23% | Recall: 0.52%


                                                                                           

Epoch 18/100 | Loss: 0.92434 | Accuracy: 0.50% | Precision: 0.29% | Recall: 0.49%


Epoch 19/100:   0%|          | 0/358 [02:06<?, ?it/s, Loss=0.971, Acc=0.592, Prec=0, Rec=0]          

Epoch 19/100 | Loss: 0.92440 | Accuracy: 0.52% | Precision: 0.30% | Recall: 0.43%


                                                                                           

Epoch 20/100 | Loss: 0.92432 | Accuracy: 0.49% | Precision: 0.32% | Recall: 0.53%


Epoch 21/100:   0%|          | 0/358 [02:05<?, ?it/s, Loss=0.933, Acc=0.354, Prec=0.339, Rec=0.911] 

Epoch 21/100 | Loss: 0.92431 | Accuracy: 0.48% | Precision: 0.29% | Recall: 0.57%


                                                                                                   

Epoch 22/100 | Loss: 0.92431 | Accuracy: 0.49% | Precision: 0.25% | Recall: 0.52%


Epoch 23/100:   0%|          | 0/358 [02:07<?, ?it/s, Loss=0.871, Acc=0.3, Prec=0.258, Rec=0.939]    

Epoch 23/100 | Loss: 0.92420 | Accuracy: 0.49% | Precision: 0.25% | Recall: 0.51%


                                                                                                 

Epoch 24/100 | Loss: 0.92418 | Accuracy: 0.53% | Precision: 0.23% | Recall: 0.40%


Epoch 25/100:   0%|          | 0/358 [02:11<?, ?it/s, Loss=0.917, Acc=0.354, Prec=0.328, Rec=0.952]  

Epoch 25/100 | Loss: 0.92428 | Accuracy: 0.50% | Precision: 0.27% | Recall: 0.51%


                                                                                                   

Epoch 26/100 | Loss: 0.92432 | Accuracy: 0.54% | Precision: 0.32% | Recall: 0.38%


Epoch 27/100:   0%|          | 0/358 [02:12<?, ?it/s, Loss=0.929, Acc=0.662, Prec=0, Rec=0]          

Epoch 27/100 | Loss: 0.92426 | Accuracy: 0.47% | Precision: 0.31% | Recall: 0.59%


                                                                                           

Epoch 28/100 | Loss: 0.92424 | Accuracy: 0.51% | Precision: 0.30% | Recall: 0.47%


Epoch 29/100:   0%|          | 0/358 [02:07<?, ?it/s, Loss=0.906, Acc=0.615, Prec=0.292, Rec=0.175]  

Epoch 29/100 | Loss: 0.92420 | Accuracy: 0.50% | Precision: 0.23% | Recall: 0.48%


                                                                                                   

Epoch 30/100 | Loss: 0.92421 | Accuracy: 0.51% | Precision: 0.27% | Recall: 0.47%


Epoch 31/100:   0%|          | 0/358 [02:03<?, ?it/s, Loss=0.939, Acc=0.615, Prec=0.357, Rec=0.109]  

Epoch 31/100 | Loss: 0.92428 | Accuracy: 0.52% | Precision: 0.28% | Recall: 0.45%


                                                                                                   

Epoch 32/100 | Loss: 0.92424 | Accuracy: 0.51% | Precision: 0.31% | Recall: 0.47%


Epoch 33/100:   0%|          | 0/358 [02:08<?, ?it/s, Loss=0.928, Acc=0.385, Prec=0.324, Rec=0.75]  

Epoch 33/100 | Loss: 0.92426 | Accuracy: 0.51% | Precision: 0.21% | Recall: 0.46%


                                                                                                  

Epoch 34/100 | Loss: 0.92425 | Accuracy: 0.49% | Precision: 0.29% | Recall: 0.54%


Epoch 35/100:   0%|          | 0/358 [02:07<?, ?it/s, Loss=0.976, Acc=0.592, Prec=0, Rec=0]          

Epoch 35/100 | Loss: 0.92437 | Accuracy: 0.50% | Precision: 0.30% | Recall: 0.48%


                                                                                           

Epoch 36/100 | Loss: 0.92422 | Accuracy: 0.48% | Precision: 0.27% | Recall: 0.55%


Epoch 37/100:   0%|          | 0/358 [02:08<?, ?it/s, Loss=0.978, Acc=0.592, Prec=0, Rec=0]          

Epoch 37/100 | Loss: 0.92435 | Accuracy: 0.45% | Precision: 0.31% | Recall: 0.65%


                                                                                           

Epoch 38/100 | Loss: 0.92425 | Accuracy: 0.50% | Precision: 0.24% | Recall: 0.48%


Epoch 39/100:   0%|          | 0/358 [02:11<?, ?it/s, Loss=0.885, Acc=0.723, Prec=0, Rec=0]          

Epoch 39/100 | Loss: 0.92416 | Accuracy: 0.48% | Precision: 0.30% | Recall: 0.54%


                                                                                           

Epoch 40/100 | Loss: 0.92426 | Accuracy: 0.50% | Precision: 0.27% | Recall: 0.51%


Epoch 41/100:   0%|          | 0/358 [02:11<?, ?it/s, Loss=0.933, Acc=0.346, Prec=0.346, Rec=1]      

Epoch 41/100 | Loss: 0.92426 | Accuracy: 0.56% | Precision: 0.29% | Recall: 0.32%


                                                                                               

Epoch 42/100 | Loss: 0.92419 | Accuracy: 0.51% | Precision: 0.31% | Recall: 0.47%


Epoch 43/100:   0%|          | 0/358 [02:12<?, ?it/s, Loss=0.906, Acc=0.662, Prec=0.167, Rec=0.025]  

Epoch 43/100 | Loss: 0.92417 | Accuracy: 0.46% | Precision: 0.31% | Recall: 0.61%


                                                                                                   

Epoch 44/100 | Loss: 0.92429 | Accuracy: 0.48% | Precision: 0.29% | Recall: 0.55%


Epoch 45/100:   0%|          | 0/358 [02:12<?, ?it/s, Loss=0.949, Acc=0.515, Prec=0.368, Rec=0.438]  

Epoch 45/100 | Loss: 0.92427 | Accuracy: 0.54% | Precision: 0.34% | Recall: 0.38%


                                                                                                   

Epoch 46/100 | Loss: 0.92434 | Accuracy: 0.52% | Precision: 0.30% | Recall: 0.43%


Epoch 47/100:   0%|          | 0/358 [02:10<?, ?it/s, Loss=0.944, Acc=0.446, Prec=0.379, Rec=0.83]  

Epoch 47/100 | Loss: 0.92427 | Accuracy: 0.45% | Precision: 0.33% | Recall: 0.63%


                                                                                                  

Epoch 48/100 | Loss: 0.92426 | Accuracy: 0.47% | Precision: 0.31% | Recall: 0.59%


Epoch 49/100:   0%|          | 0/358 [02:09<?, ?it/s, Loss=0.923, Acc=0.338, Prec=0.331, Rec=0.977]   

Epoch 49/100 | Loss: 0.92422 | Accuracy: 0.53% | Precision: 0.29% | Recall: 0.40%


                                                                                                   

Epoch 50/100 | Loss: 0.92423 | Accuracy: 0.50% | Precision: 0.34% | Recall: 0.50%


Epoch 51/100:   0%|          | 0/358 [02:13<?, ?it/s, Loss=0.952, Acc=0.377, Prec=0.377, Rec=1]      

Epoch 51/100 | Loss: 0.92428 | Accuracy: 0.50% | Precision: 0.28% | Recall: 0.49%


                                                                                               

Epoch 52/100 | Loss: 0.92417 | Accuracy: 0.48% | Precision: 0.33% | Recall: 0.57%


Epoch 53/100:   0%|          | 0/358 [02:09<?, ?it/s, Loss=0.966, Acc=0.608, Prec=0, Rec=0]           

Epoch 53/100 | Loss: 0.92430 | Accuracy: 0.51% | Precision: 0.32% | Recall: 0.46%


                                                                                           

Epoch 54/100 | Loss: 0.92416 | Accuracy: 0.51% | Precision: 0.31% | Recall: 0.46%


Epoch 55/100:   0%|          | 0/358 [02:09<?, ?it/s, Loss=0.846, Acc=0.723, Prec=0.267, Rec=0.138]  

Epoch 55/100 | Loss: 0.92406 | Accuracy: 0.51% | Precision: 0.34% | Recall: 0.48%


                                                                                                   

Epoch 56/100 | Loss: 0.92424 | Accuracy: 0.50% | Precision: 0.33% | Recall: 0.48%


Epoch 57/100:   0%|          | 0/358 [02:05<?, ?it/s, Loss=0.922, Acc=0.408, Prec=0.343, Rec=0.86]   

Epoch 57/100 | Loss: 0.92422 | Accuracy: 0.50% | Precision: 0.34% | Recall: 0.49%


                                                                                                  

Epoch 58/100 | Loss: 0.92416 | Accuracy: 0.46% | Precision: 0.30% | Recall: 0.62%


Epoch 59/100:   0%|          | 0/358 [02:10<?, ?it/s, Loss=0.906, Acc=0.392, Prec=0.307, Rec=0.775]  

Epoch 59/100 | Loss: 0.92416 | Accuracy: 0.51% | Precision: 0.24% | Recall: 0.47%


                                                                                                   

Epoch 60/100 | Loss: 0.92422 | Accuracy: 0.48% | Precision: 0.33% | Recall: 0.56%


Epoch 61/100:   0%|          | 0/358 [02:08<?, ?it/s, Loss=0.928, Acc=0.646, Prec=0, Rec=0]           

Epoch 61/100 | Loss: 0.92424 | Accuracy: 0.51% | Precision: 0.32% | Recall: 0.46%


                                                                                           

Epoch 62/100 | Loss: 0.92428 | Accuracy: 0.47% | Precision: 0.33% | Recall: 0.59%


Epoch 63/100:   0%|          | 0/358 [02:12<?, ?it/s, Loss=0.959, Acc=0.385, Prec=0.375, Rec=0.9]   

Epoch 63/100 | Loss: 0.92428 | Accuracy: 0.51% | Precision: 0.34% | Recall: 0.46%


                                                                                                 

Epoch 64/100 | Loss: 0.92418 | Accuracy: 0.47% | Precision: 0.32% | Recall: 0.60%


Epoch 65/100:   0%|          | 0/358 [02:14<?, ?it/s, Loss=0.9, Acc=0.654, Prec=0.286, Rec=0.103]     

Epoch 65/100 | Loss: 0.92416 | Accuracy: 0.50% | Precision: 0.34% | Recall: 0.49%


                                                                                                 

Epoch 66/100 | Loss: 0.92423 | Accuracy: 0.52% | Precision: 0.29% | Recall: 0.45%


Epoch 67/100:   0%|          | 0/358 [02:13<?, ?it/s, Loss=0.923, Acc=0.331, Prec=0.331, Rec=1]       

Epoch 67/100 | Loss: 0.92421 | Accuracy: 0.48% | Precision: 0.31% | Recall: 0.55%


                                                                                               

Epoch 68/100 | Loss: 0.92420 | Accuracy: 0.51% | Precision: 0.34% | Recall: 0.47%


Epoch 69/100:   0%|          | 0/358 [02:10<?, ?it/s, Loss=0.966, Acc=0.608, Prec=0, Rec=0]           

Epoch 69/100 | Loss: 0.92431 | Accuracy: 0.52% | Precision: 0.30% | Recall: 0.45%


                                                                                           

Epoch 70/100 | Loss: 0.92422 | Accuracy: 0.47% | Precision: 0.33% | Recall: 0.57%


Epoch 71/100:   0%|          | 0/358 [02:12<?, ?it/s, Loss=0.949, Acc=0.577, Prec=0.182, Rec=0.0417] 

Epoch 71/100 | Loss: 0.92426 | Accuracy: 0.51% | Precision: 0.30% | Recall: 0.46%


                                                                                                    

Epoch 72/100 | Loss: 0.92421 | Accuracy: 0.48% | Precision: 0.31% | Recall: 0.57%


Epoch 73/100:   0%|          | 0/358 [02:12<?, ?it/s, Loss=0.933, Acc=0.623, Prec=0.167, Rec=0.0222]

Epoch 73/100 | Loss: 0.92424 | Accuracy: 0.47% | Precision: 0.33% | Recall: 0.60%


                                                                                                    

Epoch 74/100 | Loss: 0.92410 | Accuracy: 0.54% | Precision: 0.33% | Recall: 0.37%


Epoch 75/100:   0%|          | 0/358 [02:11<?, ?it/s, Loss=0.986, Acc=0.415, Prec=0.415, Rec=0.927]  

Epoch 75/100 | Loss: 0.92434 | Accuracy: 0.51% | Precision: 0.23% | Recall: 0.47%


                                                                                                   

Epoch 76/100 | Loss: 0.92421 | Accuracy: 0.47% | Precision: 0.29% | Recall: 0.59%


Epoch 77/100:   0%|          | 0/358 [02:14<?, ?it/s, Loss=0.901, Acc=0.331, Prec=0.31, Rec=1]       

Epoch 77/100 | Loss: 0.92415 | Accuracy: 0.49% | Precision: 0.29% | Recall: 0.53%


                                                                                              

Epoch 78/100 | Loss: 0.92419 | Accuracy: 0.49% | Precision: 0.33% | Recall: 0.52%


Epoch 79/100:   0%|          | 0/358 [02:11<?, ?it/s, Loss=0.932, Acc=0.377, Prec=0.357, Rec=1]       

Epoch 79/100 | Loss: 0.92422 | Accuracy: 0.55% | Precision: 0.29% | Recall: 0.37%


                                                                                               

Epoch 80/100 | Loss: 0.92419 | Accuracy: 0.50% | Precision: 0.34% | Recall: 0.49%


Epoch 81/100:   0%|          | 0/358 [02:13<?, ?it/s, Loss=0.906, Acc=0.662, Prec=0.333, Rec=0.1]     

Epoch 81/100 | Loss: 0.92417 | Accuracy: 0.53% | Precision: 0.33% | Recall: 0.40%


                                                                                                 

Epoch 82/100 | Loss: 0.92419 | Accuracy: 0.51% | Precision: 0.34% | Recall: 0.46%


Epoch 83/100:   0%|          | 0/358 [02:06<?, ?it/s, Loss=0.939, Acc=0.6, Prec=0.364, Rec=0.174]    

Epoch 83/100 | Loss: 0.92426 | Accuracy: 0.50% | Precision: 0.34% | Recall: 0.50%


                                                                                                 

Epoch 84/100 | Loss: 0.92436 | Accuracy: 0.49% | Precision: 0.33% | Recall: 0.52%


Epoch 85/100:   0%|          | 0/358 [02:11<?, ?it/s, Loss=0.965, Acc=0.369, Prec=0.378, Rec=0.941]  

Epoch 85/100 | Loss: 0.92430 | Accuracy: 0.49% | Precision: 0.28% | Recall: 0.52%


                                                                                                   

Epoch 86/100 | Loss: 0.92429 | Accuracy: 0.48% | Precision: 0.31% | Recall: 0.57%


Epoch 87/100:   0%|          | 0/358 [02:07<?, ?it/s, Loss=0.949, Acc=0.6, Prec=0.25, Rec=0.0417]   

Epoch 87/100 | Loss: 0.92427 | Accuracy: 0.48% | Precision: 0.27% | Recall: 0.55%


                                                                                                 

Epoch 88/100 | Loss: 0.92425 | Accuracy: 0.52% | Precision: 0.30% | Recall: 0.45%


Epoch 89/100:   0%|          | 0/358 [02:15<?, ?it/s, Loss=0.912, Acc=0.377, Prec=0.331, Rec=0.951]   

Epoch 89/100 | Loss: 0.92420 | Accuracy: 0.48% | Precision: 0.33% | Recall: 0.56%


                                                                                                   

Epoch 90/100 | Loss: 0.92426 | Accuracy: 0.49% | Precision: 0.31% | Recall: 0.54%


Epoch 91/100:   0%|          | 0/358 [02:07<?, ?it/s, Loss=0.944, Acc=0.362, Prec=0.339, Rec=0.809]  

Epoch 91/100 | Loss: 0.92425 | Accuracy: 0.53% | Precision: 0.23% | Recall: 0.42%


                                                                                                   

Epoch 92/100 | Loss: 0.92420 | Accuracy: 0.45% | Precision: 0.28% | Recall: 0.63%


Epoch 93/100:   0%|          | 0/358 [02:13<?, ?it/s, Loss=0.971, Acc=0.6, Prec=0.5, Rec=0.135]      

Epoch 93/100 | Loss: 0.92432 | Accuracy: 0.51% | Precision: 0.28% | Recall: 0.46%


                                                                                               

Epoch 94/100 | Loss: 0.92423 | Accuracy: 0.57% | Precision: 0.29% | Recall: 0.29%


Epoch 95/100:   0%|          | 0/358 [02:10<?, ?it/s, Loss=0.955, Acc=0.546, Prec=0.25, Rec=0.102]   

Epoch 95/100 | Loss: 0.92429 | Accuracy: 0.43% | Precision: 0.32% | Recall: 0.70%


                                                                                                  

Epoch 96/100 | Loss: 0.92417 | Accuracy: 0.50% | Precision: 0.29% | Recall: 0.51%


Epoch 97/100:   0%|          | 0/358 [02:09<?, ?it/s, Loss=0.901, Acc=0.669, Prec=0.25, Rec=0.0513]   

Epoch 97/100 | Loss: 0.92417 | Accuracy: 0.50% | Precision: 0.30% | Recall: 0.49%


                                                                                                   

Epoch 98/100 | Loss: 0.92424 | Accuracy: 0.56% | Precision: 0.34% | Recall: 0.33%


Epoch 99/100:   0%|          | 0/358 [02:08<?, ?it/s, Loss=0.895, Acc=0.708, Prec=0, Rec=0]          

Epoch 99/100 | Loss: 0.92416 | Accuracy: 0.48% | Precision: 0.29% | Recall: 0.55%


                                                                                           

Epoch 100/100 | Loss: 0.92416 | Accuracy: 0.51% | Precision: 0.25% | Recall: 0.47%


## Save model and predictions

In [11]:
from pathlib import Path

# save the model
model_path = Path('./saved')
model_path.mkdir(exist_ok=True)
model_file = model_path / f"{MODEL_V}.pth"

torch.save(model.state_dict(), model_file)

In [12]:
model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)
model.load_state_dict(torch.load(f"./saved/{MODEL_V}.pth"))



<All keys matched successfully>