# Exploring transformers architecture

In [2]:
import os
import sys
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from torch import nn

print(torch.__version__)

notebook_dir = os.getcwd()

project_root = os.path.abspath(os.path.join(notebook_dir, "../.."))
print(project_root)
if project_root not in sys.path:
    sys.path.append(project_root)

2.2.2
/Users/damianstone/Documents/Code/machine-learning/dl-sepsis-prediction


## Dataset

In [2]:
from utils import get_data

DATA_PATH = get_data.get_dataset_abspath()
load_path = os.path.join(DATA_PATH, "imputed_combined_data.parquet")
imputed_df = pd.read_parquet(load_path)

In [3]:
imputed_df.head(1)

Unnamed: 0,HR,O2Sat,Temp,SBP,MAP,DBP,Resp,BaseExcess,HCO3,FiO2,...,Age,Gender,Unit1,Unit2,HospAdmTime,ICULOS,SepsisLabel,patient_id,dataset,cluster_id
0,102.108491,91.419811,36.789519,128.165094,88.199717,71.73211,24.712264,-0.288406,23.835971,0.467029,...,83.14,0,0.715787,0.284213,-0.03,1,0,1,A,0_9_2_2_X_X_X


In [4]:
if 'SOFA' in imputed_df.columns:
    print("Column exists.")
else:
    print("Column doesn't exist.")
    
imputed_df.columns

Column doesn't exist.


Index(['HR', 'O2Sat', 'Temp', 'SBP', 'MAP', 'DBP', 'Resp', 'BaseExcess',
       'HCO3', 'FiO2', 'pH', 'PaCO2', 'SaO2', 'AST', 'BUN', 'Alkalinephos',
       'Calcium', 'Chloride', 'Creatinine', 'Bilirubin_direct', 'Glucose',
       'Lactate', 'Magnesium', 'Phosphate', 'Potassium', 'Bilirubin_total',
       'TroponinI', 'Hct', 'Hgb', 'PTT', 'WBC', 'Fibrinogen', 'Platelets',
       'Age', 'Gender', 'Unit1', 'Unit2', 'HospAdmTime', 'ICULOS',
       'SepsisLabel', 'patient_id', 'dataset', 'cluster_id'],
      dtype='object')

## Last feature engineering

In [5]:
# SOFA calculation based on Sepsis-3
def calculate_sofa(row):
    sofa = 0
    
    if row['FiO2'] > 0:
        pao2_fio2 = row['SaO2'] / row['FiO2']
        if pao2_fio2 < 100: sofa += 4
        elif pao2_fio2 < 200: sofa += 3
        elif pao2_fio2 < 300: sofa += 2
        elif pao2_fio2 < 400: sofa += 1

    if row['Platelets'] < 20: sofa += 4
    elif row['Platelets'] < 50: sofa += 3
    elif row['Platelets'] < 100: sofa += 2
    elif row['Platelets'] < 150: sofa += 1

    if row['Bilirubin_total'] >= 12: sofa += 4
    elif row['Bilirubin_total'] >= 6: sofa += 3
    elif row['Bilirubin_total'] >= 2: sofa += 2
    elif row['Bilirubin_total'] >= 1.2: sofa += 1

    if row['MAP'] < 70: sofa += 1

    if row['Creatinine'] >= 5: sofa += 4
    elif row['Creatinine'] >= 3.5: sofa += 3
    elif row['Creatinine'] >= 2: sofa += 2
    elif row['Creatinine'] >= 1.2: sofa += 1

    return sofa

imputed_df = imputed_df.drop(columns=["Unit1", "Unit2", "cluster_id","dataset", "patient_id"], errors='ignore')
imputed_df = imputed_df.dropna(subset=["HospAdmTime"])
imputed_df["SOFA"] = imputed_df.apply(calculate_sofa, axis=1)
imputed_df.head(1)


Unnamed: 0,HR,O2Sat,Temp,SBP,MAP,DBP,Resp,BaseExcess,HCO3,FiO2,...,PTT,WBC,Fibrinogen,Platelets,Age,Gender,HospAdmTime,ICULOS,SepsisLabel,SOFA
0,102.108491,91.419811,36.789519,128.165094,88.199717,71.73211,24.712264,-0.288406,23.835971,0.467029,...,36.405357,11.737903,350.0,224.187135,83.14,0,-0.03,1,0,2


