<a href="https://colab.research.google.com/github/halfmoonliu/example-graphrag/blob/main/Evaluate_10202024.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Retrieval Augmented Medical Question Answering

Yun-Chung Liu (yl974) </br></br>

This notebook demonstrates how to **apply RAG** (Retrieval Augmented Generation) to generate **long answers** (conclusion, or summary) to **Medical Questions** (generated from research papers on PubMed).

In [None]:
# import libraries
! pip install datasets

from datasets import Dataset, DatasetDict, load_dataset
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import transformers
from transformers import AdamW, AutoTokenizer, BertModel, BertTokenizer
from transformers import GPT2Tokenizer, GPT2Model
from transformers import pipeline, set_seed
from nltk.translate.bleu_score import sentence_bleu
!pip install evaluate
!pip install rouge-score
import evaluate

# import libraries
from google.colab import drive

# for dataset access
drive.mount('/content/drive')

In [None]:
# print version
print(transformers.__version__)

4.40.1


##  I. Dataset

The dataset used for this project is PubMedQA. The PubMedQA[1] dataset contains questions generated with research papers on PubMed, one of the most popular database for biomedical research worldwide. The dataset has three sets of data: *labeled*, *unlabeled*, and *artificial*, depending on ways the questions and answers were generated.

The questions were obtained from research papers whose title is a questions. For each question, there is a short answer (yes/no) and a long answer (usually the conclusion of the abstract). Below is an example:

In [None]:
Dataset_l_raw = load_dataset("qiaojin/PubMedQA", "pqa_labeled")

In [None]:
# an example of the labeled dataset
Dataset_l_raw['train'][0]

