# Exploring transformers architecture

In [2]:
import os
import sys
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from torch import nn

print(torch.__version__)

notebook_dir = os.getcwd()

project_root = os.path.abspath(os.path.join(notebook_dir, "../.."))
print(project_root)
if project_root not in sys.path:
    sys.path.append(project_root)

2.2.2
/Users/damianstone/Documents/Code/machine-learning/dl-sepsis-prediction


## Dataset

In [2]:
from utils import get_data

DATA_PATH = get_data.get_dataset_abspath()
load_path = os.path.join(DATA_PATH, "imputed_combined_data.parquet")
imputed_df = pd.read_parquet(load_path)

In [3]:
imputed_df.head(1)

Unnamed: 0,HR,O2Sat,Temp,SBP,MAP,DBP,Resp,BaseExcess,HCO3,FiO2,...,Age,Gender,Unit1,Unit2,HospAdmTime,ICULOS,SepsisLabel,patient_id,dataset,cluster_id
0,102.108491,91.419811,36.789519,128.165094,88.199717,71.73211,24.712264,-0.288406,23.835971,0.467029,...,83.14,0,0.715787,0.284213,-0.03,1,0,1,A,0_9_2_2_X_X_X


In [4]:
if 'SOFA' in imputed_df.columns:
    print("Column exists.")
else:
    print("Column doesn't exist.")
    
imputed_df.columns

Column doesn't exist.


Index(['HR', 'O2Sat', 'Temp', 'SBP', 'MAP', 'DBP', 'Resp', 'BaseExcess',
       'HCO3', 'FiO2', 'pH', 'PaCO2', 'SaO2', 'AST', 'BUN', 'Alkalinephos',
       'Calcium', 'Chloride', 'Creatinine', 'Bilirubin_direct', 'Glucose',
       'Lactate', 'Magnesium', 'Phosphate', 'Potassium', 'Bilirubin_total',
       'TroponinI', 'Hct', 'Hgb', 'PTT', 'WBC', 'Fibrinogen', 'Platelets',
       'Age', 'Gender', 'Unit1', 'Unit2', 'HospAdmTime', 'ICULOS',
       'SepsisLabel', 'patient_id', 'dataset', 'cluster_id'],
      dtype='object')

## Last feature engineering

In [5]:
# SOFA calculation based on Sepsis-3
def calculate_sofa(row):
    sofa = 0
    
    if row['FiO2'] > 0:
        pao2_fio2 = row['SaO2'] / row['FiO2']
        if pao2_fio2 < 100: sofa += 4
        elif pao2_fio2 < 200: sofa += 3
        elif pao2_fio2 < 300: sofa += 2
        elif pao2_fio2 < 400: sofa += 1

    if row['Platelets'] < 20: sofa += 4
    elif row['Platelets'] < 50: sofa += 3
    elif row['Platelets'] < 100: sofa += 2
    elif row['Platelets'] < 150: sofa += 1

    if row['Bilirubin_total'] >= 12: sofa += 4
    elif row['Bilirubin_total'] >= 6: sofa += 3
    elif row['Bilirubin_total'] >= 2: sofa += 2
    elif row['Bilirubin_total'] >= 1.2: sofa += 1

    if row['MAP'] < 70: sofa += 1

    if row['Creatinine'] >= 5: sofa += 4
    elif row['Creatinine'] >= 3.5: sofa += 3
    elif row['Creatinine'] >= 2: sofa += 2
    elif row['Creatinine'] >= 1.2: sofa += 1

    return sofa

imputed_df = imputed_df.drop(columns=["Unit1", "Unit2", "cluster_id","dataset", "patient_id"], errors='ignore')
imputed_df = imputed_df.dropna(subset=["HospAdmTime"])
imputed_df["SOFA"] = imputed_df.apply(calculate_sofa, axis=1)
imputed_df.head(1)


Unnamed: 0,HR,O2Sat,Temp,SBP,MAP,DBP,Resp,BaseExcess,HCO3,FiO2,...,PTT,WBC,Fibrinogen,Platelets,Age,Gender,HospAdmTime,ICULOS,SepsisLabel,SOFA
0,102.108491,91.419811,36.789519,128.165094,88.199717,71.73211,24.712264,-0.288406,23.835971,0.467029,...,36.405357,11.737903,350.0,224.187135,83.14,0,-0.03,1,0,2


In [6]:
imputed_df.isna().sum()

HR                  0
O2Sat               0
Temp                0
SBP                 0
MAP                 0
DBP                 0
Resp                0
BaseExcess          0
HCO3                0
FiO2                0
pH                  0
PaCO2               0
SaO2                0
AST                 0
BUN                 0
Alkalinephos        0
Calcium             0
Chloride            0
Creatinine          0
Bilirubin_direct    0
Glucose             0
Lactate             0
Magnesium           0
Phosphate           0
Potassium           0
Bilirubin_total     0
TroponinI           0
Hct                 0
Hgb                 0
PTT                 0
WBC                 0
Fibrinogen          0
Platelets           0
Age                 0
Gender              0
HospAdmTime         0
ICULOS              0
SepsisLabel         0
SOFA                0
dtype: int64