In [6]:
imputed_df.isna().sum()

HR                  0
O2Sat               0
Temp                0
SBP                 0
MAP                 0
DBP                 0
Resp                0
BaseExcess          0
HCO3                0
FiO2                0
pH                  0
PaCO2               0
SaO2                0
AST                 0
BUN                 0
Alkalinephos        0
Calcium             0
Chloride            0
Creatinine          0
Bilirubin_direct    0
Glucose             0
Lactate             0
Magnesium           0
Phosphate           0
Potassium           0
Bilirubin_total     0
TroponinI           0
Hct                 0
Hgb                 0
PTT                 0
WBC                 0
Fibrinogen          0
Platelets           0
Age                 0
Gender              0
HospAdmTime         0
ICULOS              0
SepsisLabel         0
SOFA                0
dtype: int64

In [7]:
imputed_df.to_parquet(f"{project_root}/dataset/imputed_sofa.parquet", index=False)

## Splitting the data

In [3]:
from get_splitted_data import get_dataset_tensors

X_train, X_test, y_train, y_test = get_dataset_tensors()

## Transformer architecture
- Use `Mean Pooling` (x.mean(dim=1)) if: You want a summary of the whole sequence (useful if sepsis patterns are spread across time).
- Use `Max Pooling` (x.max(dim=1).values) if: You want to focus on the most extreme feature values, which might indicate critical moments.
- Use `Last Timestep` (x[:, -1, :]) if: You believe the latest patient state matters most (recommended for your case).

Test pooling:
- Train with different pooling methods and compare AUC-ROC scores.
- If AUC increases, the new pooling method is better.
- Visualize attention weights to see where the model is focusing.

In [4]:
class TransformerClassifier(nn.Module):
    """
    num_heads = more heads capture different attention but increase computation
    num_layers = more make the model deeper but can overfit if too high
    """
    def __init__(self, input_dim, num_heads=4, num_layers=2):
        super().__init__()
        # d_model = input_dim (number of features)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads)
        # stacks multiple encoder layers (num_layers controls depth)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        # linear layer to map the output to a single value (binary classification)
        self.linear_layer = nn.Linear(in_features=input_dim, out_features=1)

    def forward(self, x):
      z = self.encoder(x)  # Output shape: (batch_size, seq_len, features) or (batch_size, features)
      if z.dim() == 3:  # If (batch_size, seq_len, features), take the last timestep
            # NOTE: last timestep -> the most recent ICU data is usually the most relevant for prediction
            z = z[:, -1, :]
      return self.linear_layer(z)


def get_valid_num_heads(input_dim, desired_heads):
    """Finds the highest valid num_heads <= desired_heads that divides input_dim."""
    while desired_heads > 0 and input_dim % desired_heads != 0:
        desired_heads -= 1
    return max(1, desired_heads)


in_dim = X_train.shape[1]
print(in_dim)
n_heads = get_valid_num_heads(in_dim, 10)
print(n_heads)
model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)

38
2




In [16]:
print(X_test.dtype)  # Ensure it prints torch.float16

torch.float16


In [5]:
batch_size = 512
y_pred_list = []
with torch.inference_mode():
    for i in range(0, len(X_test), batch_size):
        batch = X_test[i:i+batch_size]
        y_pred_list.append(model(batch))
y_pred = torch.cat(y_pred_list)
y_pred[:5]

tensor([[-1.0394],
        [-1.3452],
        [-1.1834],
        [-1.0828],
        [-0.7410]])

In [6]:
from torch.utils.data import DataLoader, TensorDataset


train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

batch_size = 512 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


### Training loop

In [32]:
from tqdm import tqdm
from torchmetrics import Accuracy

model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) 

epochs = 100
epoch_counter = []
loss_counter = []
acc_counter = []

t_accuracy = Accuracy(task='binary')

