# Training of a simple encoder transformer
- trained with a small balanced dataset 

In [1]:
import os
import sys
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from torch import nn

print(torch.__version__)

notebook_dir = os.getcwd()

project_root = os.path.abspath(os.path.join(notebook_dir, "../.."))
print(project_root)
if project_root not in sys.path:
    sys.path.append(project_root)
    
MODEL_V = "03_simple_transformer"

2.2.2
/Users/damianstone/Documents/Code/machine-learning/dl-sepsis-prediction


## Splitting the data

In [2]:
from Dataset import Dataset

data_loader = Dataset(
    data_path="imputed_sofa.parquet",
    save_path="dataset_tensors.pth",
    method="oversample",
    minority_ratio=0.3,
    target_column="SepsisLabel"
)

# as tensors
X_train, X_test, y_train, y_test = data_loader.get_train_test_tensors(size='small', train_size=0.1)

Computed pos_weight: tensor([20])
Balanced training set balance:
SepsisLabel
0    1219428
1     365828
Name: count, dtype: int64
Total records in balanced training set: 1585256
Reduced balanced training set balance:
SepsisLabel
0    121942
1     36583
Name: count, dtype: int64
Total records in reduced training set: 158525


## Transformer architecture
- Use `Mean Pooling` (x.mean(dim=1)) if: You want a summary of the whole sequence (useful if sepsis patterns are spread across time).
- Use `Max Pooling` (x.max(dim=1).values) if: You want to focus on the most extreme feature values, which might indicate critical moments.
- Use `Last Timestep` (x[:, -1, :]) if: You believe the latest patient state matters most (recommended for your case).

Test pooling:
- Train with different pooling methods and compare AUC-ROC scores.
- If AUC increases, the new pooling method is better.
- Visualize attention weights to see where the model is focusing.

In [None]:
class TransformerClassifier(nn.Module):
    """
    num_heads = more heads capture different attention but increase computation
    num_layers = more make the model deeper but can overfit if too high
    """

    def __init__(self, input_dim, num_heads=4, num_layers=4):
        super().__init__()
        # d_model = input_dim (number of features)
        # TODO: add drop out to avoid overfitting
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=input_dim, nhead=num_heads, dropout=0.1)
        # stacks multiple encoder layers (num_layers controls depth)
        self.encoder = nn.TransformerEncoder(
            self.encoder_layer, num_layers=num_layers)
        # linear layer to map the output to a single value (binary classification)
        self.linear_layer = nn.Linear(in_features=input_dim, out_features=1)

    def forward(self, x):
        # Output shape: (batch_size, seq_len, features) or (batch_size, features)
        z = self.encoder(x)
        print("Shape of z:", z.shape)
        if z.dim() == 3:  # If (batch_size, seq_len, features), take the last timestep
            # NOTE: last timestep -> the most recent ICU data is usually the most relevant for prediction
            z = z[:, -1, :]
        return self.linear_layer(z)


def get_valid_num_heads(input_dim, desired_heads):
    """Finds the highest valid num_heads <= desired_heads that divides input_dim."""
    while desired_heads > 0 and input_dim % desired_heads != 0:
        desired_heads -= 1
    return max(1, desired_heads)


in_dim = X_train.shape[1]
print(in_dim)
n_heads = get_valid_num_heads(in_dim, 10)
print(n_heads)
model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)

38
2




In [4]:
from torch.utils.data import DataLoader, TensorDataset


train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

batch_size = 512 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


### Training loop

In [10]:
from tqdm import tqdm
from torchmetrics import Accuracy

model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0], dtype=torch.float32))
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) 

epochs = 100
epoch_counter = []
loss_counter = []
acc_counter = []

t_accuracy = Accuracy(task='binary')

