In [1]:
import torch
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load Train, Validation, Test Dataset
train_dataset = pd.read_csv('saveDir/LOS_WEEKS_adm_train.csv')
val_dataset = pd.read_csv('saveDir/LOS_WEEKS_adm_val.csv')
test_dataset = pd.read_csv('saveDir/LOS_WEEKS_adm_test.csv')

In [3]:
# from transformers import AutoModelForSequenceClassification, AutoConfig

# # Specify the dropout rate in the configuration
# config = AutoConfig.from_pretrained('bvanaken/CORe-clinical-outcome-biobert-v1', 
#                                     num_labels=4, 
#                                     hidden_dropout_prob=0.2, 
#                                     attention_probs_dropout_prob=0.2)

# # Load the pre-trained model with the specified configuration
# core_model = AutoModelForSequenceClassification.from_pretrained('bvanaken/CORe-clinical-outcome-biobert-v1', config=config)


In [4]:
# from transformers import AutoModelForSequenceClassification, AutoConfig

# # Specify the dropout rate in the configuration
# config = AutoConfig.from_pretrained('emilyalsentzer/Bio_ClinicalBERT', 
#                                     num_labels=4, 
#                                     hidden_dropout_prob=0.2, 
#                                     attention_probs_dropout_prob=0.2)

# # Load the pre-trained model with the specified configuration
# clinical_model = AutoModelForSequenceClassification.from_pretrained('emilyalsentzer/Bio_ClinicalBERT', config=config)

In [5]:
# from transformers import AutoModelForSequenceClassification, AutoConfig

# # Specify the dropout rate in the configuration
# config = AutoConfig.from_pretrained('dmis-lab/biobert-base-cased-v1.2', 
#                                     num_labels=4, 
#                                     hidden_dropout_prob=0.2, 
#                                     attention_probs_dropout_prob=0.2)

# # Load the pre-trained model with the specified configuration
# base_model = AutoModelForSequenceClassification.from_pretrained('dmis-lab/biobert-base-cased-v1.2', config=config)

In [6]:
from transformers import AutoModelForSequenceClassification, AutoConfig

# Specify the dropout rate in the configuration
config = AutoConfig.from_pretrained('emilyalsentzer/Bio_Discharge_Summary_BERT', 
                                    num_labels=4, 
                                    hidden_dropout_prob=0.2, 
                                    attention_probs_dropout_prob=0.2)

# Load the pre-trained model with the specified configuration
discharge_model = AutoModelForSequenceClassification.from_pretrained('emilyalsentzer/Bio_Discharge_Summary_BERT', config=config)

Some weights of the model checkpoint at emilyalsentzer/Bio_Discharge_Summary_BERT were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from

In [7]:
from transformers import AutoTokenizer

# Choose a tokenizer
tokenizer = AutoTokenizer.from_pretrained('emilyalsentzer/Bio_Discharge_Summary_BERT')

In [8]:
from transformers import BertTokenizer, BertModel
import torch

def split_inputs(input_ids, attention_mask, tokenizer):
    # Get the special token ids
    cls_token_id = tokenizer.cls_token_id
    sep_token_id = tokenizer.sep_token_id

    # Split the tensors
    input_ids1 = input_ids[:, :510]
    input_ids2 = input_ids[:, 510:1020]

    attention_mask1 = attention_mask[:, :510]
    attention_mask2 = attention_mask[:, 510:1020]

    # Add [CLS] and [SEP] tokens
    input_ids1 = torch.cat([torch.full_like(input_ids1[:, :1], cls_token_id), input_ids1, torch.full_like(input_ids1[:, :1], sep_token_id)], dim=1)
    input_ids2 = torch.cat([torch.full_like(input_ids2[:, :1], cls_token_id), input_ids2, torch.full_like(input_ids2[:, :1], sep_token_id)], dim=1)

    attention_mask1 = torch.cat([torch.full_like(attention_mask1[:, :1], 1), attention_mask1, torch.full_like(attention_mask1[:, :1], 1)], dim=1)
    attention_mask2 = torch.cat([torch.full_like(attention_mask2[:, :1], 1), attention_mask2, torch.full_like(attention_mask2[:, :1], 1)], dim=1)

    return (input_ids1, attention_mask1), (input_ids2, attention_mask2)