for epoch in range(epochs):
    model.train()
    epoch_loss, epoch_acc = 0, 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

    for X_batch, y_batch in train_loader:
        # forward pass
        y_logits = model(X_batch)
        y_probs = torch.sigmoid(y_logits)
        y_pred = torch.round(y_probs)

        # loss function
        loss = loss_fn(y_logits, y_batch.unsqueeze(1).float())
        acc = t_accuracy(y_pred, y_batch.unsqueeze(1).float())
        
        # zero grad
        optimizer.zero_grad()
        # backpropagation
        loss.backward()
        # optimizer
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # clip gradients
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc
        
        progress_bar.set_postfix({"Loss": loss.item(), "Acc": acc.item()})

    epoch_loss /= len(train_loader)
    epoch_acc /= len(train_loader)
    acc_counter.append(epoch_acc)
    
    print(f"Epoch {epoch+1}/{epochs} | Loss: {epoch_loss:.5f} | Accuracy: {epoch_acc:.2f}%")

Epoch 1/100:   0%|          | 0/2426 [03:49<?, ?it/s, Loss=0.0405, Acc=0.994]

Epoch 1/100 | Loss: 0.08505 | Accuracy: 0.98%


                                                                             

Epoch 2/100 | Loss: 0.08192 | Accuracy: 0.98%


Epoch 3/100:   0%|          | 0/2426 [03:54<?, ?it/s, Loss=0.117, Acc=0.969] 

Epoch 3/100 | Loss: 0.08116 | Accuracy: 0.98%


                                                                            

Epoch 4/100 | Loss: 0.08071 | Accuracy: 0.98%


Epoch 5/100:   0%|          | 0/2426 [04:06<?, ?it/s, Loss=0.0854, Acc=0.981]

Epoch 5/100 | Loss: 0.08040 | Accuracy: 0.98%


                                                                             

Epoch 6/100 | Loss: 0.08015 | Accuracy: 0.98%


Epoch 7/100:   0%|          | 0/2426 [04:13<?, ?it/s, Loss=0.0705, Acc=0.988]

Epoch 7/100 | Loss: 0.07994 | Accuracy: 0.98%


                                                                             

Epoch 8/100 | Loss: 0.07965 | Accuracy: 0.98%


Epoch 9/100:   0%|          | 0/2426 [04:05<?, ?it/s, Loss=0.054, Acc=0.988] 

Epoch 9/100 | Loss: 0.07938 | Accuracy: 0.98%


                                                                            

Epoch 10/100 | Loss: 0.07916 | Accuracy: 0.98%


Epoch 11/100:   0%|          | 0/2426 [04:05<?, ?it/s, Loss=0.142, Acc=0.963] 

Epoch 11/100 | Loss: 0.07895 | Accuracy: 0.98%


                                                                             

Epoch 12/100 | Loss: 0.07856 | Accuracy: 0.98%


Epoch 13/100:   0%|          | 0/2426 [04:26<?, ?it/s, Loss=0.0648, Acc=0.988]

Epoch 13/100 | Loss: 0.07830 | Accuracy: 0.98%


                                                                              

Epoch 14/100 | Loss: 0.07793 | Accuracy: 0.98%


Epoch 15/100:   0%|          | 0/2426 [04:06<?, ?it/s, Loss=0.119, Acc=0.969] 

Epoch 15/100 | Loss: 0.07759 | Accuracy: 0.98%


                                                                             

Epoch 16/100 | Loss: 0.07724 | Accuracy: 0.98%


Epoch 17/100:   0%|          | 0/2426 [04:01<?, ?it/s, Loss=0.0756, Acc=0.981]

Epoch 17/100 | Loss: 0.07689 | Accuracy: 0.98%


                                                                              

Epoch 18/100 | Loss: 0.07656 | Accuracy: 0.98%


Epoch 19/100:   0%|          | 0/2426 [04:01<?, ?it/s, Loss=0.0661, Acc=0.975]

Epoch 19/100 | Loss: 0.07629 | Accuracy: 0.98%


                                                                              

Epoch 20/100 | Loss: 0.07587 | Accuracy: 0.98%


Epoch 21/100:   0%|          | 0/2426 [03:57<?, ?it/s, Loss=0.182, Acc=0.957] 

Epoch 21/100 | Loss: 0.07551 | Accuracy: 0.98%


                                                                             

Epoch 22/100 | Loss: 0.07507 | Accuracy: 0.98%


