In [1]:
import torch
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

In [2]:

# Load Train, Validation, Test Dataset
train_dataset = pd.read_csv('/kaggle/input/mydata/LOS_WEEKS_adm_train.csv')
val_dataset = pd.read_csv('/kaggle/input/mydata/LOS_WEEKS_adm_val.csv')
test_dataset = pd.read_csv('/kaggle/input/mydata/LOS_WEEKS_adm_test.csv')

In [3]:
from torch.utils.data import DataLoader
from torch import nn

class EnsembleModel(nn.Module):
    def __init__(self, model1):
        super(EnsembleModel, self).__init__()
        self.model1 = model1

    def forward(self, input_ids, attention_mask):
        output1 = self.model1(input_ids, attention_mask=attention_mask)[0]
        avg_output = output1
        return avg_output

In [4]:
from transformers import DistilBertForSequenceClassification, DistilBertConfig, AdamW, get_linear_schedule_with_warmup

# create a student model
student_config = DistilBertConfig.from_pretrained('distilbert-base-uncased', 
                                                  num_labels=4, 
                                                  hidden_dropout_prob=0.2, 
                                                  attention_probs_dropout_prob=0.2)

student_model = DistilBertForSequenceClassification(student_config)

# set the temperature
temperature = 2.0

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [5]:
from transformers import AutoTokenizer

# Choose a tokenizer
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



In [6]:
train_encodings = tokenizer(train_dataset['text'].tolist(), truncation=True, padding=True, max_length = 512)
val_encodings = tokenizer(val_dataset['text'].tolist(), truncation=True, padding=True,  max_length = 512)
test_encodings = tokenizer(test_dataset['text'].tolist(), truncation=True, padding=True , max_length = 512)

In [7]:
# Create a Dataset for PyTorch
class LosDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [8]:
train_dataset = LosDataset(train_encodings, train_dataset['los_label'].tolist())
val_dataset = LosDataset(val_encodings, val_dataset['los_label'].tolist())
test_dataset = LosDataset(test_encodings, test_dataset['los_label'].tolist())

In [9]:
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, accuracy_score
from tqdm import tqdm
from torch import nn
import numpy as np

# Create the ensemble model
ensemble_model = EnsembleModel(student_model)

In [10]:
import os

# list all files in the current directory
files = os.listdir('.')

# filter the ones that start with 'CORE_baseline'
core_models = [f for f in files if f.startswith('distilBERT_baseline')]

if core_models:
    print("Found models starting with 'distilBERT_baseline':")
    for model in core_models:
        print(model)
        
    # get the first (and supposedly only) model
    model_path = core_models[0]

    # load the model state
    ensemble_model.load_state_dict(torch.load(model_path))
    print("Loaded Model")
else:
    print("No models found starting with 'distilBERT_baseline'.")

No models found starting with 'distilBERT_baseline'.


In [11]:
# Push the model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ensemble_model = ensemble_model.to(device)

In [12]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [13]:
epochs = 200
best_roc_auc = 0.0
min_delta = 0.0001
early_stopping_count = 0
early_stopping_patience = 3
gradient_accumulation_steps = 10
best_model_path = "best_model.pth"

# Set the optimizer
optimizer = AdamW(ensemble_model.parameters(), lr=1e-5, weight_decay=0.01)

# Set the scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=50, 
    num_training_steps=len(train_loader) * epochs // gradient_accumulation_steps
)



