# Setup

In [199]:
import os, json
import pandas as pd
import yaml
import re
import sys
import logging
from copy import deepcopy
from typing import Dict, List, Optional, Union
from datetime import datetime

from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.schema import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.vectorstores import FAISS
from langchain.document_transformers import LongContextReorder
from langchain.chains import RetrievalQA
from langchain.schema import Document, BaseDocumentTransformer
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.callbacks import get_openai_callback

In [4]:
MAIN_DIR = ".."
DATA_DIR = os.path.join(MAIN_DIR, "data")
ARTIFACT_DIR = os.path.join(MAIN_DIR, "artifacts")
DOCUMENTS_DIR = os.path.join(DATA_DIR, "document_sources")
EMB_DIR = os.path.join(DATA_DIR, "emb_store")

with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)
    
os.environ["OPENAI_API_KEY"] = api_keys["OPENAI_API_KEY"]

# Helper Functions

In [200]:
def convert_prompt_to_string(prompt) -> str:
    return prompt.format(**{v: v for v in prompt.input_variables})

def generate_query(profile: str, scan: str):
    return "Patient Profile: {}\nScan ordered: {}".format(profile, scan)

def convert_doc_to_dict(doc: Union[Document, Dict]) -> Dict:
    if isinstance(doc, Document):
        json_doc = {
            "page_content": doc.page_content,
            "metadata": {
                "source": doc.metadata["source"].split("/")[-1],
                "page": doc.metadata["page"] + 1
            }
            }
    elif isinstance(doc, Dict):
        json_doc = {
            "page_content": doc["page_content"],
            "metadata": {
                "source": doc["metadata"]["source"].split("/")[-1],
                "page": doc["metadata"]["page"] + 1
            }
        }
    return json_doc

def get_experiment_logs(description: str, log_folder: str):
    logger = logging.getLogger(description)

    stream_handler = logging.StreamHandler(sys.stdout)

    if not os.path.exists(log_folder):
        os.makedirs(log_folder, exist_ok=True)

    file_handler = logging.FileHandler(filename=os.path.join(log_folder, "logfile.log"))

    formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s")
    file_handler.setFormatter(formatter)
    stream_handler.setFormatter(formatter)

    logger.setLevel(logging.INFO)
    logger.addHandler(stream_handler)
    logger.addHandler(file_handler)
    
    return logger

## Setup VectorStore

In [34]:
embeddings = OpenAIEmbeddings()
llm = ChatOpenAI(model_name="gpt-4", temperature=0, max_tokens=512)

In [37]:
vectordb_params = dict(
    chunk_size=1024,
    chunk_overlap=128,
)

emb_db_path = os.path.join(EMB_DIR, "faiss",
                           "openai_{}_{}".format(vectordb_params["chunk_size"],
                                                 vectordb_params["chunk_overlap"]))

In [None]:
# # Create document stores
# from langchain.document_loaders import PyPDFLoader
# from langchain.text_splitter import RecursiveCharacterTextSplitter

# documents = os.listdir(DOCUMENTS_DIR)

# texts = []

# for document in documents:
#     lbp_path = os.path.join(DOCUMENTS_DIR, document)
#     docs = PyPDFLoader(lbp_path).load()
#     print("Number of document pages:", len(docs))
#     text_splitter = RecursiveCharacterTextSplitter(**vectordb_params)
#     texts.extend(text_splitter.split_documents(docs))

# print("Number of text chunks:", len(texts))

# if not os.path.exists(emb_db_path):
#     os.makedirs(emb_db_path, exist_ok=True)

# docsearch = FAISS.from_documents(texts, embeddings)
# docsearch.save_local(emb_db_path)

In [38]:
docsearch = FAISS.load_local(emb_db_path, embeddings=embeddings)

# Setup Chain