Epoch 23/100:   0%|          | 0/2426 [04:06<?, ?it/s, Loss=0.2, Acc=0.944]   

Epoch 23/100 | Loss: 0.07474 | Accuracy: 0.98%


                                                                           

Epoch 24/100 | Loss: 0.07441 | Accuracy: 0.98%


Epoch 25/100:   0%|          | 0/2426 [04:38<?, ?it/s, Loss=0.0373, Acc=0.988]

Epoch 25/100 | Loss: 0.07412 | Accuracy: 0.98%


                                                                              

Epoch 26/100 | Loss: 0.07412 | Accuracy: 0.98%


Epoch 27/100:   0%|          | 0/2426 [04:32<?, ?it/s, Loss=0.0811, Acc=0.975]

Epoch 27/100 | Loss: 0.07365 | Accuracy: 0.98%


                                                                              

Epoch 28/100 | Loss: 0.07323 | Accuracy: 0.98%


Epoch 29/100:   0%|          | 0/2426 [04:03<?, ?it/s, Loss=0.109, Acc=0.981] 

Epoch 29/100 | Loss: 0.07303 | Accuracy: 0.98%


                                                                             

Epoch 30/100 | Loss: 0.07265 | Accuracy: 0.98%


Epoch 31/100:   0%|          | 0/2426 [04:10<?, ?it/s, Loss=0.0905, Acc=0.981]

Epoch 31/100 | Loss: 0.07234 | Accuracy: 0.98%


                                                                              

Epoch 32/100 | Loss: 0.07196 | Accuracy: 0.98%


Epoch 33/100:   0%|          | 0/2426 [04:02<?, ?it/s, Loss=0.074, Acc=0.988] 

Epoch 33/100 | Loss: 0.07180 | Accuracy: 0.98%


                                                                             

Epoch 34/100 | Loss: 0.07149 | Accuracy: 0.98%


Epoch 35/100:   0%|          | 0/2426 [04:03<?, ?it/s, Loss=0.0875, Acc=0.981]

Epoch 35/100 | Loss: 0.07115 | Accuracy: 0.98%


                                                                              

Epoch 36/100 | Loss: 0.07099 | Accuracy: 0.98%


Epoch 37/100:   0%|          | 0/2426 [04:06<?, ?it/s, Loss=0.0414, Acc=0.988]

Epoch 37/100 | Loss: 0.07067 | Accuracy: 0.98%


                                                                              

Epoch 38/100 | Loss: 0.07038 | Accuracy: 0.98%


Epoch 39/100:   0%|          | 0/2426 [04:05<?, ?it/s, Loss=0.0424, Acc=0.988]

Epoch 39/100 | Loss: 0.07009 | Accuracy: 0.98%


                                                                              

Epoch 40/100 | Loss: 0.06982 | Accuracy: 0.98%


Epoch 41/100:   0%|          | 0/2426 [04:16<?, ?it/s, Loss=0.101, Acc=0.975] 

Epoch 41/100 | Loss: 0.06954 | Accuracy: 0.98%


                                                                             

Epoch 42/100 | Loss: 0.06939 | Accuracy: 0.98%


Epoch 43/100:   0%|          | 0/2426 [04:10<?, ?it/s, Loss=0.123, Acc=0.969] 

Epoch 43/100 | Loss: 0.06899 | Accuracy: 0.98%


                                                                             

Epoch 44/100 | Loss: 0.06863 | Accuracy: 0.98%


Epoch 45/100:   0%|          | 0/2426 [04:25<?, ?it/s, Loss=0.142, Acc=0.969] 

Epoch 45/100 | Loss: 0.06829 | Accuracy: 0.98%


                                                                             

Epoch 46/100 | Loss: 0.06812 | Accuracy: 0.98%


Epoch 47/100:   0%|          | 0/2426 [03:56<?, ?it/s, Loss=0.0167, Acc=1]    

Epoch 47/100 | Loss: 0.06776 | Accuracy: 0.98%


                                                                          

Epoch 48/100 | Loss: 0.06763 | Accuracy: 0.98%


Epoch 49/100:   0%|          | 0/2426 [04:01<?, ?it/s, Loss=0.0656, Acc=0.988]

