### Load data

In [None]:
import json

# Path to the JSON file
json_file_path = './dataset/PubMedQA/test.json'

# Initialize lists to store the individual fields
test_id_list = []
test_question_list = []
test_context_list = []
test_longlabel_list = []
test_shortlabel_list = []

# Open the JSON file and read line by line
with open(json_file_path, 'r', encoding='utf-8') as file:
    for line in file:
        # Parse each line (which is a JSON object) into a Python dictionary
        sample = json.loads(line.strip())
        
        # Extract and store each field in its respective list
        test_id_list.append(sample.get('pubid'))
        test_question_list.append(sample.get('question'))
        list_of_context = sample.get('context')['contexts']
        test_context_list.append(" ".join(list_of_context))
        test_longlabel_list.append(sample.get('long_answer'))
        test_shortlabel_list.append(sample.get('final_decision'))

### Getting results from LLMs

In [None]:
import openai
import os
import pathlib
import textwrap
import google.generativeai as genai
from IPython.display import display
from IPython.display import Markdown
from google.colab import userdata
from vertexai.generative_models import GenerationConfig

from transformers import AutoTokenizer, T5ForConditionalGeneration
from transformers import LlamaTokenizer, LlamaForCausalLM
from transformers import pipeline


def get_output(prompt, llm):
    if llm == 3.5:
        openai.api_key = ''
        model = 'gpt-3.5-turbo-1106'
        message = openai.ChatCompletion.create(
            model=model,
            temperature=0,
            messages=[
                    {"role": "user", "content": prompt}
                ]
        )
        result = message['choices'][0]['message']['content']

    elif llm == 4:
        openai.api_key = ''
        model = 'gpt-4'
        message = openai.ChatCompletion.create(
            model=model,
            temperature=0,
            messages=[
                    {"role": "user", "content": prompt}
                ]
        )
        result = message['choices'][0]['message']['content']

    elif llm == 'instruct':
        openai.api_key = ''
        model = "gpt-3.5-turbo-instruct"
        message = openai.Completion.create(
            model = model,
            prompt = prompt,
            temperature = 0    
        )
        result = message['choices'][0]['text']   
    
    elif llm == 'gemini-1.0-pro':
        GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
        genai.configure(api_key=GOOGLE_API_KEY)
        model = genai.GenerativeModel('gemini-pro')

        safety_settings = [
            {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
            {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
            {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
            {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}
        ]        

        response = model.generate_content(prompt,
                generation_config=genai.types.GenerationConfig(
                candidate_count=1,
                temperature=0),
                safety_settings=safety_settings)
    
        result = response.candidates[0].content.parts[0].text
    

    elif llm == 'flan-ul2':
        model = T5ForConditionalGeneration.from_pretrained("google/flan-ul2", device_map="auto", load_in_8bit=True)                                                                 
        tokenizer = AutoTokenizer.from_pretrained("google/flan-ul2")
        inputs = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
        outputs = model.generate(inputs, max_length=200)
        result = tokenizer.decode(outputs[0])
    

    elif llm == 'med-alpaca':
        pipeline = pipeline("text-generation", model="medalpaca/medalpaca-7b", tokenizer="medalpaca/medalpaca-7b")
        result = pipeline(prompt, max_length=200)[0]['generated_text']

    elif llm == 'pmc-llama':
        tokenizer = LlamaTokenizer.from_pretrained('axiong/PMC_LLaMA_13B')
        model = LlamaForCausalLM.from_pretrained('axiong/PMC_LLaMA_13B')
        encoded_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
        output = model.generate(**encoded_input)
        result = tokenizer.decode(output[0])

    else:
        raise ValueError('Invalid LLM')

    print(result)
    return result

In [None]:
def create_prompt(question, reference, prompt_type='base'):
    
    if prompt_type == 'guide':
        prompt = f'''### Task
You are a skilled medical expert. Considering the information from a biomedical study provided in the reference text, is it correct to conclude that '{question}'? 
Please begin with a brief response of 'Yes', 'No' or 'Maybe'. Then, provide a detailed explanation in 1-2 sentences.
Finally, briefly outline your reasoning process by stating: 'Reasoning Process: ', ensuring it aligns with the findings of the study.
'''
    else:
        prompt = f'''### Task
You are a skilled medical expert. Considering the information from a biomedical study provided in the reference text, is it correct to conclude that '{question}'? 
Please begin with a brief response of 'Yes', 'No' or 'Maybe'. Then, provide a detailed explanation in 1-2 sentences.
'''


    if prompt_type == '1shot':
        prompt += f''' 
### Examples
Example Question 1: Visceral adipose tissue area measurement at a single level: can it represent visceral adipose tissue volume?
Example Reference Text 1: Measurement of visceral adipose tissue (VAT) needs to be accurate and sensitive to change for risk monitoring. The purpose of this study is to determine the CT slice location where VAT area can best reflect changes in VAT volume and body weight.", "60 plain abdominal CT images from 30 males [mean age (range) 51 (41-68) years, mean body weight (range) 71.1 (101.9-50.9) kg] who underwent workplace screenings twice within a 1-year interval were evaluated. Automatically calculated and manually corrected areas of the VAT of various scan levels using \"freeform curve\" region of interest on CT were recorded and compared with body weight changes.", "The strongest correlations of VAT area with VAT volume and body weight changes were shown in a slice 3 cm above the lower margin of L3 with r values of 0.853 and 0.902, respectively.
Example Output Text 1: Yes. VAT area measurement at a single level 3 cm above the lower margin of the L3 vertebra is feasible and can reflect changes in VAT volume and body weight. Advances in knowledge: As VAT area at a CT slice 3cm above the lower margin of L3 can best reflect interval changes in VAT volume and body weight, VAT area measurement should be selected at this location.
'''

    elif prompt_type == '3shot':
        prompt += f''' 
### Examples
Example Question 1: Visceral adipose tissue area measurement at a single level: can it represent visceral adipose tissue volume?
Example Reference Text 1: Measurement of visceral adipose tissue (VAT) needs to be accurate and sensitive to change for risk monitoring. The purpose of this study is to determine the CT slice location where VAT area can best reflect changes in VAT volume and body weight.", "60 plain abdominal CT images from 30 males [mean age (range) 51 (41-68) years, mean body weight (range) 71.1 (101.9-50.9) kg] who underwent workplace screenings twice within a 1-year interval were evaluated. Automatically calculated and manually corrected areas of the VAT of various scan levels using \"freeform curve\" region of interest on CT were recorded and compared with body weight changes.", "The strongest correlations of VAT area with VAT volume and body weight changes were shown in a slice 3 cm above the lower margin of L3 with r values of 0.853 and 0.902, respectively.
Example Output Text 1: Yes. VAT area measurement at a single level 3 cm above the lower margin of the L3 vertebra is feasible and can reflect changes in VAT volume and body weight. Advances in knowledge: As VAT area at a CT slice 3cm above the lower margin of L3 can best reflect interval changes in VAT volume and body weight, VAT area measurement should be selected at this location.

Example Question 2: Can a practicing surgeon detect early lymphedema reliably?
Example Reference Text 2: Lymphedema may be identified by simpler circumference changes as compared with changes in limb volume.", "Ninety breast cancer patients were prospectively enrolled in an academic trial, and seven upper extremity circumferences were measured quarterly for 3 years. A 10% volume increase or greater than 1 cm increase in arm circumference identified lymphedema with verification by a lymphedema specialist. Sensitivity and specificity of several different criteria for detecting lymphedema were compared using the academic trial as the standard.", "Thirty-nine cases of lymphedema were identified by the academic trial. Using a 10% increase in circumference at two sites as the criterion, half the lymphedema cases were detected (sensitivity 37%). When using a 10% increase in circumference at any site, 74.4% of cases were detected (sensitivity 49%). Detection by a 5% increase in circumference at any site was 91% sensitive.
Example Output Text 2: Maybe. An increase of 5% in circumference measurements identified the most potential lymphedema cases compared with an academic trial.

Example Question 3: It's Fournier's gangrene still dangerous?
Example Reference Text 3: Fournier's gangrene is known to have an impact in the morbidity and despite antibiotics and aggressive debridement, the mortality rate remains high.", "To assess the morbidity and mortality in the treatment of Fournier's gangrene in our experience.", "The medical records of 14 patients with Fournier's gangrene who presented at the University Hospital Center \"Mother Teresa\" from January 1997 to December 2006 were reviewed retrospectively to analyze the outcome and identify the risk factor and prognostic indicators of mortality.", "Of the 14 patients, 5 died and 9 survived. Mean age was 54 years (range from 41-61): it was 53 years in the group of survivors and 62 years in deceased group. There was a significant difference in leukocyte count between patients who survived (range 4900-17000/mm) and those died (range 20.300-31000/mm3). Mean hospital stay was about 19 days (range 2-57 days).
Example Output Text 3: Yes. The interval from the onset of clinical symptoms to the initial surgical intervention seems to be the most important prognostic factor with a significant impact on outcome. Despite extensive therapeutic efforts, Fournier's gangrene remains a surgical emergency and early recognition with prompt radical debridement is the mainstays of management.
'''

    elif prompt_type == '5shot':
        prompt += f''' 
### Examples
Example Question 1: Visceral adipose tissue area measurement at a single level: can it represent visceral adipose tissue volume?
Example Reference Text 1: Measurement of visceral adipose tissue (VAT) needs to be accurate and sensitive to change for risk monitoring. The purpose of this study is to determine the CT slice location where VAT area can best reflect changes in VAT volume and body weight.", "60 plain abdominal CT images from 30 males [mean age (range) 51 (41-68) years, mean body weight (range) 71.1 (101.9-50.9) kg] who underwent workplace screenings twice within a 1-year interval were evaluated. Automatically calculated and manually corrected areas of the VAT of various scan levels using \"freeform curve\" region of interest on CT were recorded and compared with body weight changes.", "The strongest correlations of VAT area with VAT volume and body weight changes were shown in a slice 3 cm above the lower margin of L3 with r values of 0.853 and 0.902, respectively.
Example Output Text 1: Yes. VAT area measurement at a single level 3 cm above the lower margin of the L3 vertebra is feasible and can reflect changes in VAT volume and body weight. Advances in knowledge: As VAT area at a CT slice 3cm above the lower margin of L3 can best reflect interval changes in VAT volume and body weight, VAT area measurement should be selected at this location.

Example Question 2: Can a practicing surgeon detect early lymphedema reliably?
Example Reference Text 2: Lymphedema may be identified by simpler circumference changes as compared with changes in limb volume.", "Ninety breast cancer patients were prospectively enrolled in an academic trial, and seven upper extremity circumferences were measured quarterly for 3 years. A 10% volume increase or greater than 1 cm increase in arm circumference identified lymphedema with verification by a lymphedema specialist. Sensitivity and specificity of several different criteria for detecting lymphedema were compared using the academic trial as the standard.", "Thirty-nine cases of lymphedema were identified by the academic trial. Using a 10% increase in circumference at two sites as the criterion, half the lymphedema cases were detected (sensitivity 37%). When using a 10% increase in circumference at any site, 74.4% of cases were detected (sensitivity 49%). Detection by a 5% increase in circumference at any site was 91% sensitive.
Example Output Text 2: Maybe. An increase of 5% in circumference measurements identified the most potential lymphedema cases compared with an academic trial.

Example Question 3: It's Fournier's gangrene still dangerous?
Example Reference Text 3: Fournier's gangrene is known to have an impact in the morbidity and despite antibiotics and aggressive debridement, the mortality rate remains high.", "To assess the morbidity and mortality in the treatment of Fournier's gangrene in our experience.", "The medical records of 14 patients with Fournier's gangrene who presented at the University Hospital Center \"Mother Teresa\" from January 1997 to December 2006 were reviewed retrospectively to analyze the outcome and identify the risk factor and prognostic indicators of mortality.", "Of the 14 patients, 5 died and 9 survived. Mean age was 54 years (range from 41-61): it was 53 years in the group of survivors and 62 years in deceased group. There was a significant difference in leukocyte count between patients who survived (range 4900-17000/mm) and those died (range 20.300-31000/mm3). Mean hospital stay was about 19 days (range 2-57 days).
Example Output Text 3: Yes. The interval from the onset of clinical symptoms to the initial surgical intervention seems to be the most important prognostic factor with a significant impact on outcome. Despite extensive therapeutic efforts, Fournier's gangrene remains a surgical emergency and early recognition with prompt radical debridement is the mainstays of management.

Example Question 4: Does a colonoscopy after acute diverticulitis affect its management?
Example Reference Text 4: Medical records of 220 patients hospitalized for acute diverticulitis between June 1, 2002 and September 1, 2009 were reviewed. Acute diverticulitis was diagnosed by clinical criteria and characteristic CT findings. Fifteen patients were excluded either because of questionable CT or hematochezia. Mean age was 61.8±14.3 years (61% females). Clinical parameters, laboratory results, imaging, endoscopic and histopathological reports, and long-term patients' outcome were analyzed.", "One hundred patients (aged 61.8±13.3 y, 54.1% females), underwent an early (4 to 6 wk) colonoscopy after hospital discharge. There were no significant differences in patients' characteristics or survival between those with or without colonoscopy (4±1.9 vs. 4.2±2.1 y, P=0.62). No colonic malignancy was detected. However, in 32 patients (32%) at least 1 polyp was found. Only 1 was determined as an advanced adenoma. No new or different diagnosis was made after colonoscopy.
Example Output Text 4: No. Our results suggest that colonoscopy does not affect the management of patients with acute diverticulitis nor alter the outcome. The current practice of a routine colonoscopy after acute diverticulitis, diagnosed by typical clinical symptoms and CT needs to be reevaluated.

Example Question 5: Detailed analysis of sputum and systemic inflammation in asthma phenotypes: are paucigranulocytic asthmatics really non-inflammatory?
Example Reference Text 5: The technique of induced sputum has allowed to subdivide asthma patients into inflammatory phenotypes according to their level of granulocyte airway infiltration. There are very few studies which looked at detailed sputum and blood cell counts in a large cohort of asthmatics divided into inflammatory phenotypes. The purpose of this study was to analyze sputum cell counts, blood leukocytes and systemic inflammatory markers in these phenotypes, and investigate how those groups compared with healthy subjects.", "We conducted a retrospective cross-sectional study on 833 asthmatics recruited from the University Asthma Clinic of Liege and compared them with 194 healthy subjects. Asthmatics were classified into inflammatory phenotypes.", "The total non-squamous cell count per gram of sputum was greater in mixed granulocytic and neutrophilic phenotypes as compared to eosinophilic, paucigranulocytic asthma and healthy subjects (p < 0.005). Sputum eosinophils (in absolute values and percentages) were increased in all asthma phenotypes including paucigranulocytic asthma, compared to healthy subjects (p < 0.005). Eosinophilic asthma showed higher absolute sputum neutrophil and lymphocyte counts than healthy subjects (p < 0.005), while neutrophilic asthmatics had a particularly low number of sputum macrophages and epithelial cells. All asthma phenotypes showed an increased blood leukocyte count compared to healthy subjects (p < 0.005), with paucigranulocytic asthmatics having also increased absolute blood eosinophils compared to healthy subjects (p < 0.005). Neutrophilic asthma had raised CRP and fibrinogen while eosinophilic asthma only showed raised fibrinogen compared to healthy subjects (p < 0.005).
Example Output Text 5: Maybe. This study demonstrates that a significant eosinophilic inflammation is present across all categories of asthma, and that paucigranulocytic asthma may be seen as a low grade inflammatory disease.
'''

    # Add input text to be annotated
    prompt += f'''
### Reference Text: {reference}
### Output Text:
'''

    #print(prompt)
    return prompt

In [None]:
def run(llm, prompt_type, ids, questions, contexts):
    output_path = f'./output/PubMedQA/{llm}/'
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    # Check if the file exists and delete it before starting to append new outputs
    output_file_path = os.path.join(output_path, f'test_{prompt_type}.json')
    if os.path.exists(output_file_path):
        os.remove(output_file_path)
    
    for id, question, context in zip(ids, questions, contexts): 
        success = False
        while not success:
            try:
                # Assume get_output is a function that generates the output
                prompt = create_prompt(question, context, prompt_type)
                output = get_output(prompt, llm)

                # todo: model may not generate proper answer??
                cleaned_output = ' '.join(output.split())

                # Create a json string
                output_dict = {"id": id, "answer": cleaned_output}
                json_string = json.dumps(output_dict)
                
                # Open the file in append mode ('a') to add each new output
                with open(output_file_path, 'a', encoding='utf-8') as f_write:
                    f_write.write(json_string + '\n')
                    success = True
                    
            except Exception as e:
                print(e)

In [None]:
llm = 4
prompt_type = 'base'
run(llm, prompt_type, test_id_list, test_question_list, test_context_list)

### Evaluation

In [None]:
import re
import json

# Load results from the specified directory
output_file_path = f'./output/PubMedQA/{llm}/test_{prompt_type}.json'

test_id2_list = []
test_short_answer_list = []
test_long_answer_list = []

# Open the JSON file and read line by line
with open(output_file_path, 'r', encoding='utf-8') as file:
    for line in file:
        sample = json.loads(line.strip())

        # Extract sample id
        sample_id = sample.get('id')
        test_id2_list.append(sample_id)

        # Extract the answer
        sample_answer = sample.get('answer')

        if prompt_type == 'guide':
            # If under CoT prompt, we only keep the long answer before "Reasoning Process"
            short_answer_match = re.search(r"^\s*(Yes|No|Maybe)", sample_answer, re.IGNORECASE)
            short_answer = short_answer_match.group(0) if short_answer_match else "Answer not found!"
            
            long_answer_match = re.search(r"^\s*(Yes|No|Maybe)[\s,]*(.*?)(?=\s*Reasoning Process:|$)", sample_answer, re.IGNORECASE | re.DOTALL)
            long_answer = long_answer_match.group(2).strip() if long_answer_match else "Long answer not found!"
        
        else:
            short_answer_match = re.search(r"^\s*(Yes|No|Maybe)", sample_answer, re.IGNORECASE)
            short_answer = short_answer_match.group(0) if short_answer_match else "Answer not found!"
            
            long_answer_match = re.search(r"^\s*(Yes|No|Maybe)[\s,]*(.*)", sample_answer, re.IGNORECASE | re.DOTALL)
            long_answer = long_answer_match.group(2).strip() if long_answer_match else "Long answer not found!"

        # Append the extracted answers to their respective lists
        test_short_answer_list.append(short_answer)
        test_long_answer_list.append(long_answer)

# Ensure that the IDs match between the original and extracted lists
assert test_id_list == test_id2_list


#### Accuracy for short answer

In [None]:
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.preprocessing import LabelEncoder, label_binarize  # Added label_binarize import

print("Evaluating: ", llm, prompt_type)

# Define possible options for labels
valid_labels = ['yes', 'no', 'maybe']

# Convert labels and predictions to lowercase and filter only valid options
labels = [label.lower() for label in test_shortlabel_list if label.lower() in valid_labels]
predictions = [prediction.lower() for prediction in test_short_answer_list if prediction.lower() in valid_labels]

# Ensure both lists are of the same length after filtering
if len(labels) != len(predictions):
    raise ValueError("Mismatch between the number of valid labels and predictions.")

# Convert labels to binary values using LabelEncoder
le = LabelEncoder()
le.fit(valid_labels)
labels_binary = le.transform(labels)
predictions_binary = le.transform(predictions)

# Binarize the labels and predictions
classes = [0, 1, 2] 
labels_binarized = label_binarize(labels_binary, classes=classes)
predictions_binarized = label_binarize(predictions_binary, classes=classes)

# Calculate accuracy
accuracy = accuracy_score(labels_binary, predictions_binary)
print("Accuracy:", accuracy)

try:
    # **Modification: Calculate AUC using the binarized labels and predictions**
    auc_score = roc_auc_score(labels_binarized, predictions_binarized, average='macro', multi_class='ovr')
    print("AUC Score (One-vs-Rest):", auc_score)
except ValueError as e:
    print(f"Unable to calculate AUC: {e}")


#### Rouge score for long answer

In [None]:
from rouge import Rouge

# Initialize the Rouge scoring object
rouge = Rouge()

# Prepare lists to hold scores for each metric
scores_rouge1 = []
scores_rouge2 = []
scores_rougel = []

# Iterate over each pair and calculate ROUGE scores
for idx, gold_standard, predicted_output in zip(test_id_list, test_longlabel_list, test_long_answer_list):

    try:
        # Calculate scores
        scores = rouge.get_scores(predicted_output, gold_standard, avg=False)
        
        # Append scores for each metric
        scores_rouge1.append(scores[0]['rouge-1']['f'])
        scores_rouge2.append(scores[0]['rouge-2']['f'])
        scores_rougel.append(scores[0]['rouge-l']['f'])
        
    except ValueError as e:
        # Check if the error is due to an empty hypothesis
        if "Hypothesis is empty." in str(e):
            print("Hypothesis is empty in id:", idx)

# Calculate average scores
avg_rouge1 = sum(scores_rouge1) / len(scores_rouge1)
avg_rouge2 = sum(scores_rouge2) / len(scores_rouge2)
avg_rougel = sum(scores_rougel) / len(scores_rougel)

print(f"Average ROUGE-1 Score: {avg_rouge1}")
print(f"Average ROUGE-2 Score: {avg_rouge2}")
print(f"Average ROUGE-L Score: {avg_rougel}")