In [7]:
imputed_df.to_parquet(f"{project_root}/dataset/imputed_sofa.parquet", index=False)

## Splitting the data

In [3]:
from get_splitted_data import get_dataset_tensors

X_train, X_test, y_train, y_test = get_dataset_tensors()

## Transformer architecture
- Use `Mean Pooling` (x.mean(dim=1)) if: You want a summary of the whole sequence (useful if sepsis patterns are spread across time).
- Use `Max Pooling` (x.max(dim=1).values) if: You want to focus on the most extreme feature values, which might indicate critical moments.
- Use `Last Timestep` (x[:, -1, :]) if: You believe the latest patient state matters most (recommended for your case).

Test pooling:
- Train with different pooling methods and compare AUC-ROC scores.
- If AUC increases, the new pooling method is better.
- Visualize attention weights to see where the model is focusing.

In [4]:
class TransformerClassifier(nn.Module):
    """
    num_heads = more heads capture different attention but increase computation
    num_layers = more make the model deeper but can overfit if too high
    """
    def __init__(self, input_dim, num_heads=4, num_layers=2):
        super().__init__()
        # d_model = input_dim (number of features)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads)
        # stacks multiple encoder layers (num_layers controls depth)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        # linear layer to map the output to a single value (binary classification)
        self.linear_layer = nn.Linear(in_features=input_dim, out_features=1)

    def forward(self, x):
      z = self.encoder(x)  # Output shape: (batch_size, seq_len, features) or (batch_size, features)
      if z.dim() == 3:  # If (batch_size, seq_len, features), take the last timestep
            # NOTE: last timestep -> the most recent ICU data is usually the most relevant for prediction
            z = z[:, -1, :]
      return self.linear_layer(z)


def get_valid_num_heads(input_dim, desired_heads):
    """Finds the highest valid num_heads <= desired_heads that divides input_dim."""
    while desired_heads > 0 and input_dim % desired_heads != 0:
        desired_heads -= 1
    return max(1, desired_heads)


in_dim = X_train.shape[1]
print(in_dim)
n_heads = get_valid_num_heads(in_dim, 10)
print(n_heads)
model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)

38
2




In [16]:
print(X_test.dtype)  # Ensure it prints torch.float16

torch.float16


In [5]:
batch_size = 512
y_pred_list = []
with torch.inference_mode():
    for i in range(0, len(X_test), batch_size):
        batch = X_test[i:i+batch_size]
        y_pred_list.append(model(batch))
y_pred = torch.cat(y_pred_list)
y_pred[:5]

tensor([[-1.0394],
        [-1.3452],
        [-1.1834],
        [-1.0828],
        [-0.7410]])

In [6]:
from torch.utils.data import DataLoader, TensorDataset


train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

batch_size = 512 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


### Training loop

In [7]:
from tqdm import tqdm
from torchmetrics import Accuracy

model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) 

epochs = 1
epoch_counter = []
loss_counter = []
acc_counter = []

t_accuracy = Accuracy(task='binary')

for epoch in range(epochs):
    model.train()
    epoch_loss, epoch_acc = 0, 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

    for X_batch, y_batch in train_loader:
        # forward pass
        y_logits = model(X_batch)
        y_probs = torch.sigmoid(y_logits)
        y_pred = torch.round(y_probs)

        # loss function
        loss = loss_fn(y_logits, y_batch.unsqueeze(1).float())
        acc = t_accuracy(y_pred, y_batch.unsqueeze(1).float())
        
        # zero grad
        optimizer.zero_grad()
        # backpropagation
        loss.backward()
        # optimizer
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # clip gradients
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc
        
        progress_bar.set_postfix({"Loss": loss.item(), "Acc": acc.item()})

    epoch_loss /= len(train_loader)
    epoch_acc /= len(train_loader)
    acc_counter.append(epoch_acc)
    
    print(f"Epoch {epoch+1}/{epochs} | Loss: {epoch_loss:.5f} | Accuracy: {epoch_acc:.2f}%")

Epoch 1/1:   0%|          | 0/2426 [03:16<?, ?it/s, Loss=0.116, Acc=0.975] 

Epoch 1/1 | Loss: 0.08523 | Accuracy: 0.98%


## Save model and predictions

In [27]:
from pathlib import Path

# save the model
model_path = Path('./saved')
model_path.mkdir(exist_ok=True)
model_file = model_path / 'simple_transformer.pth'

torch.save(model.state_dict(), model_file)

In [29]:
model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)
model.load_state_dict(torch.load("./saved/simple_transformer.pth"))



<All keys matched successfully>