In [1]:
import torch
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

In [2]:
# Load Train, Validation, Test Dataset
dataset = pd.read_csv('/kaggle/input/mimic-bengali-preprocessed/balanced_translated_dataset_preprocessed.csv')

In [3]:
dataset.head()

Unnamed: 0,english_text,english_text_preprocessed,label,bengali_text
0,CHIEF COMPLAINT: weak/fatigue\n\nPRESENT ILLNE...,chief complaint weakfatigue present illness ye...,0,"প্রধান অভিযোগ হল, দুর্দশাগ্রস্ত রোগ বর্তমান পু..."
1,CHIEF COMPLAINT: Fever and hypotension.\n\nPRE...,chief complaint fever hypotension present illn...,0,প্রধান অভিযোগ জ্বর হাইপোটেনশন বর্তমান অসুস্থ র...
2,CHIEF COMPLAINT: fall\n\nPRESENT ILLNESS: HPI:...,chief complaint fall present illness hpi f fal...,0,"প্রধান অভিযোগ হচ্ছে, গত রাতে হেডনো লক্রেট্রেশন..."
3,CHIEF COMPLAINT: Worsening shortness of breath...,chief complaint worsening shortness breath pre...,0,"প্রধান অভিযোগ হচ্ছে, রোগী রোগী বছর বয়সী মহিলা..."
4,CHIEF COMPLAINT: Chest heaviness\n\nPRESENT IL...,chief complaint chest heaviness present illnes...,0,প্রধান অভিযোগের বুকটি বর্তমান অসুস্থতাকে ক্ষুব...


In [4]:
from sklearn.model_selection import train_test_split

# Split the data into train (80%) and temp (20%)
train_df, temp_df = train_test_split(dataset, test_size=0.2, stratify=dataset['label'], random_state=42)

# Split the temp data into validation (10%) and test (10%)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=42)

# Rename the datasets
train_dataset = train_df
val_dataset = val_df
test_dataset = test_df

In [5]:
# Print class distribution for the train dataset
print("Train dataset class distribution:")
print(train_dataset['label'].value_counts())

# Print class distribution for the validation dataset
print("Validation dataset class distribution:")
print(val_dataset['label'].value_counts())

# Print class distribution for the test dataset
print("Test dataset class distribution:")
print(test_dataset['label'].value_counts())

Train dataset class distribution:
label
1    800
0    800
Name: count, dtype: int64
Validation dataset class distribution:
label
1    100
0    100
Name: count, dtype: int64
Test dataset class distribution:
label
1    100
0    100
Name: count, dtype: int64


In [6]:
from torch.utils.data import DataLoader
from torch import nn

class EnsembleModel(nn.Module):
    def __init__(self, model1):
        super(EnsembleModel, self).__init__()
        self.model1 = model1

    def forward(self, input_ids, attention_mask):
        output1 = self.model1(input_ids, attention_mask=attention_mask)[0]
        avg_output = output1
        return avg_output

In [7]:
from transformers import AutoModelForSequenceClassification, AutoConfig

# Specify the dropout rate in the configuration
config = AutoConfig.from_pretrained('sagorsarker/bangla-bert-base', 
                                    num_labels=2, 
                                    hidden_dropout_prob=0.2, 
                                    attention_probs_dropout_prob=0.2)

# Load the pre-trained model with the specified configuration
bert_model = AutoModelForSequenceClassification.from_pretrained('sagorsarker/bangla-bert-base', config=config)


config.json:   0%|          | 0.00/491 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/660M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sagorsarker/bangla-bert-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
from transformers import AutoTokenizer

# Choose a tokenizer
tokenizer = AutoTokenizer.from_pretrained('sagorsarker/bangla-bert-base')

vocab.txt:   0%|          | 0.00/2.24M [00:00<?, ?B/s]

In [9]:
# Apply the tokenizer to the training, validation, and test datasets
train_encodings = tokenizer(train_dataset['bengali_text'].tolist(), truncation=True, padding=True, max_length = 512)
val_encodings = tokenizer(val_dataset['bengali_text'].tolist(), truncation=True, padding=True,  max_length = 512)
test_encodings = tokenizer(test_dataset['bengali_text'].tolist(), truncation=True, padding=True , max_length = 512)

