# Training of a simple encoder transformer
- trained with a small balanced dataset 

In [1]:
import os
import sys
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from torch import nn

print(torch.__version__)

notebook_dir = os.getcwd()

project_root = os.path.abspath(os.path.join(notebook_dir, "../.."))
print(project_root)
if project_root not in sys.path:
    sys.path.append(project_root)
    
MODEL_V = "03_simple_transformer"

2.2.2
/Users/damianstone/Documents/Code/machine-learning/dl-sepsis-prediction


## Splitting the data

In [3]:
from Dataset import Dataset

data_loader = Dataset(
    data_path="imputed_sofa.parquet",
    save_path="dataset_tensors.pth",
    method="oversample",
    minority_ratio=0.2,
    target_column="SepsisLabel"
)

# as tensors
X_train, X_test, y_train, y_test = data_loader.get_train_test_tensors(size='small', train_size=0.3)

Computed pos_weight: tensor([54.6021], dtype=torch.float64)
Balanced training set balance:
SepsisLabel
0    1219428
1     243885
Name: count, dtype: int64
Total records in balanced training set: 1463313
Reduced balanced training set balance:
SepsisLabel
0    365828
1     73165
Name: count, dtype: int64
Total records in reduced training set: 438993


## Transformer architecture
- Use `Mean Pooling` (x.mean(dim=1)) if: You want a summary of the whole sequence (useful if sepsis patterns are spread across time).
- Use `Max Pooling` (x.max(dim=1).values) if: You want to focus on the most extreme feature values, which might indicate critical moments.
- Use `Last Timestep` (x[:, -1, :]) if: You believe the latest patient state matters most (recommended for your case).

Test pooling:
- Train with different pooling methods and compare AUC-ROC scores.
- If AUC increases, the new pooling method is better.
- Visualize attention weights to see where the model is focusing.

In [4]:
class TransformerClassifier(nn.Module):
    """
    num_heads = more heads capture different attention but increase computation
    num_layers = more make the model deeper but can overfit if too high
    """
    def __init__(self, input_dim, num_heads=4, num_layers=3):
        super().__init__()
        # d_model = input_dim (number of features)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads)
        # stacks multiple encoder layers (num_layers controls depth)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        # linear layer to map the output to a single value (binary classification)
        self.linear_layer = nn.Linear(in_features=input_dim, out_features=1)

    def forward(self, x):
      z = self.encoder(x)  # Output shape: (batch_size, seq_len, features) or (batch_size, features)
      if z.dim() == 3:  # If (batch_size, seq_len, features), take the last timestep
            # NOTE: last timestep -> the most recent ICU data is usually the most relevant for prediction
            z = z[:, -1, :]
      return self.linear_layer(z)


def get_valid_num_heads(input_dim, desired_heads):
    """Finds the highest valid num_heads <= desired_heads that divides input_dim."""
    while desired_heads > 0 and input_dim % desired_heads != 0:
        desired_heads -= 1
    return max(1, desired_heads)


in_dim = X_train.shape[1]
print(in_dim)
n_heads = get_valid_num_heads(in_dim, 10)
print(n_heads)
model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)

38
2




In [5]:
from torch.utils.data import DataLoader, TensorDataset


train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

batch_size = 512 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


### Training loop

In [6]:
from tqdm import tqdm
from torchmetrics import Accuracy

model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=data_loader.pos_weight)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) 

epochs = 100
epoch_counter = []
loss_counter = []
acc_counter = []

t_accuracy = Accuracy(task='binary')

for epoch in range(epochs):
    model.train()
    epoch_loss, epoch_acc = 0, 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

    for X_batch, y_batch in train_loader:
        # forward pass
        y_logits = model(X_batch)
        y_probs = torch.sigmoid(y_logits)
        y_pred = torch.round(y_probs)

        # loss function
        loss = loss_fn(y_logits, y_batch.unsqueeze(1).float())
        acc = t_accuracy(y_pred, y_batch.unsqueeze(1).float())
        
        # zero grad
        optimizer.zero_grad()
        # backpropagation
        loss.backward()
        # optimizer
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc
        
        progress_bar.set_postfix({"Loss": loss.item(), "Acc": acc.item()})

    epoch_loss /= len(train_loader)
    epoch_acc /= len(train_loader)
    acc_counter.append(epoch_acc)
    
    print(f"Epoch {epoch+1}/{epochs} | Loss: {epoch_loss:.5f} | Accuracy: {epoch_acc:.2f}%")

Epoch 1/100:   0%|          | 0/858 [01:40<?, ?it/s, Loss=2.37, Acc=0.163]

Epoch 1/100 | Loss: 2.68666 | Accuracy: 0.17%


                                                                          

Epoch 2/100 | Loss: 2.58706 | Accuracy: 0.17%


Epoch 3/100:   0%|          | 0/858 [01:29<?, ?it/s, Loss=2.59, Acc=0.15] 

KeyboardInterrupt: 

## Save model and predictions

In [None]:
from pathlib import Path

# save the model
model_path = Path('./saved')
model_path.mkdir(exist_ok=True)
model_file = model_path / f"{MODEL_V}.pth"

torch.save(model.state_dict(), model_file)

In [None]:
model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)
model.load_state_dict(torch.load(f"./saved/{MODEL_V}.pth"))

<All keys matched successfully>