<a href="https://colab.research.google.com/github/melissatorgbi/LLM-Clinical-Guideline-Understandability/blob/main/project_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Set up

In [None]:
!pip --quiet install openai bert-score readability textstat python-dotenv boto3[crt]

In [None]:
!unzip project.zip

In [None]:
!mkdir -p .aws
!cp -r project/aws/* .aws/

In [None]:
!mv .aws/config.txt .aws/config
!mv .aws/credentials.txt .aws/credentials

In [None]:
with open( "project/aws/credentials.txt" ) as infile :
  data = infile.readlines()
aws_access_key_id = data[1].split( "aws_access_key_id = ")[1].rstrip()
aws_secret_access_key = data[2].split( "aws_secret_access_key = ")[1].rstrip()

In [None]:
import os
import pandas as pd
import readability
import textstat
import nltk
from bert_score import score
from pprint import pprint
from botocore.exceptions import ClientError
import openai
from difflib import SequenceMatcher
import re
import boto3
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
from html.parser import HTMLParser
from nltk.stem.porter import *
nltk.download('punkt')
#from evaluate import load

In [None]:
load_dotenv(find_dotenv("project/openai_api_keys.env"))

openai.api_key = os.environ.get('OPENAI_API_KEY')
openai.organization = ''


In [None]:
# setup aws comprehend api

s3_client = boto3.client(
  service_name='comprehend',
  region_name="us-east-1",
  aws_access_key_id=aws_access_key_id,
  aws_secret_access_key= aws_secret_access_key,
  )

medical_client = boto3.client(service_name='comprehendmedical',
                              region_name='eu-west-2',
                              aws_access_key_id=aws_access_key_id,
                              aws_secret_access_key= aws_secret_access_key,)#,endpoint_url=endpoint_url)


In [None]:
def get_gpt_output(prompt):
  client = OpenAI()
  response = client.chat.completions.create(
          #model = "gpt-3.5-turbo-1106",
          model = "gpt-4",
          messages=[{"role": "user", "content": prompt}],
          #max_tokens = 2500,
          )

  gpt_output = response.choices[0].message.content
  return gpt_output

# Evaluation Functions

In [None]:
def evaluate_readability(text):
    fk = {"flesch_score": textstat.flesch_kincaid_grade(text),
          "sentence_count": textstat.sentence_count(text)}
    return fk #round(fk,2)

In [None]:
def evaluate_BERT(original, new):
    P, R, F1 = score([original], [new], lang="en", verbose=False)
 #   row.append(F1.item())
    print("BERT F1 is ", round(F1.item()),2)
    return round(F1.item(),2)

## Entities

In [None]:
def find_missing(string_list, target_string):
    missing_strings = []
    for string in string_list:
        if string not in target_string:
            missing_strings.append(string)
    return missing_strings

In [None]:
def remove_stems(text):
    # Tokenize the input text into individual words
    words = nltk.word_tokenize(text)

    # Initialize the PorterStemmer
    stemmer = PorterStemmer()

    # Stem each word in the text
    stemmed_words = [stemmer.stem(word) for word in words]

    # Reconstruct the stemmed words back into a string
    stemmed_text = " ".join(stemmed_words)

    return stemmed_text

In [None]:
def remove_non_alphanumeric(text):
    # Define a regular expression pattern to match non-alphanumeric characters (letters, numbers, and % sign)
    pattern = r'[^a-zA-Z0-9\s.%]+'

    # Use re.sub() to replace the matching pattern with an empty string
    cleaned_text = re.sub(pattern, '', text)

    return cleaned_text

In [None]:
# this uses amazon comprehend to check that entities are retained in the improved versionbb
def standard_entity_check(text):
    #print("Checking this text for entities: ", text)
    entity_text = []
    #print("Detecting  general entities.")
   # response = medical_client.detect_entities_v2(Text=text)#, LanguageCode='en')
    response = s3_client.detect_entities(Text=text, LanguageCode='en')
    entities = response['Entities']
  #  print(entities)
    for entity in entities:
       # print(f'Type: {entity["Type"]}, Text: {entity["Text"]}')
        entity_text.append(entity["Text"].lower()) #make sure decapitalised
    return entity_text, entities

In [None]:
#Medical Named Entity and Relationship Extraction (NERe)
# this uses amazon comprehend to check that entities are retained in the improved versionbb
def medical_entity_check(text):
    #print("Checking this text for entities: ", text)
    entity_text = []
    #print("Detecting medical entities.")
    response = medical_client.detect_entities_v2(Text=text)
    entities = response['Entities']
  #  print(entities)
    for entity in entities:
        #print(f'Type: {entity["Type"]}, Text: {entity["Text"]}')
        entity_text.append(entity["Text"].lower()) #make sure decapitalised
    return entity_text, entities

In [None]:
def extact_all_entities(text):
    standard_ent_text, standard_entities = standard_entity_check(text)
    med_ent_text, med_entities = medical_entity_check(text)
    all_entities = standard_ent_text + med_ent_text
    return all_entities

In [None]:
def reformat_entiites(entity_list):
    entity_list = list(map(remove_non_alphanumeric, entity_list))
    entity_list = list(map(str.lower, entity_list))
    entity_list = list(map(remove_stems, entity_list))
    return entity_list

In [None]:
def evaluate_entities(original, new):
    all_ori_entities = extact_all_entities(original)
    all_new_entities = extact_all_entities(new)

    # process entities to remove stems, capitalisation and grammatical symbols
    all_ori_entities = reformat_entiites(all_ori_entities)
    all_new_entities = reformat_entiites(all_new_entities)



    missing = find_missing(all_ori_entities, all_new_entities)
    #print("OLD", all_ori_entities)
    #print("NEW", all_ori_entities)
    extra = find_missing(all_new_entities, all_ori_entities)
    return all_ori_entities, all_new_entities, missing, extra

# Implementation

In [None]:
#PROMPTS

prompt_file = open('project/readability_prompt.txt', "r")
readability_prompt = prompt_file.read()

prompt_file = open('project/formatting_prompt.txt', "r")
formatting_prompt = prompt_file.read()

questions_prompt = """You are a nurse administering IV medication,
what questions would you need to be ask a technical expert, to generate the following text?
Give the output directly as a list of numbered questions - TEXT: """

suggestions_prompt = '''You are a nurse following a guideline for administering IV medication. You need to find questions from the list provided
which are NOT already answered by the guideline you are given. Find the unanswered questions.
As well as the question, output section heading from the guideline most related to the question.
The output format will be the heading: followed by the questions.
Only output a maximum of ten questions in total across all sections.
These are the questions '''

In [None]:
!mkdir results

In [None]:
path = "/content/project/medications/"

for file in os.listdir(path):

  if file.endswith(".md"):
    md_file = open(path+file, "rb")

    original_content = str(md_file.read())
    medication = file.split('.')[0]


    #READABILITY

    gpt_readability = get_gpt_output(readability_prompt +"\n" + original_content)
    gpt_readability = gpt_readability.strip("```").strip("markdown")

    bert_similarity = evaluate_BERT(original_content, gpt_readability)
    original_metrics = evaluate_readability(original_content)
    new_metrics = evaluate_readability(gpt_readability)
    #all_ori_entities, all_gpt_entities, gpt_missing_entities, gpt_extra_entities = evaluate_entities(original_text, new_text)


    #FORMATTING

    gpt_formatting = get_gpt_output(formatting_prompt +"\n" + gpt_readability)

    f = open("results/" + medication + "_gpt" + ".md", "a")
    f.write(gpt_formatting.strip("```").strip("markdown"))
    f.close()

    #TECHNICAL QUESTIONS AND SUGGESTIONS
    gpt_questions = get_gpt_output(questions_prompt +"\n" + original_content)
    gpt_suggestions = get_gpt_output(suggestions_prompt + gpt_questions + '. This is the guideline text: ' +  original_content)

    #SAVE OUTPUTS
    f = open("results/" + medication + ".txt", "a")
    f.write(
        "BERT Similarity score: " +str(bert_similarity)
        +"\nOriginal text flesch-kincaid grade level: "+str(original_metrics['flesch_score'])
        +"\nOriginal text sentence count: "+str(original_metrics['sentence_count'])
        +"\ngpt output flesch-kincaid grade level: "+str(new_metrics['flesch_score'])
        +"\ngpt output sentence count: "+str(new_metrics['sentence_count'])
        #+"\nAll entities from original text: "+str(all_ori_entities)
        #+"\nAll entities from gpt output: "+str(all_gpt_entities)
        #+"\nMissing entities in gpt output: "+str(gpt_missing_entities)
        #+"\nAdditional entities in gpt output: "+str(gpt_extra_entities)
        +"\n\nTechnical Suggestions\n\n"+str(gpt_suggestions)
        )

    f.close()

In [None]:
!zip -r results.zip results