for epoch in range(epochs):
    model.train()
    epoch_loss, epoch_acc = 0, 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

    for X_batch, y_batch in train_loader:
        # forward pass
        y_logits = model(X_batch)
        y_probs = torch.sigmoid(y_logits)
        # y_pred = torch.round(y_probs)

        # loss function
        loss = loss_fn(y_logits, y_batch.unsqueeze(1).float())
        acc = t_accuracy(y_logits, y_batch.unsqueeze(1).float())
        
        # zero grad
        optimizer.zero_grad()
        # backpropagation
        loss.backward()
        # optimizer
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc
        
        progress_bar.set_postfix({"Loss": loss.item(), "Acc": acc.item()})

    epoch_loss /= len(train_loader)
    epoch_acc /= len(train_loader)
    acc_counter.append(epoch_acc)
    
    print(f"Epoch {epoch+1}/{epochs} | Loss: {epoch_loss:.5f} | Accuracy: {epoch_acc:.2f}%")

Epoch 1/100:   0%|          | 0/310 [00:38<?, ?it/s, Loss=1.16, Acc=0.552]

Epoch 1/100 | Loss: 1.28596 | Accuracy: 0.55%


                                                                          

Epoch 2/100 | Loss: 1.15789 | Accuracy: 0.57%


Epoch 3/100:   0%|          | 0/310 [00:41<?, ?it/s, Loss=1, Acc=0.672]   

Epoch 3/100 | Loss: 1.11921 | Accuracy: 0.61%


                                                                       

Epoch 4/100 | Loss: 1.11270 | Accuracy: 0.62%


Epoch 5/100:   0%|          | 0/310 [00:41<?, ?it/s, Loss=1.14, Acc=0.631] 

Epoch 5/100 | Loss: 1.10492 | Accuracy: 0.62%


                                                                          

Epoch 6/100 | Loss: 1.10472 | Accuracy: 0.63%


Epoch 7/100:   0%|          | 0/310 [00:42<?, ?it/s, Loss=1.16, Acc=0.672] 

Epoch 7/100 | Loss: 1.09923 | Accuracy: 0.62%


                                                                          

Epoch 8/100 | Loss: 1.09015 | Accuracy: 0.63%


Epoch 9/100:   0%|          | 0/310 [00:34<?, ?it/s, Loss=1.13, Acc=0.634] 

Epoch 9/100 | Loss: 1.08722 | Accuracy: 0.63%


                                                                          

Epoch 10/100 | Loss: 1.08168 | Accuracy: 0.64%


Epoch 11/100:   0%|          | 0/310 [00:34<?, ?it/s, Loss=1.04, Acc=0.666] 

Epoch 11/100 | Loss: 1.07789 | Accuracy: 0.64%


                                                                           

Epoch 12/100 | Loss: 1.07418 | Accuracy: 0.64%


Epoch 13/100:   0%|          | 0/310 [00:38<?, ?it/s, Loss=0.987, Acc=0.716]

Epoch 13/100 | Loss: 1.06847 | Accuracy: 0.64%


                                                                            

Epoch 14/100 | Loss: 1.06413 | Accuracy: 0.65%


Epoch 15/100:   0%|          | 0/310 [00:34<?, ?it/s, Loss=1.02, Acc=0.691] 

Epoch 15/100 | Loss: 1.05708 | Accuracy: 0.65%


                                                                           

Epoch 16/100 | Loss: 1.04883 | Accuracy: 0.66%


Epoch 17/100:   0%|          | 0/310 [00:44<?, ?it/s, Loss=1.04, Acc=0.707] 

Epoch 17/100 | Loss: 1.04125 | Accuracy: 0.66%


                                                                           

Epoch 18/100 | Loss: 1.03609 | Accuracy: 0.66%


Epoch 19/100:   0%|          | 0/310 [00:40<?, ?it/s, Loss=0.953, Acc=0.7]  

Epoch 19/100 | Loss: 1.02513 | Accuracy: 0.67%


                                                                          

Epoch 20/100 | Loss: 1.02092 | Accuracy: 0.67%


Epoch 21/100:   0%|          | 0/310 [00:55<?, ?it/s, Loss=0.987, Acc=0.644]

Epoch 21/100 | Loss: 1.01204 | Accuracy: 0.67%


                                                                            

Epoch 22/100 | Loss: 1.00175 | Accuracy: 0.67%


Epoch 23/100:   0%|          | 0/310 [00:48<?, ?it/s, Loss=0.923, Acc=0.644]