Epoch 49/100 | Loss: 0.06753 | Accuracy: 0.98%


                                                                              

Epoch 50/100 | Loss: 0.06753 | Accuracy: 0.98%


Epoch 51/100:   0%|          | 0/2426 [04:10<?, ?it/s, Loss=0.0391, Acc=0.988]

Epoch 51/100 | Loss: 0.06751 | Accuracy: 0.98%


                                                                              

Epoch 52/100 | Loss: 0.06714 | Accuracy: 0.98%


Epoch 53/100:   0%|          | 0/2426 [04:09<?, ?it/s, Loss=0.136, Acc=0.969] 

Epoch 53/100 | Loss: 0.06695 | Accuracy: 0.98%


                                                                             

Epoch 54/100 | Loss: 0.06674 | Accuracy: 0.98%


Epoch 55/100:   0%|          | 0/2426 [04:10<?, ?it/s, Loss=0.0766, Acc=0.975]

Epoch 55/100 | Loss: 0.06675 | Accuracy: 0.98%


                                                                              

Epoch 56/100 | Loss: 0.06637 | Accuracy: 0.98%


Epoch 57/100:   0%|          | 0/2426 [04:03<?, ?it/s, Loss=0.0455, Acc=0.988]

Epoch 57/100 | Loss: 0.06631 | Accuracy: 0.98%


                                                                              

Epoch 58/100 | Loss: 0.06612 | Accuracy: 0.98%


Epoch 59/100:   0%|          | 0/2426 [04:01<?, ?it/s, Loss=0.0762, Acc=0.981]

Epoch 59/100 | Loss: 0.06548 | Accuracy: 0.98%


                                                                              

Epoch 60/100 | Loss: 0.06522 | Accuracy: 0.98%


Epoch 61/100:   0%|          | 0/2426 [04:06<?, ?it/s, Loss=0.0908, Acc=0.981]

Epoch 61/100 | Loss: 0.06531 | Accuracy: 0.98%


                                                                              

Epoch 62/100 | Loss: 0.06518 | Accuracy: 0.98%


Epoch 63/100:   0%|          | 0/2426 [04:12<?, ?it/s, Loss=0.0698, Acc=0.988]

Epoch 63/100 | Loss: 0.06491 | Accuracy: 0.98%


                                                                              

Epoch 64/100 | Loss: 0.06453 | Accuracy: 0.98%


Epoch 65/100:   0%|          | 0/2426 [04:03<?, ?it/s, Loss=0.0834, Acc=0.969]

Epoch 65/100 | Loss: 0.06442 | Accuracy: 0.98%


                                                                              

Epoch 66/100 | Loss: 0.06451 | Accuracy: 0.98%


Epoch 67/100:   0%|          | 0/2426 [04:08<?, ?it/s, Loss=0.0956, Acc=0.975]

Epoch 67/100 | Loss: 0.06427 | Accuracy: 0.98%


                                                                              

Epoch 68/100 | Loss: 0.06444 | Accuracy: 0.98%


Epoch 69/100:   0%|          | 0/2426 [04:05<?, ?it/s, Loss=0.0488, Acc=0.988]

Epoch 69/100 | Loss: 0.06419 | Accuracy: 0.98%


                                                                              

Epoch 70/100 | Loss: 0.06393 | Accuracy: 0.98%


Epoch 71/100:   0%|          | 0/2426 [04:10<?, ?it/s, Loss=0.0379, Acc=0.988]

Epoch 71/100 | Loss: 0.06373 | Accuracy: 0.98%


                                                                              

Epoch 72/100 | Loss: 0.06380 | Accuracy: 0.98%


Epoch 73/100:   0%|          | 0/2426 [04:16<?, ?it/s, Loss=0.0658, Acc=0.975]

Epoch 73/100 | Loss: 0.06341 | Accuracy: 0.98%


                                                                              

Epoch 74/100 | Loss: 0.06337 | Accuracy: 0.98%


Epoch 75/100:   0%|          | 0/2426 [04:04<?, ?it/s, Loss=0.0333, Acc=0.988]

Epoch 75/100 | Loss: 0.06309 | Accuracy: 0.98%


                                                                              

Epoch 76/100 | Loss: 0.06340 | Accuracy: 0.98%