In [10]:
# Create a Dataset for PyTorch
class LosDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [11]:
train_dataset = LosDataset(train_encodings, train_dataset['label'].tolist())
val_dataset = LosDataset(val_encodings, val_dataset['label'].tolist())
test_dataset = LosDataset(test_encodings, test_dataset['label'].tolist())

In [12]:
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, accuracy_score
from tqdm import tqdm
from torch import nn
import numpy as np

# Create the ensemble model
ensemble_model = EnsembleModel(bert_model)

In [13]:
# Push the model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ensemble_model = ensemble_model.to(device)

In [14]:
train_loader = DataLoader(train_dataset, batch_size=18, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=18, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=18, shuffle=False)

In [15]:
epochs = 200
best_roc_auc = 0.0
min_delta = 0.0001
early_stopping_count = 0
early_stopping_patience = 10
gradient_accumulation_steps = 10
best_model_path = "best_model.pth"

# Set the optimizer
optimizer = AdamW(ensemble_model.parameters(), lr=1e-5, weight_decay=0.01)

# Set the scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=50, 
    num_training_steps=len(train_loader) * epochs // gradient_accumulation_steps
)




In [16]:
from torch.nn import functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

best_model_state_dict = None

# Training
for epoch in range(epochs):
    ensemble_model.train()
    train_loss = 0
    for step, batch in enumerate(tqdm(train_loader)):
        optimizer.zero_grad() if step % gradient_accumulation_steps == 0 else None
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = ensemble_model(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        (loss / gradient_accumulation_steps).backward()
        train_loss += loss.item()
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_loader):
            optimizer.step()
            scheduler.step()

    ensemble_model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for batch in tqdm(val_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = ensemble_model(input_ids, attention_mask)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            val_loss += loss.item()
            val_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
            val_labels.append(labels.cpu().numpy())
            
    val_preds = np.concatenate(val_preds)
    val_labels = np.concatenate(val_labels)
    val_loss /= len(val_loader)
    train_loss /= len(train_loader)
    print(f'Epoch: {epoch+1}/{epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}')

    # Calculate metrics
    val_preds_class = np.argmax(val_preds, axis=1)
    accuracy = accuracy_score(val_labels, val_preds_class)
    recall = recall_score(val_labels, val_preds_class)
    precision = precision_score(val_labels, val_preds_class)
    f1 = f1_score(val_labels, val_preds_class)
    roc_auc = roc_auc_score(val_labels, val_preds[:, 1])

    print(f'Accuracy: {accuracy}, Recall: {recall}, Precision: {precision}, F1: {f1}, Roc Auc: {roc_auc}')

    # Implement early stopping
    if epoch > 0 and roc_auc - best_roc_auc < min_delta:
        early_stopping_count += 1
        print(f'EarlyStopping counter: {early_stopping_count} out of {early_stopping_patience}')
        if early_stopping_count >= early_stopping_patience:
            print('Early stopping')
            break
    else:
        best_roc_auc = roc_auc
        early_stopping_count = 0
        best_model_state_dict = ensemble_model.state_dict()
        torch.save(ensemble_model.state_dict(), f"BERT_baseline_MP_2000.pth")


100%|██████████| 89/89 [01:24<00:00,  1.06it/s]
100%|██████████| 12/12 [00:03<00:00,  3.83it/s]


Epoch: 1/200, Training Loss: 0.7076852877488297, Validation Loss: 0.7092338254054388
Accuracy: 0.475, Recall: 0.66, Precision: 0.48175182481751827, F1: 0.5569620253164558, Roc Auc: 0.48580000000000007


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]


Epoch: 2/200, Training Loss: 0.703408987334605, Validation Loss: 0.7083308945099512
Accuracy: 0.49, Recall: 0.72, Precision: 0.4931506849315068, F1: 0.5853658536585366, Roc Auc: 0.4845
EarlyStopping counter: 1 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]


Epoch: 3/200, Training Loss: 0.6979838060529044, Validation Loss: 0.6993743032217026
Accuracy: 0.525, Recall: 0.66, Precision: 0.5196850393700787, F1: 0.581497797356828, Roc Auc: 0.5076999999999999


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]