In [9]:
from torch.utils.data import DataLoader
from torch import nn

class EnsembleModel(nn.Module):
    def __init__(self, model1):
        super(EnsembleModel, self).__init__()
        self.model1 = model1

    def forward(self, input_ids, attention_mask):
        # Assuming your batch size (18) tensors are input_ids and attention_mask
        (input_ids1, attention_mask1), (input_ids2, attention_mask2) = split_inputs(input_ids, attention_mask, tokenizer)

        outputs1 = self.model1(input_ids1, attention_mask=attention_mask1)[0]
        outputs2 = self.model1(input_ids2, attention_mask=attention_mask2)[0]
        
        avg_output = (outputs1 + outputs2) / 2.00
        return avg_output

In [10]:
# Apply the tokenizer to the training, validation, and test datasets
train_encodings = tokenizer(train_dataset['text'].tolist(), truncation=True, padding=True, max_length = 1020)
val_encodings = tokenizer(val_dataset['text'].tolist(), truncation=True, padding=True,  max_length = 1020)
test_encodings = tokenizer(test_dataset['text'].tolist(), truncation=True, padding=True , max_length = 1020)

In [11]:
# Create a Dataset for PyTorch
class LosDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [12]:
train_dataset = LosDataset(train_encodings, train_dataset['los_label'].tolist())
val_dataset = LosDataset(val_encodings, val_dataset['los_label'].tolist())
test_dataset = LosDataset(test_encodings, test_dataset['los_label'].tolist())

In [13]:
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, accuracy_score
from tqdm import tqdm
from torch import nn
import numpy as np

# Create the ensemble model
ensemble_model = EnsembleModel(discharge_model)

In [15]:
import os

# list all files in the current directory
files = os.listdir('.')

# filter the ones that start with 'CORE_baseline'
core_models = [f for f in files if f.startswith('dischargeBERT_chunk')]

if core_models:
    print("Found models starting with 'dischargeBERT_chunk':")
    for model in core_models:
        print(model)
        
    # get the first (and supposedly only) model
    model_path = core_models[0]

    # load the model state
    ensemble_model.load_state_dict(torch.load(model_path))
    print(f"Loaded Model{model_path}")
else:
    print("No models found starting with 'dischargeBERT_chunk'.")

Found models starting with 'dischargeBERT_chunk':
dischargeBERT_chunk_epoch_5roc_0.7204260583852379.pth
dischargeBERT_chunk_epoch_4roc_0.7189238265020537.pth
Loaded ModeldischargeBERT_chunk_epoch_5roc_0.7204260583852379.pth


In [16]:
# ensemble_model

In [19]:
# Push the model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ensemble_model = ensemble_model.to(device)

In [20]:
train_loader = DataLoader(train_dataset, batch_size=9, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=9, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=9, shuffle=False)

In [21]:
epochs = 200
best_roc_auc = 0.720426
min_delta = 0.0001
early_stopping_count = 0
early_stopping_patience = 3
gradient_accumulation_steps = 10

# Set the optimizer
optimizer = AdamW(ensemble_model.parameters(), lr=1e-5, weight_decay=0.01)

# Set the scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=50, 
    num_training_steps=len(train_loader) * epochs // gradient_accumulation_steps
)


In [22]:
# device = "cpu"