Epoch 23/100 | Loss: 0.99272 | Accuracy: 0.68%


                                                                            

Epoch 24/100 | Loss: 0.98511 | Accuracy: 0.67%


Epoch 25/100:   0%|          | 0/310 [00:49<?, ?it/s, Loss=1.02, Acc=0.669] 

Epoch 25/100 | Loss: 0.97670 | Accuracy: 0.68%


                                                                           

Epoch 26/100 | Loss: 0.96414 | Accuracy: 0.68%


Epoch 27/100:   0%|          | 0/310 [00:48<?, ?it/s, Loss=0.853, Acc=0.666]

Epoch 27/100 | Loss: 0.95389 | Accuracy: 0.69%


                                                                            

Epoch 28/100 | Loss: 0.94678 | Accuracy: 0.68%


Epoch 29/100:   0%|          | 0/310 [00:44<?, ?it/s, Loss=0.843, Acc=0.713]

Epoch 29/100 | Loss: 0.93439 | Accuracy: 0.69%


                                                                            

Epoch 30/100 | Loss: 0.92289 | Accuracy: 0.69%


Epoch 31/100:   0%|          | 0/310 [00:44<?, ?it/s, Loss=0.771, Acc=0.675]

Epoch 31/100 | Loss: 0.90791 | Accuracy: 0.69%


                                                                            

Epoch 32/100 | Loss: 0.90153 | Accuracy: 0.70%


Epoch 33/100:   0%|          | 0/310 [00:44<?, ?it/s, Loss=0.689, Acc=0.726]

Epoch 33/100 | Loss: 0.88502 | Accuracy: 0.70%


                                                                            

Epoch 34/100 | Loss: 0.87754 | Accuracy: 0.71%


Epoch 35/100:   0%|          | 0/310 [00:43<?, ?it/s, Loss=0.692, Acc=0.757]

Epoch 35/100 | Loss: 0.86014 | Accuracy: 0.71%


                                                                            

Epoch 36/100 | Loss: 0.85000 | Accuracy: 0.72%


Epoch 37/100:   0%|          | 0/310 [00:42<?, ?it/s, Loss=0.898, Acc=0.741]

Epoch 37/100 | Loss: 0.83550 | Accuracy: 0.72%


                                                                            

Epoch 38/100 | Loss: 0.82260 | Accuracy: 0.73%


Epoch 39/100:   0%|          | 0/310 [00:46<?, ?it/s, Loss=0.824, Acc=0.713]

Epoch 39/100 | Loss: 0.81222 | Accuracy: 0.73%


                                                                            

Epoch 40/100 | Loss: 0.80103 | Accuracy: 0.73%


Epoch 41/100:   0%|          | 0/310 [00:44<?, ?it/s, Loss=0.782, Acc=0.757]

Epoch 41/100 | Loss: 0.78638 | Accuracy: 0.74%


                                                                            

Epoch 42/100 | Loss: 0.77323 | Accuracy: 0.75%


Epoch 43/100:   0%|          | 0/310 [00:44<?, ?it/s, Loss=0.765, Acc=0.757]

Epoch 43/100 | Loss: 0.76057 | Accuracy: 0.75%


                                                                            

Epoch 44/100 | Loss: 0.75271 | Accuracy: 0.76%


Epoch 45/100:   0%|          | 0/310 [00:43<?, ?it/s, Loss=0.765, Acc=0.782]

Epoch 45/100 | Loss: 0.74746 | Accuracy: 0.76%


                                                                            

Epoch 46/100 | Loss: 0.73048 | Accuracy: 0.76%


Epoch 47/100:   0%|          | 0/310 [00:45<?, ?it/s, Loss=0.74, Acc=0.808] 

Epoch 47/100 | Loss: 0.72186 | Accuracy: 0.77%


                                                                           

Epoch 48/100 | Loss: 0.71067 | Accuracy: 0.77%


Epoch 49/100:   0%|          | 0/310 [00:46<?, ?it/s, Loss=0.578, Acc=0.826]

Epoch 49/100 | Loss: 0.69917 | Accuracy: 0.77%


                                                                            