{'pubid': 21645374,
 'question': 'Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?',
 'context': {'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.',
   'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), ce

## II. Preprocessing

### i. Tokenization

To represent document and prompt with the **BERT** [*cls*] **token**, **the first step is tokenization**, or **mapping texts into indices of words/ subwords using the pretrained BERT model**. Below is an example. Since BERT applied **byte pair encoding, some tokens are sub-words instead of words** (e.g. ##ap, ##se, as shown below).

In [None]:
transformer_name = "bert-base-uncased"
tokenizer_bert = BertTokenizer.from_pretrained(transformer_name)



In [None]:
# example reference sentence
example_ref = "i love you too much"
example_ref_tokenized = tokenizer_bert.tokenize(example_ref)
example_ref_tokenized

['i', 'love', 'you', 'too', 'much']

### ii. Evaluating generated response

Generated answers to medical questions were evaluated using BLEU and ROUGE-L score. From the example below, we can see **the limitation of these two measures**. The **reference sentence** "*I love you too much*" **has almost the opposite meaning** than the  **predicted sentence** "*I hate you too*". However, **the bag-of-word precision-based BLEU** score and **F1-score-based ROUGH-L** score are relatively high.

In [None]:
# example predictied sentence
example_pred = "i hate you too"
example_pred_tokenized = tokenizer_bert.tokenize(example_pred)

example_pred_tokenized

['i', 'hate', 'you', 'too']

In [None]:
# calculate 1-gram BLEU score

example_bleu = sentence_bleu([example_ref_tokenized],
                             example_pred_tokenized,
                             weights = [1])

print(f"Example BLEU score: {example_bleu:.3f}")


Example BLEU score: 0.584


In [None]:
# calculate ROUGe-L score
rouge = evaluate.load('rouge')

rouge_results = rouge.compute(predictions=[example_pred],
                        references=[example_ref])
print(rouge_results)
print(f"Example Rouge-L score: {rouge_results['rougeL']:.3f}")

{'rouge1': 0.6666666666666665, 'rouge2': 0.28571428571428575, 'rougeL': 0.6666666666666665, 'rougeLsum': 0.6666666666666665}
Example Rouge-L score: 0.667


### iii. Dataset Preparation

The dataset is first turned into a panda dataframe. Then, **all contexts** (some or all sections in a structured abstract, e.g. background, methods, results, excluding the conclusion section) **were concatenated into a whole piece of text** named *abstract*.

In [None]:
# Convert dataset to dataframe
Dataset_l_df = Dataset_l_raw['train'].to_pandas()
Dataset_l_df.sample(5)

Unnamed: 0,pubid,question,context,long_answer,final_decision
93,14978612,Does positron emission tomography change manag...,{'contexts': ['The influence of positron emiss...,Position emission tomography scanning appears ...,yes
692,17462393,Does normothermic normokalemic simultaneous an...,{'contexts': ['Beating-heart valve surgery app...,Normothermic normokalemic simultaneous antegra...,no
100,22564465,Mammographic screening in Sami speaking munici...,{'contexts': ['Female citizens of Sami (the in...,"Despite a lower risk of breast cancer, the Sam...",yes
562,23774337,Does the central venous pressure predict fluid...,{'contexts': ['Despite a previous meta-analysi...,There are no data to support the widespread pr...,no
729,7664228,Discharging patients earlier from Winnipeg hos...,{'contexts': ['To determine whether decreasing...,Improving hospital efficiency by shortening le...,no


In [None]:
# concatenate contexts
def retrieve_abstract(context):
  return ' '.join(context['contexts'])

Dataset_l_df['abstract'] = Dataset_l_df['context'].apply(retrieve_abstract)

In [None]:
# concatenate question and the abstract from the same paper
Dataset_l_df['qa'] = Dataset_l_df['question'] + ' ' + Dataset_l_df['abstract']

In [None]:
Dataset_l_df = Dataset_l_df[['pubid','question', 'abstract', 'qa','long_answer', 'final_decision']]
Dataset_l_df.sample(5)

Unnamed: 0,pubid,question,abstract,qa,long_answer,final_decision
612,10927144,Can p53 alterations be used to predict tumour ...,To examine whether p53 tumour suppressor gene ...,Can p53 alterations be used to predict tumour ...,p53 alteration detected by IHC or SSCP analysi...,no
468,16778275,Is routine chest radiography after transbronch...,Pneumothorax following flexible bronchoscopy (...,Is routine chest radiography after transbronch...,We conclude that routine CXR after bronchoscop...,no
355,25371231,Is vitamin D insufficiency or deficiency relat...,The aetiology of osteochondritis dissecans is ...,Is vitamin D insufficiency or deficiency relat...,These first data show that a vitamin D3 defici...,maybe
586,22302658,Does limb-salvage surgery offer patients bette...,Patients with aggressive lower extremity muscu...,Does limb-salvage surgery offer patients bette...,These data suggest that limb-salvage surgery o...,maybe
938,20577124,Is leptin involved in phagocytic NADPH oxidase...,Hyperleptinemia and oxidative stress play a ma...,Is leptin involved in phagocytic NADPH oxidase...,These findings show that phagocytic NADPH oxid...,yes


### iv. Representing Documents

To represent documents with the **BERT** [*cls*] **token**, **tokenized documents were fed into the pretrained BERT model**. Then, **the first token embedding**, the [*cls*] embedding, **can be retrieved to represent the document**.

In [None]:
# load BERT tokenizer (uncased)
transformer_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(transformer_name)

# load pretrained BERT model
model = BertModel.from_pretrained(transformer_name)

# assign device (cuda if possible)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load to device
model = model.to(device)



From the example above, one can observe that **the** [*cls*] **token is the** ***first*** **token when feeding the text into the pretrained model**. With this in mind, all document representations in the database can be derived.

In [None]:
def getBertCls(text):
  '''
  this function takes the following input:
  text to be represented by the BERT CLS token
  and gives the following output:
  a numpy array representing the text
  '''
  tok_text = tokenizer(text[:512],
                       return_tensors='pt').to(device)
  mod_output = model(**tok_text,
                     output_hidden_states=True)
  last_hidden_states = mod_output.hidden_states[-1]
  return last_hidden_states[:,0,:].cpu().detach().numpy()[0]

In [None]:
# create column with cls representation of document
Dataset_l_df['question_cls'] = Dataset_l_df['question'].apply(getBertCls)
Dataset_l_df['abstract_cls'] = Dataset_l_df['abstract'].apply(getBertCls)

Dataset_l_df.sample(5)

Unnamed: 0,pubid,question,abstract,qa,long_answer,final_decision,question_cls,abstract_cls
555,15954832,Is minilaparoscopic inguinal hernia repair fea...,Laparoscopy has rapidly emerged as the preferr...,Is minilaparoscopic inguinal hernia repair fea...,"While limited by its retrospective design, the...",yes,"[-0.6641139, -0.6014889, -0.13556018, -0.10516...","[-0.4979502, 0.010480256, 0.13777044, -0.58153..."
395,25887165,Does Sensation Return to the Nasal Tip After M...,Patients usually complain about numbness in th...,Does Sensation Return to the Nasal Tip After M...,Postoperative numbness occurs in most patients...,yes,"[-0.26579577, -0.26328933, -0.4678107, -0.1576...","[-0.31449622, -0.26738158, -0.029971201, -0.05..."
527,23052500,Staging laparoscopy in patients with hepatocel...,Staging laparoscopy (SL) is not regularly perf...,Staging laparoscopy in patients with hepatocel...,"The overall yield of SL for HCC was 7 %, and t...",no,"[-0.5822657, -0.7911572, -0.47684184, -0.34038...","[-0.42638335, -0.33983535, -0.08002636, -0.403..."
335,23025584,Does stress increase imitation of drinking beh...,That alcohol consumption is strongly influence...,Does stress increase imitation of drinking beh...,"Generally, it appears that among young male ad...",no,"[-0.057301488, -0.21439199, -0.38715002, 0.074...","[-0.24706757, -0.15074095, 0.32532862, -0.2806..."
666,18472368,Does treatment duration affect outcome after r...,The protraction of external beam radiotherapy ...,Does treatment duration affect outcome after r...,A proportionally longer treatment duration was...,yes,"[-0.49996147, -0.12104655, -0.3541218, -0.2416...","[-0.4997697, -0.35794866, 0.033892825, -0.1149..."


### v. Retrieve the most relevant document

To retrieve the most relevant document, **cosine similarity was calculated between every question and all the abstracts**. **The abstract with the highest similarity score with the question will be retrieved**.

In [None]:
def calPromptCosSim(emb1, emb2):
  '''
  return the cosine similarity of the
  2 input numpy array
  '''
  result = emb1 @ emb2.T
  result /= (np.linalg.norm(emb1)*np.linalg.norm(emb2))
  return result

In [None]:
retrieved_abstract = list()

pubids = Dataset_l_df['pubid'].tolist()
q_embs = Dataset_l_df['question_cls'].to_numpy()
a_embs = Dataset_l_df['abstract_cls'].to_numpy()
abstracts = Dataset_l_df['abstract'].tolist()

# loop through all questions
for i in range(len(pubids)):
  # retrieve the most similar document
  similarities = calPromptCosSim(q_embs[i], np.stack(a_embs))
  idx_doc = np.argmax(similarities)
  retrieved_abstract.append([pubids[i], abstracts[idx_doc]])

retrieved_abstracts_df = pd.DataFrame(retrieved_abstract,
                                      columns=['pubid', 'abstract_r'])

In [None]:
# join dataset with the retrieved document
Dataset_l_df = Dataset_l_df.merge(retrieved_abstracts_df,
                                  on='pubid',
                                  how = 'inner')
Dataset_l_df.head(5)

Unnamed: 0,pubid,question,abstract,qa,long_answer,final_decision,question_cls,abstract_cls,abstract_r
0,21645374,Do mitochondria play a role in remodelling lac...,Programmed cell death (PCD) is the regulated d...,Do mitochondria play a role in remodelling lac...,Results depicted mitochondrial dynamics in viv...,yes,"[-0.5154661, -0.106142305, -0.28411922, -0.123...","[-0.20179763, -0.32204536, -0.32318643, 0.1764...",To study whether exercise during pregnancy red...
1,16418930,Landolt C and snellen e acuity: differences in...,Assessment of visual acuity depends on the opt...,Landolt C and snellen e acuity: differences in...,"Using the charts described, there was only a s...",no,"[-0.3664919, -0.18892226, -0.6237057, -0.75681...","[-0.28828055, -0.14107008, -0.44304717, -0.286...",Complex regional pain syndrome type I is treat...
2,9488747,"Syncope during bathing in infants, a pediatric...",Apparent life-threatening events in infants ar...,"Syncope during bathing in infants, a pediatric...","""Aquagenic maladies"" could be a pediatric form...",yes,"[-0.34855083, -0.34341225, -0.68562186, 0.1192...","[-0.638743, -0.52982956, -0.69166505, -0.37416...",To study whether exercise during pregnancy red...
3,17208539,Are the long-term results of the transanal pul...,The transanal endorectal pull-through (TERPT) ...,Are the long-term results of the transanal pul...,Our long-term study showed significantly bette...,no,"[-0.20517431, -0.23843785, 0.013513021, -0.034...","[-0.57676816, -0.67617077, -0.38403293, -0.354...",Various factors contribute to the effective im...
4,10808977,Can tailored interventions increase mammograph...,Telephone counseling and tailored print commun...,Can tailored interventions increase mammograph...,The effects of the intervention were most pron...,yes,"[-0.53521794, -0.33332643, -0.3663158, -0.0693...","[-0.34828824, -0.15164366, -0.6552947, -0.3681...",To study whether exercise during pregnancy red...


In [None]:
# concatenate question and the abstract from the same paper
Dataset_l_df['qa_retrieved'] = Dataset_l_df['question'] + ' ' + Dataset_l_df['abstract_r']

In [None]:
# output dataset
path = '/content/drive/MyDrive/Duke/Spring2024/LLM/FinalProject/'


Dataset_l_df.to_csv(path+'DatasetREG_20240502.csv',
                    index=False)

## C. Generate Long Answers
Long answers were generated with a pretrained gpt-2 model.

In [None]:
# set random seed
set_seed(42)

# define gpt2 text generator
generator = pipeline('text-generation',
                     model='gpt2')



config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [None]:
def gen_long_response(txt):
  generator = pipeline('text-generation',
                       model='gpt2')
  gen_text = generator(txt,
                       max_new_tokens = 30,
                       num_return_sequences = 1,
                       pad_token_id=generator.tokenizer.eos_token_id)
  return gen_text[0]['generated_text'][len(txt):]

In [None]:
Dataset_l_df['gen_answer_q'] = Dataset_l_df['question'].apply(gen_long_response)
Dataset_l_df['gen_answer_qa'] = Dataset_l_df['qa'].apply(gen_long_response)
Dataset_l_df['gen_answer_qar'] = Dataset_l_df['qa_retrieved'].apply(gen_long_response)



In [None]:
Dataset_l_df.sample(5)

Unnamed: 0,pubid,question,abstract,qa,long_answer,question_cls,abstract_cls,abstract_r,qa_retrieved,gen_answer_q,gen_answer_qa,gen_answer_qar
521,26348845,Pap smears with glandular cell abnormalities: ...,Rapid prescreening (RPS) is one of the quality...,Pap smears with glandular cell abnormalities: ...,Pap smears with glandular cell abnormalities a...,"[-0.25452468, -0.52106756, -0.2344852, -0.0410...","[-0.5702184, -0.24906324, -0.21664818, -0.2179...",Complex regional pain syndrome type I is treat...,Pap smears with glandular cell abnormalities: ...,"\n\nThis is a fascinating one, because if this...",The results indicated that only 23.3% of the ...,\n\nCONCLUSION As a consequence of the recent ...
737,10973547,Are patients with Werlhof's disease at increas...,"It is generally assumed, that patients with We...",Are patients with Werlhof's disease at increas...,Patients with WD may possibly undergo cardiac ...,"[-0.60479903, -0.31824994, -0.5445265, -0.0714...","[-0.08996212, -0.3734647, -0.31098098, -0.1186...",To determine the potential prognostic value of...,Are patients with Werlhof's disease at increas...,"\n\nThis is not a question of ""skeptical docto...",\n\nF ig. 1. View largeDownload slide Anterior...,"In a FASEB, a nonfatal stroke is an establish..."
740,26163474,Is there a connection between sublingual varic...,Sublingual varices have earlier been related t...,Is there a connection between sublingual varic...,An association was found between sublingual va...,"[-0.31770706, -0.40249935, -0.47051668, -0.164...","[-0.3533773, -0.14783579, -0.46499273, -0.2138...",Complex regional pain syndrome type I is treat...,Is there a connection between sublingual varic...,I feel like a good friend who's been studying...,Patients with different systolic blood pressu...,"At the same time, the prevalence of complex r..."
660,24374414,Does health information exchange reduce redund...,Broad-based electronic health information exch...,Does health information exchange reduce redund...,HIE was associated with reduced repeat imaging...,"[-0.37056234, -0.18045987, -0.402243, -0.13236...","[-0.509178, -0.50266314, -0.09870474, -0.24097...",Specialty pharmaceuticals have evolved beyond ...,Does health information exchange reduce redund...,How would better information exchange help to...,"HIE was associated with lower survival, reduce...",What they might do. An international coalitio...
411,18507507,The promise of specialty pharmaceuticals: are ...,Specialty pharmaceuticals have evolved beyond ...,The promise of specialty pharmaceuticals: are ...,Current evidence suggests that when used in ta...,"[-0.026619682, 0.0021823128, -0.020715406, 0.0...","[-0.34382528, -0.15578249, -0.11090787, -0.145...",Various factors contribute to the effective im...,The promise of specialty pharmaceuticals: are ...,"We're not convinced yet,"" says Tom Karp at Bl...",The benefits of prescription-drug utilization...,\n\npreventative and adaptive use of therapist...


In [None]:
# output resulting dataset
Dataset_l_df.to_csv(path+'DatasetREG_LongAns_20240502.csv',
                    index=False)

## D. Evaluate Results

Long answers were evaluated using one-gram BLEU score and ROUGE-L score against the ground truths.

### i. BLEU Score

In [None]:
def cal_bleu_q(row):
  gt = row['long_answer']
  gt_tok = tokenizer_bert.tokenize(gt)
  pred = row['gen_answer_q']
  try:
    pred_tok = tokenizer_bert.tokenize(pred)
  except:
    pred_tok = []
  # Calculate BLEU score with weights
  bleu = sentence_bleu([gt_tok],
                        pred_tok,
                        weights = [1])
  return bleu

Dataset_l_df['answer_q_bleu'] = Dataset_l_df.apply(cal_bleu_q, axis=1)

In [None]:
def cal_bleu_qa(row):
  gt = row['long_answer']
  gt_tok = tokenizer_bert.tokenize(gt)
  pred = row['gen_answer_qa']
  try:
    pred_tok = tokenizer_bert.tokenize(pred)
  except:
    pred_tok = []
  # Calculate BLEU score with weights
  bleu = sentence_bleu([gt_tok],
                        pred_tok,
                        weights = [1])
  return bleu

Dataset_l_df['answer_qa_bleu'] = Dataset_l_df.apply(cal_bleu_qa, axis=1)

In [None]:
def cal_bleu_qar(row):
  gt = row['long_answer']
  gt_tok = tokenizer_bert.tokenize(gt)
  pred = row['gen_answer_qar']
  try:
    pred_tok = tokenizer_bert.tokenize(pred)
  except:
    pred_tok = []
  # Calculate BLEU score with weights
  bleu = sentence_bleu([gt_tok],
                        pred_tok,
                        weights = [1])
  return bleu

Dataset_l_df['answer_qar_bleu'] = Dataset_l_df.apply(cal_bleu_qar, axis=1)

In [None]:
bleus = Dataset_l_df[['answer_q_bleu', 'answer_qa_bleu','answer_qar_bleu']]
bleus.describe()

Unnamed: 0,answer_q_bleu,answer_qa_bleu,answer_qar_bleu
count,1000.0,1000.0,1000.0
mean,0.097784,0.137156,0.088884
std,0.069654,0.081002,0.052309
min,0.0,0.0,0.0
25%,0.046613,0.078786,0.051912
50%,0.081805,0.126008,0.08106
75%,0.134697,0.183121,0.121779
max,0.460319,0.5,0.343575


### ii. ROUGE Score


In [None]:
def cal_rouge_q(row):
  gt = row['long_answer']
  pred = row['gen_answer_q']
  # Calculate ROUGE score with weights
  try:
    rouge_results = rouge.compute(predictions=[pred],
                        references=[gt])
    return rouge_results['rougeL']
  except:
    return None

Dataset_l_df['answer_q_rouge'] = Dataset_l_df.apply(cal_rouge_q, axis=1)

In [None]:
def cal_rouge_qa(row):
  gt = row['long_answer']
  pred = row['gen_answer_qa']
  # Calculate ROUGE score with weights
  try:
    rouge_results = rouge.compute(predictions=[pred],
                        references=[gt])
    return rouge_results['rougeL']
  except:
    return None

Dataset_l_df['answer_qa_rouge'] = Dataset_l_df.apply(cal_rouge_qa, axis=1)

In [None]:
def cal_rouge_qar(row):
  gt = row['long_answer']
  pred = row['gen_answer_qar']
  # Calculate ROUGE score with weights
  try:
    rouge_results = rouge.compute(predictions=[pred],
                        references=[gt])
    return rouge_results['rougeL']
  except:
    return None

Dataset_l_df['answer_qar_rouge'] = Dataset_l_df.apply(cal_rouge_qar, axis=1)

In [None]:
rouges = Dataset_l_df[['answer_q_rouge', 'answer_qa_rouge','answer_qar_rouge']]
rouges.describe()

Unnamed: 0,answer_q_rouge,answer_qa_rouge,answer_qar_rouge
count,1000.0,1000.0,1000.0
mean,0.09214,0.126588,0.09233
std,0.0586,0.059388,0.043708
min,0.0,0.0,0.0
25%,0.051282,0.086957,0.065574
50%,0.089888,0.125,0.090909
75%,0.129032,0.16,0.121212
max,0.411765,0.347826,0.25


In [None]:
Dataset_l_df.to_csv(path+'DatasetREG_LongAnsEval_20240502.csv',
                    index=False)

### References
1. Jin, Qiao, et al. "Pubmedqa: A dataset for biomedical research question answering." arXiv preprint arXiv:1909.06146 (2019).