In [1]:
import torch
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load Train, Validation, Test Dataset
train_dataset = pd.read_csv('saveDir/LOS_WEEKS_adm_train.csv')
val_dataset = pd.read_csv('saveDir/LOS_WEEKS_adm_val.csv')
test_dataset = pd.read_csv('saveDir/LOS_WEEKS_adm_test.csv')

In [3]:
from torch.utils.data import DataLoader
from torch import nn

class EnsembleModel(nn.Module):
    def __init__(self, model1, model2):
        super(EnsembleModel, self).__init__()
        self.model1 = model1
        self.model2 = model2

    def forward(self, input_ids, attention_mask):
        output1 = self.model1(input_ids, attention_mask=attention_mask)[0]
        output2 = self.model2(input_ids, attention_mask=attention_mask)[0]
        avg_output = (output1 + output2) / 2.00
        return avg_output

In [4]:
from transformers import AutoModelForSequenceClassification, AutoConfig

# Specify the dropout rate in the configuration
config = AutoConfig.from_pretrained('bvanaken/CORe-clinical-outcome-biobert-v1', 
                                    num_labels=4, 
                                    hidden_dropout_prob=0.2, 
                                    attention_probs_dropout_prob=0.2)

# Load the pre-trained model with the specified configuration
core_model = AutoModelForSequenceClassification.from_pretrained('bvanaken/CORe-clinical-outcome-biobert-v1', config=config)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bvanaken/CORe-clinical-outcome-biobert-v1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# from transformers import AutoModelForSequenceClassification, AutoConfig

# # Specify the dropout rate in the configuration
# config = AutoConfig.from_pretrained('emilyalsentzer/Bio_ClinicalBERT', 
#                                     num_labels=4, 
#                                     hidden_dropout_prob=0.2, 
#                                     attention_probs_dropout_prob=0.2)

# # Load the pre-trained model with the specified configuration
# clinical_model = AutoModelForSequenceClassification.from_pretrained('emilyalsentzer/Bio_ClinicalBERT', config=config)

In [6]:
# from transformers import AutoModelForSequenceClassification, AutoConfig

# # Specify the dropout rate in the configuration
# config = AutoConfig.from_pretrained('dmis-lab/biobert-base-cased-v1.2', 
#                                     num_labels=4, 
#                                     hidden_dropout_prob=0.2, 
#                                     attention_probs_dropout_prob=0.2)

# # Load the pre-trained model with the specified configuration
# biobert_base_model = AutoModelForSequenceClassification.from_pretrained('dmis-lab/biobert-base-cased-v1.2', config=config)

In [7]:
from transformers import AutoModelForSequenceClassification, AutoConfig

# Specify the dropout rate in the configuration
config = AutoConfig.from_pretrained('emilyalsentzer/Bio_Discharge_Summary_BERT', 
                                    num_labels=4, 
                                    hidden_dropout_prob=0.2, 
                                    attention_probs_dropout_prob=0.2)

# Load the pre-trained model with the specified configuration
discharge_model = AutoModelForSequenceClassification.from_pretrained('emilyalsentzer/Bio_Discharge_Summary_BERT', config=config)

Some weights of the model checkpoint at emilyalsentzer/Bio_Discharge_Summary_BERT were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from

In [8]:
from transformers import AutoTokenizer

# Choose a tokenizer
tokenizer = AutoTokenizer.from_pretrained('bvanaken/CORe-clinical-outcome-biobert-v1')

In [9]:
# Apply the tokenizer to the training, validation, and test datasets
train_encodings = tokenizer(train_dataset['text'].tolist(), truncation=True, padding=True, max_length = 512)
val_encodings = tokenizer(val_dataset['text'].tolist(), truncation=True, padding=True,  max_length = 512)
test_encodings = tokenizer(test_dataset['text'].tolist(), truncation=True, padding=True , max_length = 512)

In [10]:
# Create a Dataset for PyTorch
class LosDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [11]:
train_dataset = LosDataset(train_encodings, train_dataset['los_label'].tolist())
val_dataset = LosDataset(val_encodings, val_dataset['los_label'].tolist())
test_dataset = LosDataset(test_encodings, test_dataset['los_label'].tolist())