Epoch 50/100 | Loss: 0.68740 | Accuracy: 0.78%


Epoch 51/100:   0%|          | 0/310 [00:42<?, ?it/s, Loss=0.579, Acc=0.811]

Epoch 51/100 | Loss: 0.67691 | Accuracy: 0.78%


                                                                            

Epoch 52/100 | Loss: 0.66752 | Accuracy: 0.79%


Epoch 53/100:   0%|          | 0/310 [00:42<?, ?it/s, Loss=0.643, Acc=0.814]

Epoch 53/100 | Loss: 0.66331 | Accuracy: 0.79%


                                                                            

Epoch 54/100 | Loss: 0.65409 | Accuracy: 0.79%


Epoch 55/100:   0%|          | 0/310 [00:49<?, ?it/s, Loss=0.599, Acc=0.789]

Epoch 55/100 | Loss: 0.64314 | Accuracy: 0.79%


                                                                            

Epoch 56/100 | Loss: 0.63982 | Accuracy: 0.80%


Epoch 57/100:   0%|          | 0/310 [00:44<?, ?it/s, Loss=0.688, Acc=0.795]

Epoch 57/100 | Loss: 0.62882 | Accuracy: 0.80%


                                                                            

Epoch 58/100 | Loss: 0.62129 | Accuracy: 0.81%


Epoch 59/100:   0%|          | 0/310 [00:45<?, ?it/s, Loss=0.604, Acc=0.808]

Epoch 59/100 | Loss: 0.61723 | Accuracy: 0.81%


                                                                            

Epoch 60/100 | Loss: 0.60624 | Accuracy: 0.81%


Epoch 61/100:   0%|          | 0/310 [00:44<?, ?it/s, Loss=0.586, Acc=0.82] 

Epoch 61/100 | Loss: 0.59797 | Accuracy: 0.82%


                                                                           

Epoch 62/100 | Loss: 0.59560 | Accuracy: 0.82%


Epoch 63/100:   0%|          | 0/310 [00:59<?, ?it/s, Loss=0.67, Acc=0.845] 

Epoch 63/100 | Loss: 0.58725 | Accuracy: 0.82%


                                                                           

Epoch 64/100 | Loss: 0.58473 | Accuracy: 0.82%


Epoch 65/100:   0%|          | 0/310 [00:46<?, ?it/s, Loss=0.517, Acc=0.89] 

Epoch 65/100 | Loss: 0.57852 | Accuracy: 0.82%


                                                                           

Epoch 66/100 | Loss: 0.56869 | Accuracy: 0.82%


Epoch 67/100:   0%|          | 0/310 [00:53<?, ?it/s, Loss=0.542, Acc=0.795]

Epoch 67/100 | Loss: 0.56007 | Accuracy: 0.83%


                                                                            

Epoch 68/100 | Loss: 0.56075 | Accuracy: 0.83%


Epoch 69/100:   0%|          | 0/310 [00:56<?, ?it/s, Loss=0.518, Acc=0.836]

Epoch 69/100 | Loss: 0.55003 | Accuracy: 0.83%


                                                                            

Epoch 70/100 | Loss: 0.54641 | Accuracy: 0.83%


Epoch 71/100:   0%|          | 0/310 [00:54<?, ?it/s, Loss=0.56, Acc=0.811] 

Epoch 71/100 | Loss: 0.53979 | Accuracy: 0.84%


                                                                           

Epoch 72/100 | Loss: 0.53557 | Accuracy: 0.84%


Epoch 73/100:   0%|          | 0/310 [00:51<?, ?it/s, Loss=0.648, Acc=0.839]

Epoch 73/100 | Loss: 0.53013 | Accuracy: 0.84%


                                                                            

Epoch 74/100 | Loss: 0.53224 | Accuracy: 0.84%


Epoch 75/100:   0%|          | 0/310 [00:56<?, ?it/s, Loss=0.528, Acc=0.845]

Epoch 75/100 | Loss: 0.52303 | Accuracy: 0.84%


                                                                            

Epoch 76/100 | Loss: 0.51502 | Accuracy: 0.84%


