<a href="https://colab.research.google.com/github/narpat78/BERT-for-Electronic-Health-Records/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Setting up the Colab Notebook

In [None]:
# mounting google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# changing the working directory
%cd /content/drive/MyDrive/BEHRT/.models

/content/drive/MyDrive/BEHRT/.models


In [None]:
import os
import sys
sys.path.append('/content/drive/MyDrive/BEHRT/.models')

In [None]:
# installing pretrained pytorch bert
!pip install pytorch_pretrained_bert

Collecting pytorch_pretrained_bert
  Downloading pytorch_pretrained_bert-0.6.2-py3-none-any.whl.metadata (86 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/86.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.7/86.7 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
Collecting boto3 (from pytorch_pretrained_bert)
  Downloading boto3-1.38.3-py3-none-any.whl.metadata (6.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=0.4.1->pytorch_pretrained_bert)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=0.4.1->pytorch_pretrained_bert)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=0.4.1->pytorch_pretrained_bert)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.

#### Loading the Required Libraries

In [None]:
# importing dependencies
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from behrt_model import BertConfig, BertForEHRPrediction, DataLoader as BehrtDataLoader
from behrt_train import train_behrt
from tqdm import tqdm

In [None]:
# loading the dataset
df = pd.read_csv('/content/drive/MyDrive/BEHRT/.data/extended_mimiciii.csv')
df.head()

Unnamed: 0.1,Unnamed: 0,subject_id,hadm_id,admittime,diagnosis,icd9_code,gender,dob,age,insurance,ethnicity
0,0,10006,142345,2064-10-23 00:00:00,SEPSIS,"['99591', '99662', '5672', '40391', '42731', '...",F,1994-03-05,70,Medicare,BLACK/AFRICAN AMERICAN
1,1,10011,105331,2026-08-14 00:00:00,HEPATITIS B,"['570', '07030', '07054', '30401', '2875', '27...",F,1990-06-05,36,Private,UNKNOWN/NOT SPECIFIED
2,2,10013,165520,2025-10-04 00:00:00,SEPSIS,"['0389', '41071', '78551', '486', '42731', '20...",F,1938-09-03,87,Medicare,UNKNOWN/NOT SPECIFIED
3,3,10017,199207,2049-05-26 00:00:00,HUMERAL FRACTURE,"['81201', '4928', '8028', '8024', '99812', '41...",F,1975-09-21,74,Medicare,WHITE
4,4,10019,177759,2063-05-14 00:00:00,ALCOHOLIC HEPATITIS,"['0389', '51881', '5770', '30390', '5781', '58...",M,2014-06-20,49,Medicare,WHITE


In [None]:
# dropping irrelevant column
df.drop(['Unnamed: 0'], axis=1, inplace=True)
df.head()

Unnamed: 0,subject_id,hadm_id,admittime,diagnosis,icd9_code,gender,dob,age,insurance,ethnicity
0,10006,142345,2064-10-23 00:00:00,SEPSIS,"['99591', '99662', '5672', '40391', '42731', '...",F,1994-03-05,70,Medicare,BLACK/AFRICAN AMERICAN
1,10011,105331,2026-08-14 00:00:00,HEPATITIS B,"['570', '07030', '07054', '30401', '2875', '27...",F,1990-06-05,36,Private,UNKNOWN/NOT SPECIFIED
2,10013,165520,2025-10-04 00:00:00,SEPSIS,"['0389', '41071', '78551', '486', '42731', '20...",F,1938-09-03,87,Medicare,UNKNOWN/NOT SPECIFIED
3,10017,199207,2049-05-26 00:00:00,HUMERAL FRACTURE,"['81201', '4928', '8028', '8024', '99812', '41...",F,1975-09-21,74,Medicare,WHITE
4,10019,177759,2063-05-14 00:00:00,ALCOHOLIC HEPATITIS,"['0389', '51881', '5770', '30390', '5781', '58...",M,2014-06-20,49,Medicare,WHITE


#### Data Preprocessing

In [None]:
# label encoding features
label_encoder = LabelEncoder()

df['gender'] = label_encoder.fit_transform(df['gender']) # 0 > F, 1 >M
df['labels'] = label_encoder.fit_transform(df['diagnosis'])
df['ins'] = label_encoder.fit_transform(df['insurance'])
df['ethni'] = label_encoder.fit_transform(df['ethnicity'])

all_codes = []
for codes in df['icd9_code']:
    all_codes.extend(eval(codes))
all_codes = list(set(all_codes))

label_encoder.fit(all_codes)
df['code'] = df['icd9_code'].apply(lambda x: label_encoder.transform(eval(x)))

In [None]:
# building sequential data
mimic_df = df[['gender', 'age', 'ethni', 'ins', 'code', 'labels']]