In [12]:
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, accuracy_score
from tqdm import tqdm
from torch import nn
import numpy as np

# Create the ensemble model
ensemble_model = EnsembleModel(core_model, discharge_model)

In [13]:
import os

# list all files in the current directory
files = os.listdir('.')

# filter the ones that start with 'CORE_baseline'
core_models = [f for f in files if f.startswith('CORE_ensemble(core + dischargebert)')]

if core_models:
    print("Found models starting with 'CORE_ensemble(core + dischargebert)':")
    for model in core_models:
        print(model)
        
    # get the first (and supposedly only) model
    model_path = core_models[0]

    # load the model state
    ensemble_model.load_state_dict(torch.load(model_path))
    print(f"Loaded Model: {model_path}")
else:
    print("No models found starting with 'CORE_ensemble(core + dischargebert)'.")

Found models starting with 'CORE_ensemble(core + dischargebert)':
CORE_ensemble(core + dischargebert)_epoch_5roc_0.7133395365724405.pth
Loaded Model: CORE_ensemble(core + dischargebert)_epoch_5roc_0.7133395365724405.pth


In [14]:
# ensemble_model

In [15]:
# Push the model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ensemble_model = ensemble_model.to(device)

In [16]:
from transformers import DistilBertForSequenceClassification, DistilBertConfig, AdamW, get_linear_schedule_with_warmup

# create a student model
student_config = DistilBertConfig.from_pretrained('distilbert-base-uncased', 
                                                  num_labels=4, 
                                                  hidden_dropout_prob=0.2, 
                                                  attention_probs_dropout_prob=0.2)

student_model = DistilBertForSequenceClassification(student_config)

# set the temperature
temperature = 2.0

In [17]:
import os

# list all files in the current directory
files = os.listdir('.')

# filter the ones that start with 'CORE_baseline'
core_models = [f for f in files if f.startswith('CORE_ensemble(core + dischargebert) + distilBert')]

if core_models:
    print("Found models starting with 'CORE_ensemble(core + dischargebert) + distilBert':")
    for model in core_models:
        print(model)
        
    # get the first (and supposedly only) model
    model_path = core_models[0]

    # load the model state
    student_model.load_state_dict(torch.load(model_path))
    print(f"Loaded Model: {model_path}")
else:
    print("No models found starting with 'CORE_ensemble(core + dischargebert) + distilBert'.")

No models found starting with 'CORE_ensemble(core + dischargebert) + distilBert'.


In [18]:
# Push the model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = student_model.to(device)

In [19]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [20]:
epochs = 200
best_roc_auc = 0.0
min_delta = 0.0001
early_stopping_count = 0
early_stopping_patience = 3
gradient_accumulation_steps = 10

# Set the optimizer
optimizer = AdamW(student_model.parameters(), lr=1e-5, weight_decay=0.01)

# Set the scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=50, 
    num_training_steps=len(train_loader) * epochs // gradient_accumulation_steps
)




In [21]:
from torch.nn import functional as F

ensemble_model.eval()