In [14]:
from torch.nn import functional as F
# Training
for epoch in range(epochs):
    ensemble_model.train()
    train_loss = 0
    for step, batch in enumerate(tqdm(train_loader)):
        optimizer.zero_grad() if step % gradient_accumulation_steps == 0 else None
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = ensemble_model(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        (loss / gradient_accumulation_steps).backward()
        train_loss += loss.item()
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_loader):
            optimizer.step()
            scheduler.step()

    ensemble_model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for batch in tqdm(val_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = ensemble_model(input_ids, attention_mask)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            val_loss += loss.item()
            val_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
            val_labels.append(labels.cpu().numpy())
            

    val_preds = np.concatenate(val_preds)
    val_labels = np.concatenate(val_labels)
    val_loss /= len(val_loader)
    train_loss /= len(train_loader)
    print(f'Epoch: {epoch+1}/{epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}')

    # Calculate metrics
    val_preds_class = np.argmax(val_preds, axis=1)
    accuracy = accuracy_score(val_labels, val_preds_class)
    recall = recall_score(val_labels, val_preds_class, average='weighted')
    precision = precision_score(val_labels, val_preds_class, average='weighted')
    f1 = f1_score(val_labels, val_preds_class, average='weighted')
    micro_f1 = f1_score(val_labels, val_preds_class, average='micro')
    macro_roc_auc = roc_auc_score(val_labels, val_preds, multi_class='ovo', average='macro')

    print(f'Accuracy: {accuracy}, Recall: {recall}, Precision: {precision}, F1: {f1}, Micro F1: {micro_f1}, Macro Roc Auc: {macro_roc_auc}')
        
        
    # Implement early stopping
    if epoch > 0 and macro_roc_auc - best_roc_auc < min_delta:
        early_stopping_count += 1
        print(f'EarlyStopping counter: {early_stopping_count} out of {early_stopping_patience}')
        if early_stopping_count >= early_stopping_patience:
            print('Early stopping')
            break
    else:
        best_roc_auc = macro_roc_auc
        early_stopping_count = 0
        torch.save(ensemble_model.state_dict(), f"distilBERT_baseline_epoch_{epoch}roc_{best_roc_auc}.pth")

100%|██████████| 951/951 [23:11<00:00,  1.46s/it]
100%|██████████| 138/138 [01:13<00:00,  1.88it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 1/200, Training Loss: 1.3284779172841181, Validation Loss: 1.3184998173644578
Accuracy: 0.3664313368253245, Recall: 0.3664313368253245, Precision: 0.13427192460759443, F1: 0.19652934031731573, Micro F1: 0.3664313368253245, Macro Roc Auc: 0.562613747819919


100%|██████████| 951/951 [23:14<00:00,  1.47s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 2/200, Training Loss: 1.3159794274690149, Validation Loss: 1.3091818370680879
Accuracy: 0.36688681393759964, Recall: 0.36688681393759964, Precision: 0.23741034582091947, F1: 0.1984843363746275, Micro F1: 0.36688681393759964, Macro Roc Auc: 0.5718401572458308


100%|██████████| 951/951 [23:13<00:00,  1.46s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 3/200, Training Loss: 1.307406319554045, Validation Loss: 1.298474641813748
Accuracy: 0.3693919380551127, Recall: 0.3693919380551127, Precision: 0.34199584964022106, F1: 0.2094257724489603, Micro F1: 0.3693919380551127, Macro Roc Auc: 0.5845949891060462


100%|██████████| 951/951 [23:13<00:00,  1.47s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 4/200, Training Loss: 1.285442347907619, Validation Loss: 1.2975817141325579
Accuracy: 0.3739467091778638, Recall: 0.3739467091778638, Precision: 0.30474990004344726, F1: 0.2219270905623767, Micro F1: 0.3739467091778638, Macro Roc Auc: 0.6239034565306528


100%|██████████| 951/951 [23:13<00:00,  1.47s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]


Epoch: 5/200, Training Loss: 1.245383406537815, Validation Loss: 1.2377834933391516
Accuracy: 0.4046914142564336, Recall: 0.4046914142564336, Precision: 0.3988188233348714, F1: 0.34961261172123576, Micro F1: 0.4046914142564336, Macro Roc Auc: 0.657183922276356


100%|██████████| 951/951 [23:12<00:00,  1.46s/it]
100%|██████████| 138/138 [01:13<00:00,  1.88it/s]


Epoch: 6/200, Training Loss: 1.2132584561058148, Validation Loss: 1.2296101934667947
Accuracy: 0.4076520154862218, Recall: 0.4076520154862218, Precision: 0.4170074999671502, F1: 0.3649539460289119, Micro F1: 0.4076520154862219, Macro Roc Auc: 0.6702639416931335


100%|██████████| 951/951 [23:12<00:00,  1.46s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]


Epoch: 7/200, Training Loss: 1.196351284484382, Validation Loss: 1.2348235525946687
Accuracy: 0.40901844682304717, Recall: 0.40901844682304717, Precision: 0.41729603879674626, F1: 0.35032612212026615, Micro F1: 0.4090184468230472, Macro Roc Auc: 0.678217637238053


100%|██████████| 951/951 [23:12<00:00,  1.46s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]


Epoch: 8/200, Training Loss: 1.1829090195499383, Validation Loss: 1.2180959193602852
Accuracy: 0.41903894329309954, Recall: 0.41903894329309954, Precision: 0.41920751288343633, F1: 0.39717156509193824, Micro F1: 0.41903894329309954, Macro Roc Auc: 0.6828126603102828


100%|██████████| 951/951 [23:13<00:00,  1.46s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]


Epoch: 9/200, Training Loss: 1.1687320855639083, Validation Loss: 1.2222098107787147
Accuracy: 0.40446367570029607, Recall: 0.40446367570029607, Precision: 0.4103083160892997, F1: 0.39915113366094895, Micro F1: 0.40446367570029607, Macro Roc Auc: 0.6837150915829353


100%|██████████| 951/951 [23:12<00:00,  1.46s/it]
100%|██████████| 138/138 [01:13<00:00,  1.88it/s]


Epoch: 10/200, Training Loss: 1.1540524703847121, Validation Loss: 1.2073225949121558
Accuracy: 0.4245046686404008, Recall: 0.4245046686404008, Precision: 0.4189291479399433, F1: 0.4039707885133677, Micro F1: 0.4245046686404008, Macro Roc Auc: 0.689246835621718


100%|██████████| 951/951 [23:14<00:00,  1.47s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]


Epoch: 11/200, Training Loss: 1.1419687437835677, Validation Loss: 1.2271083100982334
Accuracy: 0.42723753131405146, Recall: 0.42723753131405146, Precision: 0.4236852506505754, F1: 0.41668210316605403, Micro F1: 0.42723753131405146, Macro Roc Auc: 0.6908063788663262


100%|██████████| 951/951 [23:13<00:00,  1.47s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]


Epoch: 12/200, Training Loss: 1.1309586594784171, Validation Loss: 1.2162234951620516
Accuracy: 0.4304258710999772, Recall: 0.4304258710999772, Precision: 0.43549309273459297, F1: 0.41270940474990664, Micro F1: 0.4304258710999772, Macro Roc Auc: 0.6911873596124583


100%|██████████| 951/951 [23:13<00:00,  1.47s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]


Epoch: 13/200, Training Loss: 1.1220415522247458, Validation Loss: 1.2050916682118955
Accuracy: 0.42541562286495105, Recall: 0.42541562286495105, Precision: 0.4308650765060821, F1: 0.4144773487782117, Micro F1: 0.42541562286495105, Macro Roc Auc: 0.6902761174863986
EarlyStopping counter: 1 out of 3


100%|██████████| 951/951 [23:14<00:00,  1.47s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]


Epoch: 14/200, Training Loss: 1.1159402930523443, Validation Loss: 1.2199497304964757
Accuracy: 0.414711910726486, Recall: 0.414711910726486, Precision: 0.4198855471472415, F1: 0.41039224731946156, Micro F1: 0.414711910726486, Macro Roc Auc: 0.6898550758323146
EarlyStopping counter: 2 out of 3


100%|██████████| 951/951 [23:11<00:00,  1.46s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]


Epoch: 15/200, Training Loss: 1.107572077200616, Validation Loss: 1.227894092383592
Accuracy: 0.4247324071965384, Recall: 0.4247324071965384, Precision: 0.4214972711379851, F1: 0.4059911742110678, Micro F1: 0.4247324071965384, Macro Roc Auc: 0.6943067512025753


100%|██████████| 951/951 [23:12<00:00,  1.46s/it]
100%|██████████| 138/138 [01:13<00:00,  1.89it/s]


Epoch: 16/200, Training Loss: 1.098484591354707, Validation Loss: 1.220688401788905
Accuracy: 0.414711910726486, Recall: 0.414711910726486, Precision: 0.4247529633925429, F1: 0.4072815978078644, Micro F1: 0.414711910726486, Macro Roc Auc: 0.6896190165397577
EarlyStopping counter: 1 out of 3


100%|██████████| 951/951 [23:14<00:00,  1.47s/it]
100%|██████████| 138/138 [01:13<00:00,  1.88it/s]


Epoch: 17/200, Training Loss: 1.0887792513448231, Validation Loss: 1.2220693766206936
Accuracy: 0.42382145297198814, Recall: 0.42382145297198814, Precision: 0.4210168865122808, F1: 0.41534330405168624, Micro F1: 0.42382145297198814, Macro Roc Auc: 0.6910325790268687
EarlyStopping counter: 2 out of 3


100%|██████████| 951/951 [23:12<00:00,  1.46s/it]
100%|██████████| 138/138 [01:13<00:00,  1.88it/s]

Epoch: 18/200, Training Loss: 1.0827796709274269, Validation Loss: 1.2256703670474067
Accuracy: 0.41607834206331135, Recall: 0.41607834206331135, Precision: 0.4156861027670152, F1: 0.40640547347358336, Micro F1: 0.41607834206331135, Macro Roc Auc: 0.6886251099280493
EarlyStopping counter: 3 out of 3
Early stopping





In [15]:
import os

# list all files in the current directory
files = os.listdir('.')

# filter the ones that start with 'CORE_baseline'
core_models = [f for f in files if f.startswith('distilBERT_baseline')]

if core_models:
    print("Found models starting with 'distilBERT_baseline':")
    for model in core_models:
        print(model)
        
    # get the first (and supposedly only) model
    model_path = core_models[0]

    # load the model state
    ensemble_model.load_state_dict(torch.load(model_path))
    print("Loaded Model")
else:
    print("No models found starting with 'distilBERT_baseline'.")

  ensemble_model.load_state_dict(torch.load(model_path))


Found models starting with 'distilBERT_baseline':
distilBERT_baseline_epoch_8roc_0.6837150915829353.pth
distilBERT_baseline_epoch_5roc_0.6702639416931335.pth
distilBERT_baseline_epoch_6roc_0.678217637238053.pth
distilBERT_baseline_epoch_11roc_0.6911873596124583.pth
distilBERT_baseline_epoch_0roc_0.562613747819919.pth
distilBERT_baseline_epoch_2roc_0.5845949891060462.pth
distilBERT_baseline_epoch_10roc_0.6908063788663262.pth
distilBERT_baseline_epoch_1roc_0.5718401572458308.pth
distilBERT_baseline_epoch_7roc_0.6828126603102828.pth
distilBERT_baseline_epoch_4roc_0.657183922276356.pth
distilBERT_baseline_epoch_3roc_0.6239034565306528.pth
distilBERT_baseline_epoch_9roc_0.689246835621718.pth
distilBERT_baseline_epoch_14roc_0.6943067512025753.pth
Loaded Model


In [16]:
# Put the model in evaluation mode
ensemble_model.eval()

# Initialize lists to store predictions and true labels
test_preds = []
test_labels = []

# Iterate over test data
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = ensemble_model(input_ids, attention_mask)
        test_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
        test_labels.append(labels.cpu().numpy())

100%|██████████| 275/275 [02:26<00:00,  1.87it/s]


In [17]:
test_preds = np.concatenate(test_preds)
test_labels = np.concatenate(test_labels)

# Calculate metrics
test_preds_class = np.argmax(test_preds, axis=1)
accuracy = accuracy_score(test_labels, test_preds_class)
recall = recall_score(test_labels, test_preds_class, average='weighted')
precision = precision_score(test_labels, test_preds_class, average='weighted')
f1 = f1_score(test_labels, test_preds_class, average='weighted')
micro_f1 = f1_score(test_labels, test_preds_class, average='micro')
macro_roc_auc = roc_auc_score(test_labels, test_preds, multi_class='ovo', average='macro')

print(f'Accuracy: {accuracy}, Recall: {recall}, Precision: {precision}, F1: {f1}, Micro F1: {micro_f1}, Macro Roc Auc: {macro_roc_auc}')

Accuracy: 0.4161646015687166, Recall: 0.4161646015687166, Precision: 0.42059029646987267, F1: 0.4115122082263095, Micro F1: 0.4161646015687166, Macro Roc Auc: 0.6937660343892668