In [39]:
system_template = """
You are a radiologist expert at providing imaging recommendations for patients with musculoskeletal conditions.
If you do not know an answer, just say "I dont know", do not make up an answer.
==========
TASK:
1. Extract from given PATIENT PROFILE relevant information for classification of imaging appropriateness.
Important information includes AGE, SYMPTOMS, DIAGNOSIS (IF ANY), which stage of diagnosis (INITIAL IMAGING OR NEXT STUDY).
2. Refer to the reference information given under CONTEXT to analyse the appropriate imaging recommendations given the patient profile.
3. Recommend if the image scan ordered is appropriate given the PATIENT PROFILE and CONTEXT. If the scan is not appropriate, recommend an appropriate procedure.
STRICTLY answer based on the given PATIENT PROFILE and CONTEXT.
==========
OUTPUT INSTRUCTIONS:
Your output should contain the following:
1. Classification of appropriateness for the ordered scan.
2. Provide explanation for the appropriateness classification.
3. If classification answer is USUALLY NOT APPROPRIATE, either recommend an alternative appropriate scan procedure or return NO SCAN REQUIRED.

Format your output as follow:
1. Classification: Can be one of [USUALLY APPROPRIATE, MAY BE APPROPRIATE, USUALLY NOT APPROPRIATE, INSUFFICIENT INFORMATION]
2. Explanation:
3. Recommendation: Can be alternative procedure, NO SCAN REQUIRED or NO CHANGE REQUIRED 
==========
CONTEXT:
{context}
==========
"""

human_template = "{question}"

PROMPT_TEMPLATE = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate.from_template(system_template),
        HumanMessagePromptTemplate.from_template(human_template)
    ]
)

print(PROMPT_TEMPLATE.format(context="context", question="question"))

System: 
You are a radiologist expert at providing imaging recommendations for patients with musculoskeletal conditions.
If you do not know an answer, just say "I dont know", do not make up an answer.
TASK:
1. Extract from given PATIENT PROFILE relevant information for classification of imaging appropriateness.
Important information includes AGE, SYMPTOMS, DIAGNOSIS (IF ANY), which stage of diagnosis (INITIAL IMAGING OR NEXT STUDY).
2. Refer to the reference information given under CONTEXT to analyse the appropriate imaging recommendations given the patient profile.
3. Recommend if the image scan ordered is appropriate given the PATIENT PROFILE and CONTEXT. If the scan is not appropriate, recommend an appropriate procedure.
STRICTLY answer based on the given PATIENT PROFILE and CONTEXT.
OUTPUT INSTRUCTIONS:
Your output should contain the following:
1. Classification of appropriateness for the ordered scan.
2. Provide explanation for the appropriateness classification.
3. If classificat