Epoch: 4/200, Training Loss: 0.6915537005060175, Validation Loss: 0.7033769538005193
Accuracy: 0.52, Recall: 0.81, Precision: 0.5126582278481012, F1: 0.627906976744186, Roc Auc: 0.5214000000000001


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.81it/s]


Epoch: 5/200, Training Loss: 0.6818881155399794, Validation Loss: 0.6940501729647318
Accuracy: 0.55, Recall: 0.52, Precision: 0.5531914893617021, F1: 0.536082474226804, Roc Auc: 0.5457


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.80it/s]


Epoch: 6/200, Training Loss: 0.6854128763916787, Validation Loss: 0.7077377438545227
Accuracy: 0.535, Recall: 0.82, Precision: 0.5222929936305732, F1: 0.6381322957198443, Roc Auc: 0.5619


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.81it/s]


Epoch: 7/200, Training Loss: 0.6718895214327266, Validation Loss: 0.7051560829083124
Accuracy: 0.535, Recall: 0.69, Precision: 0.5267175572519084, F1: 0.5974025974025975, Roc Auc: 0.5843


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.81it/s]


Epoch: 8/200, Training Loss: 0.6566414109776529, Validation Loss: 0.6993634551763535
Accuracy: 0.585, Recall: 0.61, Precision: 0.580952380952381, F1: 0.5951219512195122, Roc Auc: 0.6036999999999999


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]


Epoch: 9/200, Training Loss: 0.6552585028530507, Validation Loss: 0.6941869954268137
Accuracy: 0.565, Recall: 0.44, Precision: 0.5866666666666667, F1: 0.5028571428571429, Roc Auc: 0.623


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.81it/s]


Epoch: 10/200, Training Loss: 0.6238648898146125, Validation Loss: 0.6937173654635748
Accuracy: 0.58, Recall: 0.49, Precision: 0.5975609756097561, F1: 0.5384615384615385, Roc Auc: 0.6219
EarlyStopping counter: 1 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]


Epoch: 11/200, Training Loss: 0.6126788734050279, Validation Loss: 0.7201506743828455
Accuracy: 0.585, Recall: 0.67, Precision: 0.5726495726495726, F1: 0.6175115207373272, Roc Auc: 0.6169
EarlyStopping counter: 2 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.81it/s]


Epoch: 12/200, Training Loss: 0.582122622915868, Validation Loss: 0.7440549532572428
Accuracy: 0.565, Recall: 0.4, Precision: 0.5970149253731343, F1: 0.47904191616766467, Roc Auc: 0.6159
EarlyStopping counter: 3 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]


Epoch: 13/200, Training Loss: 0.5674050902382711, Validation Loss: 0.7757652153571447
Accuracy: 0.61, Recall: 0.79, Precision: 0.5808823529411765, F1: 0.6694915254237288, Roc Auc: 0.624


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]


Epoch: 14/200, Training Loss: 0.5278755731127235, Validation Loss: 0.761370008190473
Accuracy: 0.58, Recall: 0.66, Precision: 0.5689655172413793, F1: 0.6111111111111112, Roc Auc: 0.6242


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]


Epoch: 15/200, Training Loss: 0.4914301392737399, Validation Loss: 0.7981007347504298
Accuracy: 0.57, Recall: 0.58, Precision: 0.5686274509803921, F1: 0.5742574257425743, Roc Auc: 0.6260999999999999


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]


Epoch: 16/200, Training Loss: 0.4586210699563616, Validation Loss: 0.8465259869893392
Accuracy: 0.615, Recall: 0.82, Precision: 0.5815602836879432, F1: 0.6804979253112032, Roc Auc: 0.6337999999999999


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.81it/s]


Epoch: 17/200, Training Loss: 0.4111902142843504, Validation Loss: 0.8857194483280182
Accuracy: 0.605, Recall: 0.82, Precision: 0.5734265734265734, F1: 0.6748971193415638, Roc Auc: 0.6277
EarlyStopping counter: 1 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.80it/s]