In [None]:
from torch.nn import functional as F
# Training
for epoch in range(epochs):
    ensemble_model.train()
    train_loss = 0
    for step, batch in enumerate(tqdm(train_loader)):
        optimizer.zero_grad() if step % gradient_accumulation_steps == 0 else None
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = ensemble_model(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        (loss / gradient_accumulation_steps).backward()
        train_loss += loss.item()
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_loader):
            optimizer.step()
            scheduler.step()

    ensemble_model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for batch in tqdm(val_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = ensemble_model(input_ids, attention_mask)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            val_loss += loss.item()
            val_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
            val_labels.append(labels.cpu().numpy())
            

    val_preds = np.concatenate(val_preds)
    val_labels = np.concatenate(val_labels)
    val_loss /= len(val_loader)
    train_loss /= len(train_loader)
    print(f'Epoch: {epoch+1}/{epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}')

    # Calculate metrics
    val_preds_class = np.argmax(val_preds, axis=1)
    accuracy = accuracy_score(val_labels, val_preds_class)
    recall = recall_score(val_labels, val_preds_class, average='weighted')
    precision = precision_score(val_labels, val_preds_class, average='weighted')
    f1 = f1_score(val_labels, val_preds_class, average='weighted')
    micro_f1 = f1_score(val_labels, val_preds_class, average='micro')
    macro_roc_auc = roc_auc_score(val_labels, val_preds, multi_class='ovo', average='macro')

    print(f'Accuracy: {accuracy}, Recall: {recall}, Precision: {precision}, F1: {f1}, Micro F1: {micro_f1}, Macro Roc Auc: {macro_roc_auc}')
        
        
    # Implement early stopping
    if epoch > 0 and macro_roc_auc - best_roc_auc < min_delta:
        early_stopping_count += 1
        print(f'EarlyStopping counter: {early_stopping_count} out of {early_stopping_patience}')
        if early_stopping_count >= early_stopping_patience:
            print('Early stopping')
            break
    else:
        best_roc_auc = macro_roc_auc
        early_stopping_count = 0
        torch.save(ensemble_model.state_dict(), f"dischargeBERT_chunk_epoch_{epoch}roc_{best_roc_auc}.pth")

100%|█████████████████████████████████████| 3381/3381 [1:27:46<00:00,  1.56s/it]
100%|█████████████████████████████████████████| 488/488 [04:36<00:00,  1.76it/s]


Epoch: 1/200, Training Loss: 1.3088396803862383, Validation Loss: 1.2622539084710058
Accuracy: 0.38601685265315416, Recall: 0.38601685265315416, Precision: 0.37831052512810526, F1: 0.30033015545581854, Micro F1: 0.38601685265315416, Macro Roc Auc: 0.6356198625140287


100%|█████████████████████████████████████| 3381/3381 [1:26:16<00:00,  1.53s/it]
100%|█████████████████████████████████████████| 488/488 [04:35<00:00,  1.77it/s]


Epoch: 2/200, Training Loss: 1.2414106542244003, Validation Loss: 1.2111764597843906
Accuracy: 0.4108403552721476, Recall: 0.4108403552721476, Precision: 0.4214153317367192, F1: 0.38994785735175774, Micro F1: 0.4108403552721476, Macro Roc Auc: 0.6823343112689256


100%|█████████████████████████████████████| 3381/3381 [1:26:32<00:00,  1.54s/it]
100%|█████████████████████████████████████████| 488/488 [04:16<00:00,  1.90it/s]


Epoch: 3/200, Training Loss: 1.1975202450770592, Validation Loss: 1.186746760225687
Accuracy: 0.42382145297198814, Recall: 0.42382145297198814, Precision: 0.4200792255144625, F1: 0.41230206782723683, Micro F1: 0.42382145297198814, Macro Roc Auc: 0.7023726871987758


100%|█████████████████████████████████████| 3381/3381 [1:27:29<00:00,  1.55s/it]
100%|█████████████████████████████████████████| 488/488 [04:34<00:00,  1.78it/s]


Epoch: 4/200, Training Loss: 1.1659957835110855, Validation Loss: 1.1757249424203498
Accuracy: 0.43429742655431564, Recall: 0.43429742655431564, Precision: 0.4301291993890742, F1: 0.42573554157774934, Micro F1: 0.43429742655431564, Macro Roc Auc: 0.711401781441193


100%|█████████████████████████████████████| 3381/3381 [1:28:00<00:00,  1.56s/it]
100%|█████████████████████████████████████████| 488/488 [04:35<00:00,  1.77it/s]


Epoch: 5/200, Training Loss: 1.1384664577903285, Validation Loss: 1.1702117442351874
Accuracy: 0.43885219767706674, Recall: 0.43885219767706674, Precision: 0.4458319096802863, F1: 0.42891152329328, Micro F1: 0.43885219767706674, Macro Roc Auc: 0.7189238265020537


100%|█████████████████████████████████████| 3381/3381 [1:27:28<00:00,  1.55s/it]
100%|█████████████████████████████████████████| 488/488 [04:39<00:00,  1.75it/s]


Epoch: 6/200, Training Loss: 1.1105772353812315, Validation Loss: 1.1726674172477645
Accuracy: 0.4438624459120929, Recall: 0.4438624459120929, Precision: 0.4391937225914009, F1: 0.4314557802985876, Micro F1: 0.4438624459120929, Macro Roc Auc: 0.7204260583852379


 17%|██████▍                               | 574/3381 [15:14<1:15:03,  1.60s/it]

In [23]:
from torch.nn import functional as F
# Training
for epoch in range(7, epochs):
    ensemble_model.train()
    train_loss = 0
    for step, batch in enumerate(tqdm(train_loader)):
        optimizer.zero_grad() if step % gradient_accumulation_steps == 0 else None
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = ensemble_model(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        (loss / gradient_accumulation_steps).backward()
        train_loss += loss.item()
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_loader):
            optimizer.step()
            scheduler.step()

    ensemble_model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for batch in tqdm(val_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = ensemble_model(input_ids, attention_mask)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            val_loss += loss.item()
            val_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
            val_labels.append(labels.cpu().numpy())
            

    val_preds = np.concatenate(val_preds)
    val_labels = np.concatenate(val_labels)
    val_loss /= len(val_loader)
    train_loss /= len(train_loader)
    print(f'Epoch: {epoch+1}/{epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}')

    # Calculate metrics
    val_preds_class = np.argmax(val_preds, axis=1)
    accuracy = accuracy_score(val_labels, val_preds_class)
    recall = recall_score(val_labels, val_preds_class, average='weighted')
    precision = precision_score(val_labels, val_preds_class, average='weighted')
    f1 = f1_score(val_labels, val_preds_class, average='weighted')
    micro_f1 = f1_score(val_labels, val_preds_class, average='micro')
    macro_roc_auc = roc_auc_score(val_labels, val_preds, multi_class='ovo', average='macro')

    print(f'Accuracy: {accuracy}, Recall: {recall}, Precision: {precision}, F1: {f1}, Micro F1: {micro_f1}, Macro Roc Auc: {macro_roc_auc}')
        
        
    # Implement early stopping
    if epoch > 0 and macro_roc_auc - best_roc_auc < min_delta:
        early_stopping_count += 1
        print(f'EarlyStopping counter: {early_stopping_count} out of {early_stopping_patience}')
        if early_stopping_count >= early_stopping_patience:
            print('Early stopping')
            break
    else:
        best_roc_auc = macro_roc_auc
        early_stopping_count = 0
        torch.save(ensemble_model.state_dict(), f"dischargeBERT_chunk_epoch_{epoch}roc_{best_roc_auc}.pth")

100%|█████████████████████████████████████| 3381/3381 [1:27:03<00:00,  1.55s/it]
100%|█████████████████████████████████████████| 488/488 [04:31<00:00,  1.80it/s]


Epoch: 8/200, Training Loss: 1.0802841222134612, Validation Loss: 1.1987475198311883
Accuracy: 0.4413573217945798, Recall: 0.4413573217945798, Precision: 0.4444183544418358, F1: 0.4240782193066044, Micro F1: 0.4413573217945798, Macro Roc Auc: 0.7207026555731604


100%|█████████████████████████████████████| 3381/3381 [1:27:00<00:00,  1.54s/it]
100%|█████████████████████████████████████████| 488/488 [04:40<00:00,  1.74it/s]


Epoch: 9/200, Training Loss: 1.0582498106791471, Validation Loss: 1.1803361584172873
Accuracy: 0.4459120929173309, Recall: 0.4459120929173309, Precision: 0.4419050565361835, F1: 0.43646566092697553, Micro F1: 0.44591209291733086, Macro Roc Auc: 0.7219033483113719


100%|█████████████████████████████████████| 3381/3381 [1:26:54<00:00,  1.54s/it]
100%|█████████████████████████████████████████| 488/488 [04:39<00:00,  1.75it/s]


Epoch: 10/200, Training Loss: 1.0351085993919948, Validation Loss: 1.1899355753767686
Accuracy: 0.4522887724891824, Recall: 0.4522887724891824, Precision: 0.4499690848516941, F1: 0.4381513512884322, Micro F1: 0.4522887724891824, Macro Roc Auc: 0.7212348576189193
EarlyStopping counter: 1 out of 3


100%|█████████████████████████████████████| 3381/3381 [1:26:32<00:00,  1.54s/it]
100%|█████████████████████████████████████████| 488/488 [04:22<00:00,  1.86it/s]


Epoch: 11/200, Training Loss: 1.0098954354101861, Validation Loss: 1.2047599957126085
Accuracy: 0.4493281712593942, Recall: 0.4493281712593942, Precision: 0.4566540966027, F1: 0.4425002341744109, Micro F1: 0.4493281712593942, Macro Roc Auc: 0.7171445179538215
EarlyStopping counter: 2 out of 3


100%|█████████████████████████████████████| 3381/3381 [1:27:39<00:00,  1.56s/it]
100%|█████████████████████████████████████████| 488/488 [04:34<00:00,  1.78it/s]

Epoch: 12/200, Training Loss: 0.9840962408791982, Validation Loss: 1.2202373275258502
Accuracy: 0.44409018446823045, Recall: 0.44409018446823045, Precision: 0.44564006884696183, F1: 0.43868574922782994, Micro F1: 0.44409018446823045, Macro Roc Auc: 0.7134916118524237
EarlyStopping counter: 3 out of 3
Early stopping





In [24]:
import os

# list all files in the current directory
files = os.listdir('.')

# filter the ones that start with 'CORE_baseline'
core_models = [f for f in files if f.startswith('dischargeBERT_chunk')]

if core_models:
    print("Found models starting with 'dischargeBERT_chunk':")
    for model in core_models:
        print(model)
        
    # get the first (and supposedly only) model
    model_path = core_models[0]

    # load the model state
    ensemble_model.load_state_dict(torch.load(model_path))
    print("Loaded Model")
else:
    print("No models found starting with 'dischargeBERT_chunk'.")

Found models starting with 'dischargeBERT_chunk':
dischargeBERT_chunk_epoch_8roc_0.7219033483113719.pth
Loaded Model


In [25]:
# Put the model in evaluation mode
ensemble_model.eval()

# Initialize lists to store predictions and true labels
test_preds = []
test_labels = []

# Iterate over test data
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = ensemble_model(input_ids, attention_mask)
        test_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
        test_labels.append(labels.cpu().numpy())



100%|█████████████████████████████████████████| 978/978 [08:57<00:00,  1.82it/s]


In [26]:
test_preds = np.concatenate(test_preds)
test_labels = np.concatenate(test_labels)

# Calculate metrics
test_preds_class = np.argmax(test_preds, axis=1)
accuracy = accuracy_score(test_labels, test_preds_class)
recall = recall_score(test_labels, test_preds_class, average='weighted')
precision = precision_score(test_labels, test_preds_class, average='weighted')
f1 = f1_score(test_labels, test_preds_class, average='weighted')
micro_f1 = f1_score(test_labels, test_preds_class, average='micro')
macro_roc_auc = roc_auc_score(test_labels, test_preds, multi_class='ovo', average='macro')

print(f'Accuracy: {accuracy}, Recall: {recall}, Precision: {precision}, F1: {f1}, Micro F1: {micro_f1}, Macro Roc Auc: {macro_roc_auc}')

Accuracy: 0.45413209048539277, Recall: 0.45413209048539277, Precision: 0.44993397470703245, F1: 0.44530342628934716, Micro F1: 0.45413209048539277, Macro Roc Auc: 0.726812645778557