In [89]:
class ReOrderQARetrieval(RetrievalQA):
    
    reorder_fn: Optional[BaseDocumentTransformer] = None
    
    def _get_docs(
        self,
        question: str,
        *,
        run_manager: CallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs."""
        docs = self.retriever.get_relevant_documents(
            question, callbacks=run_manager.get_child()
        )
        
        docs = self.reorder_fn.transform_documents(docs) if self.reorder_fn else docs
     
        return docs

In [41]:
qa_chain = ReOrderQARetrieval.from_chain_type(
    llm=llm,
    retriever=docsearch.as_retriever(search_kwargs={"k": 5}),
    reorder_fn = LongContextReorder(),
    chain_type="stuff",
    chain_type_kwargs=dict(
        document_variable_name = "context",
        prompt=PROMPT_TEMPLATE
    ),
    input_key="question",
    return_source_documents = True,
    verbose=True
)

# Run Test Cases

In [201]:
settings = dict(
    llm_model="gpt=4",
    emb_model="text-embeddings-ada-2",
    framework="langchain",
    prompt=convert_prompt_to_string(PROMPT_TEMPLATE),
    chunk_size=1024,
    chunk_overlap=128,
    description="RawVectorSearch",
    max_tokens=512,
    max_tokens_limit="default",
    k=5,
    chain_type="stuff"
)

save_folder = os.path.join(
    ARTIFACT_DIR,
    "{}_{}_{}_{}_{}".format(
        "gpt-4",
        settings["chunk_size"],
        settings["chunk_overlap"],
        settings["description"],
        datetime.now().strftime("%d-%m-%Y-%H-%M")
    )
)

if not os.path.exists(save_folder):
    os.makedirs(save_folder)

LOGGER = get_experiment_logs(settings["description"], log_folder=save_folder)
LOGGER.info(
    "Experiment settings:\n{}".format(
        "\n".join([f"{k}:{v}" for k, v in settings.items()])
    )
)

with open(os.path.join(save_folder, "settings.yaml"), "w") as f:
    yaml.dump(settings, f)

2023-10-13 19:33:57,447:INFO: Experiment settings:
llm_model:gpt=4
emb_model:text-embeddings-ada-2
framework:langchain
prompt:System: 
You are a radiologist expert at providing imaging recommendations for patients with musculoskeletal conditions.
If you do not know an answer, just say "I dont know", do not make up an answer.
TASK:
1. Extract from given PATIENT PROFILE relevant information for classification of imaging appropriateness.
Important information includes AGE, SYMPTOMS, DIAGNOSIS (IF ANY), which stage of diagnosis (INITIAL IMAGING OR NEXT STUDY).
2. Refer to the reference information given under CONTEXT to analyse the appropriate imaging recommendations given the patient profile.
3. Recommend if the image scan ordered is appropriate given the PATIENT PROFILE and CONTEXT. If the scan is not appropriate, recommend an appropriate procedure.
STRICTLY answer based on the given PATIENT PROFILE and CONTEXT.
OUTPUT INSTRUCTIONS:
Your output should contain the following:
1. Classifica

In [91]:
testcase_df = pd.read_csv(os.path.join(DATA_DIR, "queries", "MSK LLM Fictitious Case Files Full.csv"))
patient_profiles = testcase_df["Clinical File"]
scan_orders = testcase_df["MRI scan ordered"]
testcase_df["queries"] = [generate_query(patient_profile, scan_order)
        for patient_profile, scan_order in zip(patient_profiles, scan_orders)]

LOGGER.info(f"Number of test cases: {len(testcase_df)}")

Number of test cases: 71


In [None]:
responses = []
total_tokens, total_cost, prompt_tokens, completion_tokens = 0, 0, 0, 0

for test_case in testcase_df["queries"] :
    with get_openai_callback() as cb:
        response = qa_chain(test_case)
    responses.append(response)
    total_tokens += cb.total_tokens
    total_cost += cb.total_cost
    prompt_tokens += cb.prompt_tokens
    completion_tokens += cb.completion_tokens
    
LOGGER.info("Tokens Consumption: Total: {}, Prompt: {}, Completion: {}\nTotal Cost (USD): {}"
            .format(total_tokens, prompt_tokens, completion_tokens, total_cost))

In [138]:
json_responses = deepcopy(responses)

for response in json_responses:
    response["source_documents"] = [convert_doc_to_dict(doc) for doc in response["source_documents"]]

with open(os.path.join(save_folder, "results.json"), "w") as f:
    json.dump(json_responses, f)

# Result Processing

In [172]:
raw_answers = []
classifications = []
recommendations = []
retrieved_contexts = []

for response in json_responses:
    classification = re.search(r"Classification: ([A-Z ]+)\n", response["result"]).group(1)
    recommendation = re.search(r"Recommendation: (.+)$", response["result"]).group(1)
    extracted_texts = [
        "- Source: {}, page {}\n- Page Content: {}".format(
            doc["metadata"]["source"], doc["metadata"]["page"], doc["page_content"])
        for doc in response["source_documents"]
    ]
    combined_texts = "\n\n".join(extracted_texts)    
    
    raw_answers.append(response["result"])
    classifications.append(classification)
    recommendations.append(recommendation)
    retrieved_contexts.append(combined_texts)

In [178]:
result_df = testcase_df[['Clinical File', 'MRI scan ordered',
                        'Difficulty', 'queries', 'Appropriateness Category', ]]

result_df = result_df.rename(columns = {"Appropriateness Category": "human_gt"})
result_df["gpt_raw_answer"] = raw_answers
result_df["gpt4_classification"] = classifications
result_df["gpt4_recommendation"] = recommendations
result_df["retrieved_context"] = retrieved_contexts
result_df["human_gt"] = result_df["human_gt"].str.replace(r"^UA$", "USUALLY APPROPRIATE", regex=True)
result_df["human_gt"] = result_df["human_gt"].str.replace(r"^UNA$", "USUALLY NOT APPROPRIATE", regex=True)
result_df["human_gt"] = result_df["human_gt"].str.replace(r"^MBA$", "MAY BE APPROPRIATE", regex=True)
result_df["human_gt"] = result_df["human_gt"].str.replace(r"^ICI$", "INSUFFICIENT INFORMATION", regex=True)

In [218]:
result_df.to_csv(os.path.join(save_folder, "results.csv"), header=True)