Epoch 77/100:   0%|          | 0/310 [01:10<?, ?it/s, Loss=0.444, Acc=0.852]

Epoch 77/100 | Loss: 0.54417 | Accuracy: 0.84%


                                                                            

Epoch 78/100 | Loss: 0.51281 | Accuracy: 0.85%


Epoch 79/100:   0%|          | 0/310 [00:55<?, ?it/s, Loss=0.51, Acc=0.836] 

Epoch 79/100 | Loss: 0.50487 | Accuracy: 0.85%


                                                                           

Epoch 80/100 | Loss: 0.50554 | Accuracy: 0.85%


Epoch 81/100:   0%|          | 0/310 [00:55<?, ?it/s, Loss=0.519, Acc=0.849]

Epoch 81/100 | Loss: 0.50029 | Accuracy: 0.85%


                                                                            

Epoch 82/100 | Loss: 0.50038 | Accuracy: 0.85%


Epoch 83/100:   0%|          | 0/310 [00:58<?, ?it/s, Loss=0.571, Acc=0.855]

Epoch 83/100 | Loss: 0.49203 | Accuracy: 0.85%


                                                                            

Epoch 84/100 | Loss: 0.48485 | Accuracy: 0.86%


Epoch 85/100:   0%|          | 0/310 [00:49<?, ?it/s, Loss=0.496, Acc=0.864]

Epoch 85/100 | Loss: 0.48562 | Accuracy: 0.86%


                                                                            

Epoch 86/100 | Loss: 0.48258 | Accuracy: 0.86%


Epoch 87/100:   0%|          | 0/310 [00:43<?, ?it/s, Loss=0.394, Acc=0.877]

Epoch 87/100 | Loss: 0.48021 | Accuracy: 0.86%


                                                                            

Epoch 88/100 | Loss: 0.47210 | Accuracy: 0.86%


Epoch 89/100:   0%|          | 0/310 [00:45<?, ?it/s, Loss=0.412, Acc=0.849]

Epoch 89/100 | Loss: 0.46660 | Accuracy: 0.86%


                                                                            

Epoch 90/100 | Loss: 0.46908 | Accuracy: 0.86%


Epoch 91/100:   0%|          | 0/310 [00:44<?, ?it/s, Loss=0.47, Acc=0.855] 

Epoch 91/100 | Loss: 0.46827 | Accuracy: 0.86%


                                                                           

Epoch 92/100 | Loss: 0.45590 | Accuracy: 0.87%


Epoch 93/100:   0%|          | 0/310 [00:44<?, ?it/s, Loss=0.573, Acc=0.811]

Epoch 93/100 | Loss: 0.45428 | Accuracy: 0.87%


                                                                            

Epoch 94/100 | Loss: 0.46218 | Accuracy: 0.86%


Epoch 95/100:   0%|          | 0/310 [00:48<?, ?it/s, Loss=0.541, Acc=0.864]

Epoch 95/100 | Loss: 0.51120 | Accuracy: 0.85%


                                                                            

Epoch 96/100 | Loss: 0.44959 | Accuracy: 0.87%


Epoch 97/100:   0%|          | 0/310 [00:43<?, ?it/s, Loss=0.451, Acc=0.849]

Epoch 97/100 | Loss: 0.44536 | Accuracy: 0.87%


                                                                            

Epoch 98/100 | Loss: 0.44089 | Accuracy: 0.87%


Epoch 99/100:   0%|          | 0/310 [00:42<?, ?it/s, Loss=0.392, Acc=0.915]

Epoch 99/100 | Loss: 0.44128 | Accuracy: 0.87%


                                                                            

Epoch 100/100 | Loss: 0.43743 | Accuracy: 0.87%


## Save model and predictions

In [11]:
from pathlib import Path

# save the model
model_path = Path('./saved')
model_path.mkdir(exist_ok=True)
model_file = model_path / f"{MODEL_V}.pth"

torch.save(model.state_dict(), model_file)

In [12]:
model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)
model.load_state_dict(torch.load(f"./saved/{MODEL_V}.pth"))



<All keys matched successfully>