# Training
for epoch in range(epochs):
    student_model.train()
    train_loss = 0
    for step, batch in enumerate(tqdm(train_loader)):
        optimizer.zero_grad() if step % gradient_accumulation_steps == 0 else None
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # get student model's logits
        student_logits = student_model(input_ids, attention_mask)[0]
        
        # get teacher model's logits
        with torch.no_grad():
            teacher_logits = ensemble_model(input_ids, attention_mask)
            
            
        # calculate loss
        loss = (
            nn.KLDivLoss()(F.log_softmax(student_logits/temperature, dim=1), 
                           F.softmax(teacher_logits/temperature, dim=1)) * (temperature ** 2) +
            nn.CrossEntropyLoss()(student_logits, labels)
        )
        
        (loss / gradient_accumulation_steps).backward()
        
        train_loss += loss.item()
        
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_loader):
            optimizer.step()
            scheduler.step()

    student_model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for batch in tqdm(val_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = student_model(input_ids, attention_mask)[0]
            loss = nn.CrossEntropyLoss()(outputs, labels)
            val_loss += loss.item()
            val_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
            val_labels.append(labels.cpu().numpy())
            

    val_preds = np.concatenate(val_preds)
    val_labels = np.concatenate(val_labels)
    val_loss /= len(val_loader)
    train_loss /= len(train_loader)
    print(f'Epoch: {epoch+1}/{epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}')

    # Calculate metrics
    val_preds_class = np.argmax(val_preds, axis=1)
    accuracy = accuracy_score(val_labels, val_preds_class)
    recall = recall_score(val_labels, val_preds_class, average='weighted')
    precision = precision_score(val_labels, val_preds_class, average='weighted')
    f1 = f1_score(val_labels, val_preds_class, average='weighted')
    micro_f1 = f1_score(val_labels, val_preds_class, average='micro')
    macro_roc_auc = roc_auc_score(val_labels, val_preds, multi_class='ovo', average='macro')

    print(f'Accuracy: {accuracy}, Recall: {recall}, Precision: {precision}, F1: {f1}, Micro F1: {micro_f1}, Macro Roc Auc: {macro_roc_auc}')
        
        
    # Implement early stopping
    if macro_roc_auc - best_roc_auc < min_delta:
        early_stopping_count += 1
        print(f'EarlyStopping counter: {early_stopping_count} out of {early_stopping_patience}')
        if early_stopping_count >= early_stopping_patience:
            print('Early stopping')
            break
    else:
        best_roc_auc = macro_roc_auc
        early_stopping_count = 0
        torch.save(student_model.state_dict(), f"CORE_ensemble(core + dischargebert) + distilBert_epoch_{epoch}roc_{best_roc_auc}.pth")

100%|█████████████████████████████████████████| 951/951 [54:05<00:00,  3.41s/it]
100%|█████████████████████████████████████████| 138/138 [01:10<00:00,  1.94it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 1/200, Training Loss: 1.3936300123527599, Validation Loss: 1.3139223274977312
Accuracy: 0.3664313368253245, Recall: 0.3664313368253245, Precision: 0.13427192460759443, F1: 0.19652934031731573, Micro F1: 0.3664313368253245, Macro Roc Auc: 0.5668297955270275


100%|█████████████████████████████████████████| 951/951 [54:24<00:00,  3.43s/it]
100%|█████████████████████████████████████████| 138/138 [01:11<00:00,  1.93it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 2/200, Training Loss: 1.3794044772658565, Validation Loss: 1.3001905852469846
Accuracy: 0.3675700296060123, Recall: 0.3675700296060123, Precision: 0.37914853979803653, F1: 0.20316625017938786, Micro F1: 0.3675700296060123, Macro Roc Auc: 0.5800700965664007


100%|█████████████████████████████████████████| 951/951 [54:19<00:00,  3.43s/it]
100%|█████████████████████████████████████████| 138/138 [01:07<00:00,  2.06it/s]


Epoch: 3/200, Training Loss: 1.3534808536182568, Validation Loss: 1.2712859960569851
Accuracy: 0.38077886586199045, Recall: 0.38077886586199045, Precision: 0.3681338027605608, F1: 0.31179573771520347, Micro F1: 0.38077886586199045, Macro Roc Auc: 0.6192753537713008


100%|█████████████████████████████████████████| 951/951 [54:39<00:00,  3.45s/it]
100%|█████████████████████████████████████████| 138/138 [01:11<00:00,  1.93it/s]


Epoch: 4/200, Training Loss: 1.3041345427590338, Validation Loss: 1.2484596069308296
Accuracy: 0.3919380551127306, Recall: 0.3919380551127306, Precision: 0.38651102962746503, F1: 0.3690170545169432, Micro F1: 0.3919380551127306, Macro Roc Auc: 0.6549495494721381


100%|█████████████████████████████████████████| 951/951 [54:13<00:00,  3.42s/it]
100%|█████████████████████████████████████████| 138/138 [01:11<00:00,  1.94it/s]


Epoch: 5/200, Training Loss: 1.2627145069505388, Validation Loss: 1.2247694488884746
Accuracy: 0.40696879981780915, Recall: 0.40696879981780915, Precision: 0.405096903117529, F1: 0.3735218897803632, Micro F1: 0.40696879981780915, Macro Roc Auc: 0.6693954865383643


100%|█████████████████████████████████████████| 951/951 [54:08<00:00,  3.42s/it]
100%|█████████████████████████████████████████| 138/138 [01:11<00:00,  1.92it/s]


Epoch: 6/200, Training Loss: 1.2387308366165803, Validation Loss: 1.2311573952868364
Accuracy: 0.4108403552721476, Recall: 0.4108403552721476, Precision: 0.41103383833754353, F1: 0.3892569997420046, Micro F1: 0.4108403552721476, Macro Roc Auc: 0.6794620713120789


100%|█████████████████████████████████████████| 951/951 [55:08<00:00,  3.48s/it]
100%|█████████████████████████████████████████| 138/138 [01:05<00:00,  2.09it/s]


Epoch: 7/200, Training Loss: 1.2203472957375423, Validation Loss: 1.2181485373040903
Accuracy: 0.42040537462992483, Recall: 0.42040537462992483, Precision: 0.4247213362919185, F1: 0.3859282401896057, Micro F1: 0.42040537462992483, Macro Roc Auc: 0.6838770006917363


100%|█████████████████████████████████████████| 951/951 [54:25<00:00,  3.43s/it]
100%|█████████████████████████████████████████| 138/138 [01:11<00:00,  1.94it/s]


Epoch: 8/200, Training Loss: 1.212052574919852, Validation Loss: 1.214036452597466
Accuracy: 0.42199954452288774, Recall: 0.42199954452288774, Precision: 0.42060472079129774, F1: 0.3908388793612484, Micro F1: 0.4219995445228878, Macro Roc Auc: 0.6899074809243989


100%|█████████████████████████████████████████| 951/951 [54:33<00:00,  3.44s/it]
100%|█████████████████████████████████████████| 138/138 [01:11<00:00,  1.94it/s]


Epoch: 9/200, Training Loss: 1.1982120798840006, Validation Loss: 1.2045108637084132
Accuracy: 0.4292871783192895, Recall: 0.4292871783192895, Precision: 0.4346382649793847, F1: 0.40518699110547624, Micro F1: 0.4292871783192895, Macro Roc Auc: 0.6886542695294741
EarlyStopping counter: 1 out of 3


100%|█████████████████████████████████████████| 951/951 [54:45<00:00,  3.45s/it]
100%|█████████████████████████████████████████| 138/138 [01:12<00:00,  1.91it/s]


Epoch: 10/200, Training Loss: 1.1824384224527642, Validation Loss: 1.203469338192456
Accuracy: 0.4270097927579139, Recall: 0.4270097927579139, Precision: 0.4350474504812953, F1: 0.4110667956156379, Micro F1: 0.4270097927579139, Macro Roc Auc: 0.6904037278750308


100%|█████████████████████████████████████████| 951/951 [55:27<00:00,  3.50s/it]
100%|█████████████████████████████████████████| 138/138 [01:12<00:00,  1.91it/s]


Epoch: 11/200, Training Loss: 1.175224215199644, Validation Loss: 1.2135251993718355
Accuracy: 0.4151673878387611, Recall: 0.4151673878387611, Precision: 0.43343990813419997, F1: 0.40531798765823374, Micro F1: 0.4151673878387611, Macro Roc Auc: 0.6900092889601988
EarlyStopping counter: 1 out of 3


100%|█████████████████████████████████████████| 951/951 [54:39<00:00,  3.45s/it]
100%|█████████████████████████████████████████| 138/138 [01:12<00:00,  1.91it/s]


Epoch: 12/200, Training Loss: 1.168273323091172, Validation Loss: 1.1972021089083906
Accuracy: 0.4281484855386017, Recall: 0.4281484855386017, Precision: 0.4237239704426197, F1: 0.41864876391570366, Micro F1: 0.4281484855386017, Macro Roc Auc: 0.6966396165203791


100%|█████████████████████████████████████████| 951/951 [53:53<00:00,  3.40s/it]
100%|█████████████████████████████████████████| 138/138 [01:05<00:00,  2.12it/s]


Epoch: 13/200, Training Loss: 1.1601959469441234, Validation Loss: 1.2195514254811881
Accuracy: 0.4135732179457982, Recall: 0.4135732179457982, Precision: 0.40802623666665877, F1: 0.3925015600618677, Micro F1: 0.4135732179457982, Macro Roc Auc: 0.6974187591483713


100%|█████████████████████████████████████████| 951/951 [52:59<00:00,  3.34s/it]
100%|█████████████████████████████████████████| 138/138 [01:10<00:00,  1.95it/s]


Epoch: 14/200, Training Loss: 1.1598734786708522, Validation Loss: 1.2404665104720904
Accuracy: 0.39717604190389433, Recall: 0.39717604190389433, Precision: 0.42862070999538493, F1: 0.38520385407182295, Micro F1: 0.39717604190389433, Macro Roc Auc: 0.6865784149869573
EarlyStopping counter: 1 out of 3


100%|█████████████████████████████████████████| 951/951 [53:24<00:00,  3.37s/it]
100%|█████████████████████████████████████████| 138/138 [01:04<00:00,  2.14it/s]


Epoch: 15/200, Training Loss: 1.1516084192177725, Validation Loss: 1.2137463887532551
Accuracy: 0.42199954452288774, Recall: 0.42199954452288774, Precision: 0.4156814625580146, F1: 0.39927619605340564, Micro F1: 0.4219995445228878, Macro Roc Auc: 0.6966672025889888
EarlyStopping counter: 2 out of 3


100%|█████████████████████████████████████████| 951/951 [54:29<00:00,  3.44s/it]
100%|█████████████████████████████████████████| 138/138 [01:12<00:00,  1.91it/s]

Epoch: 16/200, Training Loss: 1.1420242595246413, Validation Loss: 1.2083302097044128
Accuracy: 0.4260988385333637, Recall: 0.4260988385333637, Precision: 0.42124606163227657, F1: 0.41462073336596816, Micro F1: 0.4260988385333637, Macro Roc Auc: 0.696985810236773
EarlyStopping counter: 3 out of 3
Early stopping





In [24]:
import os

# list all files in the current directory
files = os.listdir('.')

# filter the ones that start with 'CORE_baseline'
core_models = [f for f in files if f.startswith('CORE_ensemble(core + dischargebert) + distilBert')]

if core_models:
    print("Found models starting with 'CORE_ensemble(core + dischargebert) + distilBert':")
    for model in core_models:
        print(model)
        
    # get the first (and supposedly only) model
    model_path = core_models[0]

    # load the model state
    student_model.load_state_dict(torch.load(model_path))
    print("Loaded Model")
else:
    print("No models found starting with 'CORE_ensemble(core + dischargebert) + distilBert'.")

Found models starting with 'CORE_ensemble(core + dischargebert) + distilBert':
CORE_ensemble(core + dischargebert) + distilBert_epoch_12roc_0.6974187591483713.pth
Loaded Model


In [25]:
# Put the model in evaluation mode
student_model.eval()

# Initialize lists to store predictions and true labels
test_preds = []
test_labels = []

# Iterate over test data
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = student_model(input_ids, attention_mask)[0]
        test_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
        test_labels.append(labels.cpu().numpy())



100%|█████████████████████████████████████████| 275/275 [02:20<00:00,  1.96it/s]


In [26]:
test_preds = np.concatenate(test_preds)
test_labels = np.concatenate(test_labels)

# Calculate metrics
test_preds_class = np.argmax(test_preds, axis=1)
accuracy = accuracy_score(test_labels, test_preds_class)
recall = recall_score(test_labels, test_preds_class, average='weighted')
precision = precision_score(test_labels, test_preds_class, average='weighted')
f1 = f1_score(test_labels, test_preds_class, average='weighted')
micro_f1 = f1_score(test_labels, test_preds_class, average='micro')
macro_roc_auc = roc_auc_score(test_labels, test_preds, multi_class='ovo', average='macro')

print(f'Accuracy: {accuracy}, Recall: {recall}, Precision: {precision}, F1: {f1}, Micro F1: {micro_f1}, Macro Roc Auc: {macro_roc_auc}')

Accuracy: 0.4265090371717631, Recall: 0.4265090371717631, Precision: 0.42588944607031426, F1: 0.4072445736479922, Micro F1: 0.4265090371717631, Macro Roc Auc: 0.7055687966725189
