### Using Facebook BART Model to conduct classification of the 5 diseases from cleaned text

Data pre-processing

In [None]:
import pandas as pd

# Load the pickle file into a pandas DataFrame
sampled_df = pd.read_pickle("mimic_iv_sampled_df.pkl")

In [7]:
from sklearn.preprocessing import MultiLabelBinarizer

mlb = MultiLabelBinarizer()
one_hot_encoded_data = mlb.fit_transform(sampled_df['y'])
one_hot_encoded_df = pd.DataFrame(one_hot_encoded_data, columns=mlb.classes_)

In [8]:
print(one_hot_encoded_df.shape)
print(one_hot_encoded_df)

(12000, 5)
       anemia  atrial fibrillation  hyperlipidemia  hypertension  pneumonia
0           0                    0               0             1          0
1           0                    1               0             0          1
2           1                    0               0             0          0
3           0                    1               0             0          0
4           0                    0               0             1          1
...       ...                  ...             ...           ...        ...
11995       0                    0               0             1          0
11996       0                    1               0             0          1
11997       0                    0               0             0          0
11998       0                    0               1             1          0
11999       0                    0               0             0          0

[12000 rows x 5 columns]


In [9]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(sampled_df['text'], one_hot_encoded_df, test_size=0.2, random_state=42)

In [10]:
print(y_train)

       anemia  atrial fibrillation  hyperlipidemia  hypertension  pneumonia
9182        0                    0               0             1          0
11091       0                    0               0             1          0
6428        0                    0               0             0          0
288         0                    1               0             0          0
2626        0                    1               0             1          0
...       ...                  ...             ...           ...        ...
11964       1                    0               0             1          0
5191        0                    1               0             1          0
5390        0                    0               0             0          1
860         0                    0               1             0          0
7270        0                    0               0             1          0

[9600 rows x 5 columns]


In [11]:
max_length_train = max(len(text.split()) for text in X_train)
max_length_test = max(len(text.split()) for text in X_test)

print("Maximum length in X_train:", max_length_train)
print("Maximum length in X_test:", max_length_test)

# You might want to use the overall maximum length for both training and testing
overall_max_length = max(max_length_train, max_length_test)
print("Overall maximum length:", overall_max_length)

Maximum length in X_train: 5678
Maximum length in X_test: 5280
Overall maximum length: 5678


### Implementing BART Model

To predict the labels (extracted disease) using BART model.

In [12]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import time
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import precision_score, recall_score, f1_score, hamming_loss
from transformers import BartForConditionalGeneration, AutoTokenizer, BartConfig
import os

# Check if CUDA is available
cuda_available = torch.cuda.is_available()
print("Is CUDA available? ", cuda_available)

# If CUDA is available, it prints: Is CUDA available? True
# Otherwise, it prints: Is CUDA available? False

Is CUDA available?  True


In [13]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base", padding_side="left", truncation_side='right')

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

Using device:  cuda


In [15]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        # Tokenize text
        encoded_text = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Return as a dictionary
        return {
            'input_ids': encoded_text['input_ids'].squeeze(0),  # Remove batch dimension
            'labels': torch.tensor(label, dtype=torch.float)  # Ensure labels are tensors
        }

max_length = 5678

train_dataset = TextDataset(X_train, y_train, tokenizer, max_length)
test_dataset = TextDataset(X_test, y_test, tokenizer, max_length)

batch_size = 32

dl_train = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
dl_test = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [16]:
class BartForMultiLabelClassification(BartForConditionalGeneration):
    def __init__(self, config, num_labels):
        super().__init__(config)
        self.num_labels = num_labels
        self.classifier = torch.nn.Linear(config.d_model, self.num_labels)

    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = super().forward(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output[:, 0, :])

        if labels is not None:
            loss_fct = torch.nn.BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.float().view(-1, self.num_labels))
            return loss, logits
        return logits


num_labels = 5

# Load configuration from pretrained model
config = BartConfig.from_pretrained("facebook/bart-base")
config.num_labels = num_labels  # Add num_labels to the configuration

# Initialize model with configuration
model = BartForMultiLabelClassification(config, num_labels=num_labels)

# Load the pretrained weights
model.load_state_dict(BartForConditionalGeneration.from_pretrained("facebook/bart-base").state_dict(), strict=False)

model = model.to(device)

In [17]:
class MedicalKeywordDataset(Dataset):
    def __init__(self, df, transcript, labels, tokenizer, max_length):
        self.df = df
        self.transcript = transcript
        self.labels = labels  # This should be a column with integer labels, not text
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        transcript_text = self.df[self.transcript].iloc[idx]
        transcript_tokens = self.tokenizer(
            transcript_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )['input_ids'].squeeze(0)  # Remove batch dimension

        label = torch.tensor(self.df[self.labels].iloc[idx])  # Assuming labels are already encoded as integers

        return transcript_tokens, label