Epoch: 18/200, Training Loss: 0.3627424037523484, Validation Loss: 0.9980068405469259
Accuracy: 0.585, Recall: 0.89, Precision: 0.5527950310559007, F1: 0.6819923371647509, Roc Auc: 0.6202000000000001
EarlyStopping counter: 2 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.81it/s]


Epoch: 19/200, Training Loss: 0.3338022595376111, Validation Loss: 0.95865931113561
Accuracy: 0.595, Recall: 0.64, Precision: 0.5871559633027523, F1: 0.6124401913875598, Roc Auc: 0.6266000000000002
EarlyStopping counter: 3 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]


Epoch: 20/200, Training Loss: 0.29639284932211546, Validation Loss: 1.0208994348843892
Accuracy: 0.605, Recall: 0.78, Precision: 0.5777777777777777, F1: 0.6638297872340425, Roc Auc: 0.6332
EarlyStopping counter: 4 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.81it/s]


Epoch: 21/200, Training Loss: 0.2593337574534202, Validation Loss: 1.0383544663588207
Accuracy: 0.6, Recall: 0.63, Precision: 0.5943396226415094, F1: 0.6116504854368932, Roc Auc: 0.6248
EarlyStopping counter: 5 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.81it/s]


Epoch: 22/200, Training Loss: 0.23598114735959622, Validation Loss: 1.175237109263738
Accuracy: 0.575, Recall: 0.85, Precision: 0.5483870967741935, F1: 0.6666666666666665, Roc Auc: 0.6173
EarlyStopping counter: 6 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.80it/s]


Epoch: 23/200, Training Loss: 0.17126537553882332, Validation Loss: 1.1789650867382686
Accuracy: 0.605, Recall: 0.75, Precision: 0.5813953488372093, F1: 0.6550218340611353, Roc Auc: 0.6198
EarlyStopping counter: 7 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]


Epoch: 24/200, Training Loss: 0.18493987753819885, Validation Loss: 1.3440769215424855
Accuracy: 0.575, Recall: 0.85, Precision: 0.5483870967741935, F1: 0.6666666666666665, Roc Auc: 0.6178
EarlyStopping counter: 8 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.81it/s]


Epoch: 25/200, Training Loss: 0.1872397368627318, Validation Loss: 1.224830001592636
Accuracy: 0.605, Recall: 0.6, Precision: 0.6060606060606061, F1: 0.6030150753768845, Roc Auc: 0.6283
EarlyStopping counter: 9 out of 10


100%|██████████| 89/89 [01:23<00:00,  1.07it/s]
100%|██████████| 12/12 [00:03<00:00,  3.81it/s]

Epoch: 26/200, Training Loss: 0.1695151004228699, Validation Loss: 1.3045382897059123
Accuracy: 0.605, Recall: 0.77, Precision: 0.5789473684210527, F1: 0.6609442060085837, Roc Auc: 0.6131
EarlyStopping counter: 10 out of 10
Early stopping





In [17]:
# Load the best model state dict into the model if needed
ensemble_model.load_state_dict(best_model_state_dict)

<All keys matched successfully>

In [18]:
from torch.nn import functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

# Put the model in evaluation mode
ensemble_model.eval()

# Initialize lists to store predictions and true labels
test_preds = []
test_labels = []

# Iterate over test data
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = ensemble_model(input_ids, attention_mask)
        test_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
        test_labels.append(labels.cpu().numpy())


100%|██████████| 12/12 [00:03<00:00,  3.81it/s]


In [19]:
test_preds = np.concatenate(test_preds)
test_labels = np.concatenate(test_labels)

# Calculate metrics
test_preds_class = np.argmax(test_preds, axis=1)
accuracy = accuracy_score(test_labels, test_preds_class)
recall = recall_score(test_labels, test_preds_class)
precision = precision_score(test_labels, test_preds_class)
f1 = f1_score(test_labels, test_preds_class)
roc_auc = roc_auc_score(test_labels, test_preds[:, 1])

print(f'Accuracy: {accuracy}, Recall: {recall}, Precision: {precision}, F1: {f1}, Roc Auc: {roc_auc}')

Accuracy: 0.555, Recall: 0.68, Precision: 0.544, F1: 0.6044444444444443, Roc Auc: 0.6106