Epoch 77/100:   0%|          | 0/2426 [04:04<?, ?it/s, Loss=0.0469, Acc=0.988]

Epoch 77/100 | Loss: 0.06297 | Accuracy: 0.98%


                                                                              

Epoch 78/100 | Loss: 0.06270 | Accuracy: 0.98%


Epoch 79/100:   0%|          | 0/2426 [04:08<?, ?it/s, Loss=0.0931, Acc=0.969]

Epoch 79/100 | Loss: 0.06254 | Accuracy: 0.98%


                                                                              

Epoch 80/100 | Loss: 0.06230 | Accuracy: 0.98%


Epoch 81/100:   0%|          | 0/2426 [04:04<?, ?it/s, Loss=0.0415, Acc=0.981]

Epoch 81/100 | Loss: 0.06209 | Accuracy: 0.98%


                                                                              

Epoch 82/100 | Loss: 0.06184 | Accuracy: 0.98%


Epoch 83/100:   0%|          | 0/2426 [04:23<?, ?it/s, Loss=0.0144, Acc=1]    

Epoch 83/100 | Loss: 0.06173 | Accuracy: 0.98%


                                                                          

Epoch 84/100 | Loss: 0.06153 | Accuracy: 0.98%


Epoch 85/100:   0%|          | 0/2426 [04:13<?, ?it/s, Loss=0.151, Acc=0.957] 

Epoch 85/100 | Loss: 0.06157 | Accuracy: 0.98%


                                                                             

Epoch 86/100 | Loss: 0.06120 | Accuracy: 0.98%


Epoch 87/100:   0%|          | 0/2426 [04:13<?, ?it/s, Loss=0.0887, Acc=0.969]

Epoch 87/100 | Loss: 0.06104 | Accuracy: 0.98%


                                                                              

Epoch 88/100 | Loss: 0.06099 | Accuracy: 0.98%


Epoch 89/100:   0%|          | 0/2426 [04:02<?, ?it/s, Loss=0.0668, Acc=0.988]

Epoch 89/100 | Loss: 0.06093 | Accuracy: 0.98%


                                                                              

Epoch 90/100 | Loss: 0.06075 | Accuracy: 0.98%


Epoch 91/100:   0%|          | 0/2426 [04:13<?, ?it/s, Loss=0.0534, Acc=0.981]

Epoch 91/100 | Loss: 0.06120 | Accuracy: 0.98%


                                                                              

Epoch 92/100 | Loss: 0.06076 | Accuracy: 0.98%


Epoch 93/100:   0%|          | 0/2426 [04:15<?, ?it/s, Loss=0.0896, Acc=0.975]

Epoch 93/100 | Loss: 0.06097 | Accuracy: 0.98%


                                                                              

Epoch 94/100 | Loss: 0.06054 | Accuracy: 0.98%


Epoch 95/100:   0%|          | 0/2426 [04:13<?, ?it/s, Loss=0.0382, Acc=0.994]

Epoch 95/100 | Loss: 0.06015 | Accuracy: 0.98%


                                                                              

Epoch 96/100 | Loss: 0.06057 | Accuracy: 0.98%


Epoch 97/100:   0%|          | 0/2426 [04:14<?, ?it/s, Loss=0.0398, Acc=0.994]

Epoch 97/100 | Loss: 0.06027 | Accuracy: 0.98%


                                                                              

Epoch 98/100 | Loss: 0.06029 | Accuracy: 0.98%


Epoch 99/100:   0%|          | 0/2426 [04:16<?, ?it/s, Loss=0.144, Acc=0.969] 

Epoch 99/100 | Loss: 0.06031 | Accuracy: 0.98%


                                                                             

Epoch 100/100 | Loss: 0.05989 | Accuracy: 0.98%


## Save model and predictions

In [33]:
from pathlib import Path

# save the model
model_path = Path('./saved')
model_path.mkdir(exist_ok=True)
model_file = model_path / '01_simple_transformer.pth'

torch.save(model.state_dict(), model_file)

In [34]:
model = TransformerClassifier(input_dim=in_dim, num_heads=n_heads)
model.load_state_dict(torch.load("./saved/01_simple_transformer.pth"))

<All keys matched successfully>