mimic_df['ethni'] = mimic_df.apply(lambda row: [row['ethni']] * len(row['code']), axis=1)
mimic_df['ins'] = mimic_df.apply(lambda row: [row['ins']] * len(row['code']), axis=1)
mimic_df['age'] = mimic_df.apply(lambda row: [row['age']] * len(row['code']), axis=1)
mimic_df['gender'] = mimic_df.apply(lambda row: [row['gender']] * 1, axis=1)
mimic_df['labels'] = mimic_df.apply(lambda row: [row['labels']] * 1, axis=1)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  mimic_df['ethni'] = mimic_df.apply(lambda row: [row['ethni']] * len(row['code']), axis=1)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  mimic_df['ins'] = mimic_df.apply(lambda row: [row['ins']] * len(row['code']), axis=1)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  mimic_df['age'] = mimic_df.ap

In [None]:
mimic_df.head()

Unnamed: 0,gender,age,ethni,ins,code,labels
0,[0],"[70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 7...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[472, 476, 323, 196, 220, 226, 215, 214, 124, ...",[73]
1,[0],"[36, 36, 36, 36, 36, 36]","[7, 7, 7, 7, 7, 7]","[3, 3, 3, 3, 3, 3]","[330, 22, 25, 148, 125, 93]",[36]
2,[0],"[87, 87, 87, 87, 87, 87, 87, 87, 87]","[7, 7, 7, 7, 7, 7, 7, 7, 7]","[2, 2, 2, 2, 2, 2, 2, 2, 2]","[11, 199, 417, 277, 220, 60, 214, 264, 86]",[73]
3,[0],"[74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 7...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]","[449, 280, 443, 442, 487, 208, 116, 277, 507, ...",[37]
4,[1],"[49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 4...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]","[11, 293, 344, 145, 347, 350, 114, 334, 419, 4...",[8]


In [None]:
mimic_df.to_csv('/content/drive/MyDrive/BEHRT/.data/encoded_extended_mimic.csv')

In [None]:
# model configurations
config = {
    'vocab_size': 580,
    'hidden_size': 768,
    'num_hidden_layers': 12,
    'num_attention_heads': 12,
    'intermediate_size': 3072,
    'hidden_act': 'gelu',
    'hidden_dropout_prob': 0.1,
    'attention_probs_dropout_prob': 0.1,
    'max_position_embedding': 512,
    'seg_vocab_size': 2,
    'age_vocab_size': 100,
    'gender_vocab_size': 2,
    'ethni_vocab_size': 9,
    'ins_vocab_size': 4,
    'number_output': 95,  # Number of classes
    'batch_size': 8,
    'use_cuda': torch.cuda.is_available(),
    'max_len_seq': 37,
    'train_loader_workers': 4,
    'test_loader_workers': 4,
    'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    'output_dir': './output',
    'output_name': 'model.pth',
    'best_name': 'best_model.pth',
    'initializer_range': 0.02
}

#### Training and Testing Sets

In [None]:
# dataset
dataset = mimic_df[['code', 'age', 'gender', 'ethni', 'ins', 'labels']]

In [None]:
# X and y
X = dataset[['code', 'age', 'gender', 'ethni', 'ins']]
y = dataset['labels']

# splitting the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train = X_train.reset_index(drop=True)
X_test = X_test.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
y_test = y_test.reset_index(drop=True)

# combining X_train with y_train and X_test with y_test to create full training and testing datasets
train_dataset = pd.concat([X_train, y_train], axis=1)
test_dataset = pd.concat([X_test, y_test], axis=1)

# shapes of the resulting datasets
print(f"Training dataset shape: {train_dataset.shape}")
print(f"Testing dataset shape: {test_dataset.shape}")

Training dataset shape: (738, 6)
Testing dataset shape: (185, 6)


In [None]:
# training data and train loader
train_data = BehrtDataLoader(train_dataset, max_len=config['max_len_seq'])
train_loader = DataLoader(train_data, batch_size=config['batch_size'], shuffle=True)

In [None]:
# model initialization
model = BertForEHRPrediction(BertConfig(config), num_labels=config['number_output'])

In [None]:
# criteria and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

#### Training and Saving the Model

In [None]:
# setting up directory for the model
base_dir = '/content/drive/My Drive/BEHRT/.saved_models'
model_dir = os.path.join(base_dir, 'models')
os.makedirs(model_dir, exist_ok=True)

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch")

    for batch_idx, batch in enumerate(progress_bar):
        input_ids, age_ids, gender_ids, ethni_ids, ins_ids, seg_ids, posi_ids, attMask, labels = batch
        optimizer.zero_grad()
        logits = model(input_ids, age_ids, gender_ids, ethni_ids, ins_ids, seg_ids, posi_ids, attention_mask=attMask)
        labels = labels.long().flatten()
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})

    avg_loss = epoch_loss/len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

Epoch 1/10: 100%|██████████| 93/93 [03:36<00:00,  2.33s/batch, loss=4.53]


Epoch 1/10 - Average Loss: 4.6019


Epoch 2/10: 100%|██████████| 93/93 [03:34<00:00,  2.31s/batch, loss=4.34]