In [26]:
for e in range(epochs):
    train(dl_train,train_batches,model, optimizer)
    test(dl_test, test_batches,model)

  attn_output = torch.nn.functional.scaled_dot_product_attention(




In [38]:
model

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), 

In [34]:
def generate_keywords(df,text, model, tokenizer):
    df['Result'] = df[text].apply(lambda x: tokenizer(x, max_length=750,
                                                               padding='max_length', truncation=True, return_tensors='pt')['input_ids'])
    if torch.cuda.is_available():
        df['Result'] = df['Result'].apply(lambda x: x.to("cuda"))

    df['Result'] = df['Result'].apply(lambda x: model.generate(x,
                                                               min_length=20,
                                                               max_length=100 ))
    df['Result'] = df['Result'].apply(lambda x: tokenizer.batch_decode(x,
                                                                       skip_special_tokens=True))
    return df

In [39]:
df_res = generate_keywords(df_test,'text',model,tokenizer)

In [40]:
df_res.to_pickle("mimic_iv_results.pkl")

In [41]:
torch.save(model, "path_to_model.pth")
torch.save(model.state_dict(), "path_to_model_state_dict.pth")


In [42]:
model.generate(tokenizer(df_test['text'].iloc[0],max_length=750,
                         padding="max_length", truncation=True,
                         return_tensors='pt')['input_ids'].to('cuda'),max_length=100)

tensor([[2, 0, 1, 0, 1, 1, 1, 2]], device='cuda:0')

In [43]:
tokenizer(df_test['text'].iloc[0],max_length=750,
          padding="max_length", truncation=True)

{'input_ids': [0, 13650, 1933, 2329, 17745, 1248, 2982, 36068, 1248, 1248, 3113, 2099, 856, 18542, 26467, 179, 41767, 118, 684, 41767, 118, 37930, 1262, 4289, 2725, 834, 3674, 1236, 26097, 636, 538, 37548, 636, 12259, 281, 31902, 710, 228, 8267, 260, 12581, 4003, 5090, 118, 364, 27122, 23671, 118, 1455, 4812, 76, 793, 693, 23671, 118, 32447, 179, 1668, 2292, 25806, 3894, 1455, 92, 23808, 1236, 26097, 636, 880, 183, 536, 40274, 14887, 1023, 1988, 4063, 2070, 1343, 2400, 39117, 1469, 1075, 405, 34548, 3529, 375, 20242, 186, 353, 47585, 179, 2400, 3977, 3914, 295, 8367, 786, 7822, 118, 10759, 1588, 94, 741, 119, 65, 186, 536, 67, 46931, 44153, 1073, 493, 2705, 6936, 685, 36612, 757, 291, 17243, 375, 76, 20181, 3953, 2714, 2379, 2520, 2985, 3069, 118, 11696, 363, 14711, 13146, 15352, 4242, 47160, 18339, 329, 493, 25599, 31695, 23385, 8367, 493, 36555, 14210, 23671, 118, 40436, 6106, 4076, 4437, 385, 3760, 118, 6602, 4603, 2052, 4603, 47012, 1208, 485, 43565, 257, 689, 20411, 3320, 48079, 4

In [44]:
df_res['Result'] = df_res['Result'].apply(lambda x: x[0])

In [45]:
print(df_res['extracted_diseases'].iloc[0])
print(df_res['Result'].iloc[0])

anemia
,pertensionemiahy


In [47]:
for i in range(5):

    print(f"-----------------Row no {i+1}------------------")
    print("Transcription:")
    print(df_res['text'].iloc[i])
    print("\n")
    print("Extracted_diseaes:")
    print(df_res['extracted_diseases'].iloc[i])
    print("\n")
    print("Result:")
    print(df_res['Result'].iloc[i])
    print("\n"*3)

-----------------Row no 1------------------
Transcription:
name unit admiss date discharg date date birth sex f servic medicin allergi known allergi advers drug reaction attend chief complaint jaundic major surgic invas procedur percutan liver biopsi egd histori present ill year old woman histori vagin cancer sp resect present new onset jaundic began day ago associ epigastricruq pain nauseavomit exacerb eat past sever week month abdomin pain cramp natur nonradi constip last bm one week ago also endors dysphagia solid liquid lost approxim 20lb past year gener weak especi leg deni fever night sweat chill melenahematochezia dysuria hematuria frequenc histori ivda drink beer daili quit smoke prior smoke cigday recent unusu food consumpt includ unusu mushroom deni use herbal supplement histori blood transfus recent bugtick bite seen md sinc regular cancer screen never colonoscopi recent mammogram pap note resect vagin cancer chemoradi tylenol use sick contact initi present morn lab notabl w