# Yes-No Answer Prediction for Medical Questions

This Notebook demonstrates how to load fined-tuned model to **generate yes-no answers to medical questions and evaluate results**.

In [None]:
# import libraries
! pip install datasets
from datasets import Dataset
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.metrics import f1_score, accuracy_score
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
import transformers
from transformers import AdamW, AutoTokenizer, BertModel, BertTokenizer
from transformers import pipeline, set_seed

# import libraries
from google.colab import drive

# for dataset access
drive.mount('/content/drive')

## I. Dataset

The test dataset used is the PQA-L subset from PubMedQA dataset [1], preprocessed so that questions, questions + associated abstract, and  

In [2]:
path = '/content/drive/MyDrive/Duke/Spring2024/LLM/FinalProject/'

# load datas
Testset_raw = pd.read_csv(path + 'DatasetREG_20240502.csv')

Testset_raw.sample(5)

Unnamed: 0,pubid,question,abstract,qa,long_answer,final_decision,question_cls,abstract_cls,abstract_r,qa_retrieved
979,18693227,Does a geriatric oncology consultation modify ...,This study was performed to describe the treat...,Does a geriatric oncology consultation modify ...,The geriatric oncology consultation led to a m...,yes,[-4.39530164e-01 -1.97288409e-01 -9.70160812e-...,[-7.55651236e-01 -2.28102103e-01 -4.33326036e-...,Community-based medical education is growing t...,Does a geriatric oncology consultation modify ...
865,23568387,Is bicompartmental knee arthroplasty more favo...,Bicompartmental knee arthroplasty features bon...,Is bicompartmental knee arthroplasty more favo...,"Although theoretically plausible, bicompartmen...",no,[-7.76706994e-01 -1.90531060e-01 -3.92592847e-...,[-6.98826015e-01 -6.91812515e-01 -7.53611684e-...,Tinnitus can be related to many different aeti...,Is bicompartmental knee arthroplasty more favo...
591,24793469,Is there any relation between cervical cord pl...,Multiple sclerosis (MS) is the most common chr...,Is there any relation between cervical cord pl...,The study data suggests a possible correlation...,yes,[-2.76793122e-01 -6.44086719e-01 -6.47584975e-...,[-4.52999204e-01 -6.43734634e-01 -7.00256646e-...,Tinnitus can be related to many different aeti...,Is there any relation between cervical cord pl...
753,10430303,Does laparoscopic cholecystectomy influence pe...,To investigate the influence of laparoscopic p...,Does laparoscopic cholecystectomy influence pe...,Laparoscopic procedures caused detectable dama...,yes,[-5.22794187e-01 -2.13746279e-01 -4.14774209e-...,[-5.29817939e-01 -5.54140329e-01 -4.40105826e-...,To study whether exercise during pregnancy red...,Does laparoscopic cholecystectomy influence pe...
887,21420186,Could ADMA levels in young adults born preterm...,Sporadic data present in literature report how...,Could ADMA levels in young adults born preterm...,Our findings reveal a significant decrease in ...,yes,[-1.41829416e-01 -4.72903520e-01 -1.62389055e-...,[-3.20394665e-01 -4.90744412e-01 -3.38018388e-...,Various factors contribute to the effective im...,Could ADMA levels in young adults born preterm...


In [3]:
clmns = ['question', 'qa', 'qa_retrieved', 'final_decision']

Testset = Testset_raw[Testset_raw['final_decision']!="maybe"].copy()[clmns].reset_index(drop=True)
Testset['final_decision'] = Testset['final_decision'].map({'yes': 1, 'no': 0})
Testset.sample(5)


Unnamed: 0,question,qa,qa_retrieved,final_decision
446,Increased neutrophil migratory activity after ...,Increased neutrophil migratory activity after ...,Increased neutrophil migratory activity after ...,1
153,Prevalence of chronic conditions among Medicar...,Prevalence of chronic conditions among Medicar...,Prevalence of chronic conditions among Medicar...,1
684,Can myometrial electrical activity identify pa...,Can myometrial electrical activity identify pa...,Can myometrial electrical activity identify pa...,1
73,Is Chaalia/Pan Masala harmful for health?,Is Chaalia/Pan Masala harmful for health? To d...,Is Chaalia/Pan Masala harmful for health? Vari...,1
205,PSA repeatedly fluctuating levels are reassuri...,PSA repeatedly fluctuating levels are reassuri...,PSA repeatedly fluctuating levels are reassuri...,0


In [4]:
# prepare datasets of different prompts
Testset_q   = Testset[['question', 'final_decision']]
Testset_q.columns =  ['input', 'final_decision']

Testset_qa  = Testset[['qa', 'final_decision']]
Testset_qa.columns =  ['input', 'final_decision']

Testset_qar = Testset[['qa_retrieved', 'final_decision']]
Testset_qar.columns =  ['input', 'final_decision']

Testset_q   = Dataset.from_pandas(Testset_q)
Testset_qa  = Dataset.from_pandas(Testset_qa)
Testset_qar = Dataset.from_pandas(Testset_qar)

### ii. Tokenization and mapping tokens to ind.
Tokenization is the step to **convert documents** (in this scenario, medical questions and relevant document) **into small units, or tokens**. **Token is the smallest unit to be represented by word embeddings** ($x_i$) included in the vocabulary. A token can be a word or sub-word, depending on the tokenizer. Here, the ***'bert-base-uncased'*** tokenizer is used, where **byte-pair encoding was applied**. In this case, **tokens can be words or sub-words** (e.g. ##ing).

All tokens in the corpus have their corresponding indices. Since the **BERT pretrained model is used**, **tokens can be converted to indices in the pretraining corpus**.

In [None]:
#load pretrained tokenizer
transformer_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(transformer_name)

In [6]:
# function to tokenize texts
def tokeAndMap(example, tokenizer):
    ids = tokenizer(example["input"], truncation=True)["input_ids"]
    return {"ids": ids}

In [None]:
# tokenize input
test_data_q = Testset_q.map(
    tokeAndMap, fn_kwargs={"tokenizer": tokenizer}
)
test_data_qa = Testset_qa.map(
    tokeAndMap, fn_kwargs={"tokenizer": tokenizer}
)
test_data_qar = Testset_qar.map(
    tokeAndMap, fn_kwargs={"tokenizer": tokenizer}
)

In [8]:
# Select ids and labels for model training
test_data_q = test_data_q.with_format(type="torch",
                                    columns=["ids", "final_decision"])
test_data_qa = test_data_qa.with_format(type="torch",
                                columns=["ids", "final_decision"])
test_data_qar = test_data_qar.with_format(type="torch",
                                columns=["ids", "final_decision"])

### iii. Generating batches
PyTorch's *dataloader* class is used to enable randomly selecting batches (one at a time) when training the model (or learning the parameters). When training the model, the length of text in the same batch should be the same. Thus, shorter sequences are *padded* by **adding the pad index to the end of shorter sequence** till they match the length of the longest sequence in the batch.

In [9]:
def get_batchFunc(pad_ind):
  # return a batch function
  def batchFunc(batch):
    # batch function
    # return a batch of text and labels
    batch_ids = [i['ids'] for i in batch]
    # pad sequence in a batch to the same length
    batch_ids = nn.utils.rnn.pad_sequence(
        batch_ids, padding_value = pad_ind, batch_first=True
    )
    batch_label = [i["final_decision"] for i in batch]
    batch_label = torch.stack(batch_label)
    batch = {'ids': batch_ids,
             'label': batch_label}
    return batch
  return batchFunc

In [10]:
def get_data_loader(dataset,
                    batch_size,
                    pad_index,
                    shuffle = False):
  # define a dataloader object using batchFunc
  batch_fn = get_batchFunc(pad_index)
  data_loader = torch.utils.data.DataLoader(
      dataset = dataset,
      batch_size = batch_size,
      collate_fn = batch_fn,
      shuffle = shuffle
  )
  return data_loader

In [11]:
# define batch size
batch_size = 8
pad_index = tokenizer.pad_token_id

# get dataloaders
q_data_loader = get_data_loader(test_data_q,
                                    batch_size,
                                    pad_index,
                                    shuffle=True)
qa_data_loader = get_data_loader(test_data_qa,
                                  batch_size,
                                  pad_index)

qar_data_loader = get_data_loader(test_data_qar,
                                    batch_size,
                                    pad_index,
                                    shuffle=True)


### III. Model Definition
To evaluate the trained models, we first define the model structure.

In [12]:
class Transformer(nn.Module):
  def __init__(self, transformer, output_dim, freeze):
    super().__init__()
    self.transformer = transformer
    hidden_dim = transformer.config.hidden_size
    # linear layer above the transformer model
    self.fc = nn.Linear(hidden_dim, output_dim)
    # options "frezze" transformer parameters
    if freeze:
      for param in self.transformer.parameters():
        param.requires_grad = False
  def forward(self, ids):
    # ids = [batch_size, seq_len]
    output = self.transformer(ids, output_attentions = True)
    # hidden = [batch_size, n_heads, seq_len, seq_len]
    hidden = output.last_hidden_state
    # attention = [batch_size, n_heads, seq_len, seq_len]
    attention = output.attentions[-1]
    # the first embedding is of the "cls" token
    cls_hidden = hidden[:, 0, :]
    pred = self.fc(torch.tanh(cls_hidden))
    # pred = [batch_size, output_dim]
    return pred

In [13]:
# transformer model
transformer = transformers.AutoModel.from_pretrained(transformer_name)

# model definition
output_dim = 2

# fine tune transformer parameters
freeze = True

model_NotTuned = Transformer(transformer, output_dim, freeze)
model_Tuned = Transformer(transformer, output_dim, freeze)

In [14]:

model_name_NotTuned = 'BERT_YesNoPred_NotTune_20240502.pt'
model_name_Tuned = 'BERT_YesNoPred_Tuned_20240502.pt'

# Load trained model parameters
model_NotTuned.load_state_dict(torch.load(path+model_name_NotTuned))
model_Tuned.load_state_dict(torch.load(path+model_name_Tuned))


<All keys matched successfully>

In [15]:
# assign device (cuda if possible)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load to device
model_NotTuned = model_NotTuned.to(device)
model_Tuned = model_Tuned.to(device)


In [16]:
def evaluate(data_loader, model, device):

  # evaluation mode
  model.eval()
  # store batch losses and accuracy
  preds = []
  labels = []

  # not updating parameters
  with torch.no_grad():
    for batch in tqdm.tqdm(data_loader, desc = "evaluating..."):
      ids = batch['ids'].to(device)
      labels_b = batch['label'].to(device)
      # get model prediction
      pred = model(ids)
      pred_classes = pred.argmax(dim = -1)


      # store predictions
      preds  += list(pred_classes.detach().cpu().numpy())
      labels += list(labels_b.detach().cpu().numpy())

  # calculate accuracy and f1 score
  acc = accuracy_score(labels, preds)
  f1 = f1_score(labels, preds)

  return acc, f1

In [17]:
# evaluate results using NotTuned Model

acc_q_NT, f1_q_NT = evaluate(q_data_loader, model_NotTuned, device)
acc_qa_NT, f1_qa_NT = evaluate(qa_data_loader, model_NotTuned, device)
acc_qar_NT, f1_qar_NT = evaluate(qar_data_loader, model_NotTuned, device)

evaluating...:   0%|          | 0/112 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
evaluating...: 100%|██████████| 112/112 [00:05<00:00, 21.12it/s]
evaluating...: 100%|██████████| 112/112 [00:27<00:00,  4.05it/s]
evaluating...: 100%|██████████| 112/112 [00:28<00:00,  3.92it/s]


In [18]:
print("Evaluation Result without Fine-Tuning BERT Parameters:")
print(f"Accuracy: {acc_qa_NT:.5f}, F1 Score: {f1_qa_NT:.5f}. (qustion + correct document)")
print(f"Accuracy: {acc_qar_NT:.5f}, F1 Score: {f1_qar_NT:.5f}. (qustion + retrieved document)")
print(f"Accuracy: {acc_q_NT:.5f}, F1 Score: {f1_q_NT:.5f}. (qustion only)")

Evaluation Result without Fine-Tuning BERT Parameters:
Accuracy: 0.65281, F1 Score: 0.76358. (qustion + correct document)
Accuracy: 0.60674, F1 Score: 0.74747. (qustion + retrieved document)
Accuracy: 0.62584, F1 Score: 0.76265. (qustion only)


In [19]:
# evaluate results using Tuned Model

acc_q_T,   f1_q_T   = evaluate(q_data_loader,   model_Tuned, device)
acc_qa_T,  f1_qa_T  = evaluate(qa_data_loader,  model_Tuned, device)
acc_qar_T, f1_qar_T = evaluate(qar_data_loader, model_Tuned, device)

evaluating...: 100%|██████████| 112/112 [00:02<00:00, 49.30it/s]
evaluating...: 100%|██████████| 112/112 [00:28<00:00,  3.90it/s]
evaluating...: 100%|██████████| 112/112 [00:26<00:00,  4.18it/s]


In [20]:
print("Evaluation Result BERT Parameters Fine-Tuned :")
print(f"Accuracy: {acc_qa_T:.5f}, F1 Score: {f1_qa_T:.5f}. (qustion + correct document)")
print(f"Accuracy: {acc_qar_T:.5f}, F1 Score: {f1_qar_T:.5f}. (qustion + retrieved document)")
print(f"Accuracy: {acc_q_T:.5f}, F1 Score: {f1_q_T:.5f}. (qustion only)")

Evaluation Result BERT Parameters Fine-Tuned :
Accuracy: 0.74494, F1 Score: 0.81409. (qustion + correct document)
Accuracy: 0.63596, F1 Score: 0.76316. (qustion + retrieved document)
Accuracy: 0.63708, F1 Score: 0.77011. (qustion only)


### References
1. Jin, Qiao, et al. "Pubmedqa: A dataset for biomedical research question answering." arXiv preprint arXiv:1909.06146 (2019).