Epoch 2/10 - Average Loss: 3.9336


Epoch 3/10: 100%|██████████| 93/93 [03:34<00:00,  2.31s/batch, loss=2.37]


Epoch 3/10 - Average Loss: 2.7032


Epoch 4/10: 100%|██████████| 93/93 [03:35<00:00,  2.32s/batch, loss=2.31]


Epoch 4/10 - Average Loss: 1.7527


Epoch 5/10: 100%|██████████| 93/93 [03:34<00:00,  2.31s/batch, loss=0.483]


Epoch 5/10 - Average Loss: 1.1709


Epoch 6/10: 100%|██████████| 93/93 [03:41<00:00,  2.38s/batch, loss=0.319]


Epoch 6/10 - Average Loss: 0.8018


Epoch 7/10: 100%|██████████| 93/93 [03:38<00:00,  2.35s/batch, loss=0.4]


Epoch 7/10 - Average Loss: 0.5776


Epoch 8/10: 100%|██████████| 93/93 [03:40<00:00,  2.37s/batch, loss=0.267]


Epoch 8/10 - Average Loss: 0.4354


Epoch 9/10: 100%|██████████| 93/93 [03:37<00:00,  2.34s/batch, loss=0.316]


Epoch 9/10 - Average Loss: 0.3310


Epoch 10/10: 100%|██████████| 93/93 [03:38<00:00,  2.35s/batch, loss=0.279]

Epoch 10/10 - Average Loss: 0.2621





In [None]:
# saving the final model
final_model_path = os.path.join(model_dir, 'final_model.pt')
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved at {final_model_path}")

Final model saved at /content/drive/My Drive/BEHRT/.saved_models/models/final_model.pt


#### Testing

In [None]:
# loading the trained model
model = BertForEHRPrediction(BertConfig(config), num_labels=config['number_output'])
model_path = os.path.join(model_dir, 'final_model.pt')
model.load_state_dict(torch.load(model_path))
model.to(config['device'])

BertForEHRPrediction(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(580, 768)
      (segment_embeddings): Embedding(2, 768)
      (age_embeddings): Embedding(100, 768)
      (gender_embeddings): Embedding(2, 768)
      (ethnicity_embeddings): Embedding(9, 768)
      (ins_embeddings): Embedding(4, 768)
      (posi_embeddings): Embedding(512, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): 

In [None]:
# preparing the test dataset
test_data = BehrtDataLoader(test_dataset, max_len=config['max_len_seq'])
test_loader = DataLoader(test_data, batch_size=config['batch_size'], shuffle=False)

In [None]:
# defining our model evaluation function
def evaluate_metrics(logits, labels):
    # converting logits to probabilities
    probabilities = torch.softmax(logits, dim=1)
    predicted_labels = torch.argmax(probabilities, dim=1)

    # converting labels and predictions to numpy arrays
    labels_np = labels.cpu().numpy()
    predicted_np = predicted_labels.cpu().numpy()

    # calculating metrics
    accuracy = accuracy_score(labels_np, predicted_np)
    precision = precision_score(labels_np, predicted_np, average='weighted')
    recall = recall_score(labels_np, predicted_np, average='weighted')
    f1 = f1_score(labels_np, predicted_np, average='weighted')

    # confusion matrix
    conf_matrix = confusion_matrix(labels_np, predicted_np)

    return accuracy, precision, recall, f1, conf_matrix

In [None]:
# model evaluation
model.eval()
all_logits = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        input_ids, age_ids, gender_ids, ethni_ids, ins_ids, seg_ids, posi_ids, attMask, labels = batch

        # moving the inputs to the device
        input_ids = input_ids.to(config['device'])
        age_ids = age_ids.to(config['device'])
        gender_ids = gender_ids.to(config['device'])
        ethni_ids = ethni_ids.to(config['device'])
        ins_ids = ins_ids.to(config['device'])
        seg_ids = seg_ids.to(config['device'])
        posi_ids = posi_ids.to(config['device'])
        attMask = attMask.to(config['device'])
        labels = labels.to(config['device'])

        # forward pass
        logits = model(input_ids, age_ids, gender_ids, ethni_ids, ins_ids, seg_ids, posi_ids, attention_mask=attMask)

        # storing logits and labels
        all_logits.append(logits)
        all_labels.append(labels)

# concatenating all logits and labels
all_logits = torch.cat(all_logits)
all_labels = torch.cat(all_labels)

#### Results

In [None]:
# evaluating metrics
accuracy, precision, recall, f1, conf_matrix = evaluate_metrics(all_logits, all_labels)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
# results
print(f"Test Accuracy: {accuracy*100:.2f}%\n")
print(f"Precision: {precision:.4f}\n")
print(f"Recall: {recall:.4f}\n")
print(f"F1 Score: {f1:.4f}\n")

Test Accuracy: 96.22%

Precision: 0.9685

Recall: 0.9622

F1 Score: 0.9590

