# Setup

In [263]:
import os, json
import pandas as pd
import yaml
from copy import deepcopy

import re
import sys
import logging
from typing import Dict, Union, Optional, List
from datetime import datetime
from config import MAIN_DIR

from llama_index.vector_stores import SimpleVectorStore
from llama_index import ServiceContext
from llama_index.storage.storage_context import StorageContext
from llama_index.llms import OpenAI
from llama_index.embeddings import OpenAIEmbedding
from llama_index.response.schema import Response
from llama_index.schema import Document, NodeWithScore
from llama_index import load_index_from_storage

In [55]:
DATA_DIR = os.path.join(MAIN_DIR, "data")
ARTIFACT_DIR = os.path.join(MAIN_DIR, "artifacts")
EMB_DIR = os.path.join(DATA_DIR, "emb_store")

with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)
    
os.environ["OPENAI_API_KEY"] = api_keys["OPENAI_API_KEY"]

# Helper Functions

In [None]:
fix_prompt_template = """TASK: Extract the following information from the provided text query.
1. Appropropriateness of the scan ordered.
2. Most Appropriate Imaging Modality
===============
FORMAT INSTRUCTIONS: Your output should contains the following:
Appropriateness: Can be one of [USUALLY APPROPRIATE, MAY BE APPROPRIATE, USUALLY NOT APPROPRIATE, INSUFFICIENT INFORMATION]
Recommendation: The most appropriate imaging modality
===============
TEXT QUERY: {query}
"""

from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI

FIX_PROMPT = PromptTemplate.from_template(fix_prompt_template)

fixing_chain = LLMChain(
    llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, max_tokens=512),
    prompt=FIX_PROMPT
)

In [281]:
def convert_prompt_to_string(prompt) -> str:
    return prompt.format(**{v: v for v in prompt.template_vars})

def generate_query(profile: str, scan: str):
    return "Patient Profile: {}\nScan ordered: {}".format(profile, scan)

def convert_doc_to_dict(doc: Union[Document, NodeWithScore, Dict]) -> Dict:
    if isinstance(doc, Union[Document, NodeWithScore]):
        json_doc = {
            "page_content": doc.text,
            "metadata": {
                "source": doc.metadata["file_name"],
                "page": doc.metadata["page_label"]
            }
            }
    elif isinstance(doc, Dict):
        json_doc = {
            "page_content": doc["text"],
            "metadata": {
                "source": doc["metadata"]["file_name"],
                "page": doc["metadata"]["page_label"]
            }
        }
    return json_doc

def get_experiment_logs(description: str, log_folder: str):
    logger = logging.getLogger(description)

    stream_handler = logging.StreamHandler(sys.stdout)

    if not os.path.exists(log_folder):
        os.makedirs(log_folder, exist_ok=True)

    file_handler = logging.FileHandler(filename=os.path.join(log_folder, "logfile.log"))

    formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s")
    file_handler.setFormatter(formatter)
    stream_handler.setFormatter(formatter)

    logger.setLevel(logging.INFO)
    logger.addHandler(stream_handler)
    logger.addHandler(file_handler)
    
    return logger

def remove_final_sentence(
    text: str
):
    if text.endswith("."):
        text = text[:-1]
    sentence_list = text.split(".")
    return ".".join(sentence_list[:-1])

def query_wrapper(
    template: str, 
    input_text: Union[str, Dict[str, str]]
) -> str:
    placeholders = re.findall(pattern = r"{([A-Za-z0-9_-]+)}", string=template)
    if isinstance(input_text, str):
        assert len(placeholders) == 1, "Must Provide a single placeholder when input_text is string."
        placeholder = placeholders[0]
        return template.format(**{placeholder:input_text})
    
    assert len(input_text) == len(placeholders)
    for key in input_text.keys():
        assert key in placeholders, f"{key} not present in template."
    
    return template.format(**input_text)

def preprocess_profile(
    profile: str
):
    if profile.endswith("."):
        profile = profile[:-1]
    sentences = profile.split(".")
    return ".".join(sentences[:-1])

def process_result_json(
    testcase_df: pd.DataFrame,
    responses: List[Response],
    save_path: Optional[str] = None
):
    json_responses = []
    queries = testcase_df["queries"]
    scan_orders = testcase_df["MRI scan ordered"]
    
    for query, response, scan_order in zip(queries, responses, scan_orders):
        testcase_info = {
            "question": query,
            "result": response.response,
            "source_documents": [convert_doc_to_dict(doc) for doc in response.source_nodes]
        }
        answer_query = "Scan Ordered: {}\nAnswer: {}".format(scan_order, testcase_info["result"])
        fixed_answer = fixing_chain(answer_query)
        appropriateness, recommendation = re.findall(
            r"^Appropriateness: ([0-9A-Za-z ]+)\nRecommendation: ([0-9A-Za-z \.]+)$", fixed_answer["text"])[0]
        testcase_info["appropriateness"] = appropriateness
        testcase_info["recommendation"] = recommendation

        json_responses.append(testcase_info)
        
    if save_path:
        with open(save_path, "w") as f:
            json.dump(json_responses, f)
    return json_responses

def process_result_df(
    testcase_df: pd.DataFrame, results: Union[List[Dict], List[Response]], save_path: Optional[str] = None
):
    if isinstance(results[0], Response):
        results = process_result_json(testcase_df, results)
    
    result_df = deepcopy(testcase_df)
    result_df["gpt_raw_answer"] = [response["result"] for response in results]
    result_df["gpt_classification"] = [response["appropriateness"] for response in results]
    result_df["gpt_classification"] = result_df["gpt_classification"].str.upper()
    result_df["gpt_recommendation"] = [response["recommendation"] for response in results]
    result_df["context"] = [
        "\n\n".join(["Guideline: {}, Page: {}\nPage Content: {}".format(
            document["metadata"]["source"], document["metadata"]["page"],
            document["page_content"]) for document in response["source_documents"]]
                ) for response in results
    ]

    result_df = result_df.rename(columns = {"Appropriateness Category": "human_gt"})

    result_df["human_gt"] = result_df["human_gt"].str.replace(r"^UA$", "USUALLY APPROPRIATE", regex=True)
    result_df["human_gt"] = result_df["human_gt"].str.replace(r"^UNA$", "USUALLY NOT APPROPRIATE", regex=True)
    result_df["human_gt"] = result_df["human_gt"].str.replace(r"^MBA$", "MAY BE APPROPRIATE", regex=True)
    result_df["human_gt"] = result_df["human_gt"].str.replace(r"^ICI$", "INSUFFICIENT INFORMATION", regex=True)
    
    result_df["match"] = (result_df["gpt_classification"] == result_df["human_gt"])

    if save_path:
        result_df.to_csv(save_path)

    return result_df

# Load Test Data

In [235]:
testcase_df = pd.read_csv(
        os.path.join(DATA_DIR, "queries", "MSK LLM Fictitious Case Files Full.csv"),
        usecols = ['ACR scenario', 'Appropriateness Category', 'MRI scan ordered',
                   'Difficulty', 'Clinical File']
        )
patient_profiles = testcase_df["Clinical File"]
scan_orders = testcase_df["MRI scan ordered"]

# Baseline Experiment (Rau et al 2023)

## Knowledge Base Setup

In [58]:
# Retrieval 
emb_type = "openai"
chunk_size = 512
chunk_overlap = 20
db_type = "simple"
similarity_top_k = 3
index_id = "msk-mri"

# Generation
model_name = "gpt-3.5-turbo"
max_tokens = 512
temperature = 0
response_mode = "compact"

In [59]:
db_directory = os.path.join(EMB_DIR, db_type, f"{emb_type}_{chunk_size}_{chunk_overlap}")
print(db_directory)
vector_store = SimpleVectorStore.from_persist_dir(db_directory)
storage_context = StorageContext.from_defaults(
    persist_dir = db_directory
    )
vector_index = load_index_from_storage(
    storage_context=storage_context,
    index_id=index_id
    )

../data/emb_store/simple/openai_512_20


In [109]:
import tiktoken
from llama_index import get_response_synthesizer
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.callbacks import CallbackManager, TokenCountingHandler

token_counter = TokenCountingHandler(
    tokenizer=tiktoken.encoding_for_model(model_name).encode
)
callback_manager = CallbackManager([token_counter])

llm = OpenAI(
    temperature=temperature, model=model_name, max_tokens=max_tokens
    )
embs = OpenAIEmbedding()
service_context = ServiceContext.from_defaults(
    llm=llm, embed_model=embs, callback_manager=callback_manager
)

retriever = VectorIndexRetriever(
    index = vector_index, similarity_top_k=similarity_top_k
)

response_synthesizer = get_response_synthesizer(
    service_context=service_context, response_mode=response_mode,
)

query_engine = RetrieverQueryEngine(
    retriever=retriever, response_synthesizer=response_synthesizer
)

## Prompt Define

In [147]:
rau_prompt_template = (
    "Case: {input_text}\n"
    "Scan Ordered: {scan_order}\n"
    "Question: Is this imaging modality for this case USUALLY APPROPRIATE, "
    "MAY BE APPROPRIATE, USUALLY NOT APPROPRIATE or INSUFFICIENT INFORMATION. "
    "Then state precisely the most appropriate imaging modality and if contrast "
    "agent is needed"
    )

## Testcases (With ending scan)

In [148]:
testcase_df["queries"] = [
    query_wrapper(rau_prompt_template, {"input_text": patient_profile, "scan_order": scan_order})
    for patient_profile, scan_order in zip(patient_profiles, scan_orders)
    ]

In [None]:
settings = dict(
    llm_model=model_name,
    emb_model=emb_type,
    framework="llama_index",
    prompt=rau_prompt_template,
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap,
    description="BaselineExperimentRau2023",
    max_tokens=temperature,
    k=similarity_top_k,
    chain_type="compact",
    db_type=db_type,
    temperature=temperature
)

save_folder = os.path.join(
    ARTIFACT_DIR,
    "{}_{}_{}_{}_{}".format(
        settings["llm_model"],
        settings["chunk_size"],
        settings["chunk_overlap"],
        settings["description"],
        datetime.now().strftime("%d-%m-%Y-%H-%M")
    )
)

if not os.path.exists(save_folder):
    os.makedirs(save_folder)

logger = get_experiment_logs(settings["description"], log_folder=save_folder)
logger.info(
    "Experiment settings:\n{}".format(
        "\n".join([f"{k}:{v}" for k, v in settings.items()])
    )
)

with open(os.path.join(save_folder, "settings.yaml"), "w") as f:
    yaml.dump(settings, f)

In [153]:
token_counter.reset_counts()
responses = []

for test_case in testcase_df["queries"]:
    response = query_engine.query(test_case)
    responses.append(response)
    
logger.info("Tokens Consumption: Total: {}, Prompt: {}, Completion: {}, Embeddings: {}"
            .format(token_counter.total_llm_token_count,
                    token_counter.prompt_llm_token_count,
                    token_counter.completion_llm_token_count,
                    token_counter.embedding_token_counts))

2023-10-18 22:37:50,508:INFO: Tokens Consumption: Total: 99257, Prompt: 96913, Completion: 2344, Embeddings: []
2023-10-18 22:37:50,508:INFO: Tokens Consumption: Total: 99257, Prompt: 96913, Completion: 2344, Embeddings: []


In [None]:
json_responses = process_result_json(
    testcase_df=testcase_df, responses=responses, save_path=os.path.join(save_folder, "results.json")
)

In [257]:
rau_result_df = process_result_df(
    testcase_df, json_responses, save_path=os.path.join(save_folder, "result.csv")
    )
rau_result_df.to_csv()

### Quick Evaluation

In [258]:
accuracy = rau_result_df["match"].sum() / len(rau_result_df) * 100

print(f"Accuracy score: {accuracy}")

Accuracy score: 32.3943661971831


In [259]:
rau_result_df.groupby("human_gt")["match"].value_counts()

human_gt                  match
ICI                       False     1
INSUFFICIENT INFORMATION  False     9
MAY BE APPROPRIATE        False    14
USUALLY APPROPRIATE       True     19
                          False     1
USUALLY NOT APPROPRIATE   False    23
                          True      4
Name: match, dtype: int64

In [260]:
rau_result_df["gpt_classification"].value_counts()

USUALLY APPROPRIATE        65
USUALLY NOT APPROPRIATE     6
Name: gpt_classification, dtype: int64

## Testcases (Without ending scan)

In [270]:
testcase_df["queries"] = [
    query_wrapper(rau_prompt_template, {"input_text": remove_final_sentence(patient_profile),
                                        "scan_order": scan_order})
    for patient_profile, scan_order in zip(patient_profiles, scan_orders)
    ]

In [271]:
settings = dict(
    llm_model=model_name,
    emb_model=emb_type,
    framework="llama_index",
    prompt=rau_prompt_template,
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap,
    description="BaselineExperimentRau2023NoEndingMRIInfo",
    max_tokens=temperature,
    k=similarity_top_k,
    chain_type="compact",
    db_type=db_type,
    temperature=temperature
)

save_folder = os.path.join(
    ARTIFACT_DIR,
    "{}_{}_{}_{}_{}".format(
        settings["llm_model"],
        settings["chunk_size"],
        settings["chunk_overlap"],
        settings["description"],
        datetime.now().strftime("%d-%m-%Y-%H-%M")
    )
)

if not os.path.exists(save_folder):
    os.makedirs(save_folder)

logger = get_experiment_logs(settings["description"], log_folder=save_folder)
logger.info(
    "Experiment settings:\n{}".format(
        "\n".join([f"{k}:{v}" for k, v in settings.items()])
    )
)

with open(os.path.join(save_folder, "settings.yaml"), "w") as f:
    yaml.dump(settings, f)

2023-10-19 16:43:15,032:INFO: Experiment settings:
llm_model:gpt-3.5-turbo
emb_model:openai
framework:llama_index
prompt:Case: {input_text}
Scan Ordered: {scan_order}
Question: Is this imaging modality for this case USUALLY APPROPRIATE, MAY BE APPROPRIATE, USUALLY NOT APPROPRIATE or INSUFFICIENT INFORMATION. Then state precisely the most appropriate imaging modality and if contrast agent is needed
chunk_size:512
chunk_overlap:20
description:BaselineExperimentRau2023NoEndingMRIInfo
max_tokens:0
k:3
chain_type:compact
db_type:simple
temperature:0


In [272]:
token_counter.reset_counts()
responses = []

for test_case in testcase_df["queries"]:
    response = query_engine.query(test_case)
    responses.append(response)
    
logger.info("Tokens Consumption: Total: {}, Prompt: {}, Completion: {}, Embeddings: {}"
            .format(token_counter.total_llm_token_count,
                    token_counter.prompt_llm_token_count,
                    token_counter.completion_llm_token_count,
                    token_counter.embedding_token_counts))

2023-10-19 16:46:26,464:INFO: Tokens Consumption: Total: 99283, Prompt: 96997, Completion: 2286, Embeddings: []


In [273]:
json_responses = process_result_json(
    testcase_df=testcase_df, responses=responses, save_path=os.path.join(save_folder, "results.json")
)

Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised Timeout: Request timed out: HTTPSConnectionPool(host='api.openai.com', port=443): Read timed out. (read timeout=600).


In [282]:
rau_result_df_2 = process_result_df(
    testcase_df, json_responses, save_path=os.path.join(save_folder, "result.csv")
    )

### Quick Evaluation

In [287]:
accuracy = rau_result_df_2["match"].sum() / len(rau_result_df_2) * 100

print(f"Accuracy score: {accuracy}")

Accuracy score: 30.985915492957744


In [288]:
rau_result_df_2.groupby("human_gt")["match"].value_counts()

human_gt                  match
ICI                       False     1
INSUFFICIENT INFORMATION  False     9
MAY BE APPROPRIATE        False    14
USUALLY APPROPRIATE       True     19
                          False     1
USUALLY NOT APPROPRIATE   False    24
                          True      3
Name: match, dtype: int64

In [286]:
rau_result_df_2["gpt_classification"].value_counts()

USUALLY APPROPRIATE        65
USUALLY NOT APPROPRIATE     6
Name: gpt_classification, dtype: int64

# Run 1

## Knowledge Base Setup

In [306]:
# Retrieval 
emb_type = "openai"
chunk_size = 512
chunk_overlap = 20
db_type = "faiss"
similarity_top_k = 3
index_id = "msk-mri"

# Generation
model_name = "gpt-3.5-turbo"
max_tokens = 512
temperature = 0
response_mode = "compact"

## Prompt Define

In [308]:
from llama_index.llms import ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate

system_template = """
You are a radiologist expert at providing imaging recommendations for patients with musculoskeletal conditions.
If you do not know an answer, just say "I dont know", do not make up an answer.
==========
TASK:
1. Extract from given PATIENT PROFILE relevant information for classification of imaging appropriateness.
Important information includes AGE, SYMPTOMS, DIAGNOSIS (IF ANY), which stage of diagnosis (INITIAL IMAGING OR NEXT STUDY).
2. Refer to the reference information given under CONTEXT to analyse the appropriate imaging recommendations given the patient profile.
3. Given the PATIENT PROFILE and CONTEXT, refer to the SCORING CRITERIA and recommend if the image scan ordered is USUALLY APPROPRIATE, MAY BE APPROPRIATE, USUALLY NOT APPROPRIATE or there is INSUFFICIENT INFORMATION to recommend the appropriateness.  
If the scan is not appropriate, recommend an appropriate procedure.

STRICTLY answer based on the given PATIENT PROFILE and CONTEXT. 
==========
SCORING CRITERIA:
- USUALLY APPROPRIATE: The imaging procedure or treatment is indicated in the specified clinical scenarios at a favorable risk-benefit ratio for patients.
- MAY BE APPROPRIATE: The imaging procedure or treatment may be indicated in the specified clinical scenarios as an alternative to imaging procedures or treatments with a more favorable risk-benefit ratio, or the risk-benefit ratio for patients is equivocal.
- USUALLY NOT APPROPRIATE: The imaging procedure or treatment is unlikely to be indicated in the specified clinical scenarios, or the risk-benefit ratio for patients is likely to be unfavorable.
- INSUFFICIENT INFORMATION: There is not enough information from PATIENT PROFILE and CONTEXT information to conclude the appropriateness
==========
OUTPUT INSTRUCTIONS:
Your output should contain the following:
1. Classification of appropriateness for the ordered scan.
2. Provide explanation for the appropriateness classification.
3. If classification answer is USUALLY NOT APPROPRIATE, either recommend an alternative appropriate scan procedure or return NO SCAN REQUIRED.

Format your output as follow:
1. Classification: Can be one of [USUALLY APPROPRIATE, MAY BE APPROPRIATE, USUALLY NOT APPROPRIATE, INSUFFICIENT INFORMATION]
2. Explanation:
3. Recommendation: Can be alternative procedure, NO SCAN REQUIRED or NO CHANGE REQUIRED 
==========
CONTEXT:

{context_str}
==========
"""

human_template = "{query_str}"
messages = [
    ChatMessage(role=MessageRole.SYSTEM, content=system_template),
    ChatMessage(role=MessageRole.USER, content=human_template)   
]

CHAT_PROMPT_TEMPLATE = ChatPromptTemplate(messages)

## Query Engine Setup

In [None]:
import tiktoken
from llama_index import get_response_synthesizer
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.callbacks import CallbackManager, TokenCountingHandler
from llama_index.vector_stores import FaissVectorStore

token_counter = TokenCountingHandler(
    tokenizer=tiktoken.encoding_for_model(model_name).encode
)
callback_manager = CallbackManager([token_counter])

# Service Context
llm = OpenAI(temperature=temperature, model=model_name, max_tokens=max_tokens)
embs = OpenAIEmbedding()
service_context = ServiceContext.from_defaults(
    llm=llm, embed_model=embs, callback_manager=callback_manager
)

# Retriever
db_directory = os.path.join(EMB_DIR, db_type, f"{emb_type}_{chunk_size}_{chunk_overlap}")
print(db_directory)
vector_store = FaissVectorStore.from_persist_dir(db_directory)
storage_context = StorageContext.from_defaults(
    vector_store=vector_store, persist_dir=db_directory)
vector_index = load_index_from_storage(storage_context=storage_context)
retriever = VectorIndexRetriever(
    index = vector_index, similarity_top_k=similarity_top_k
)

# Synthesizer
response_synthesizer = get_response_synthesizer(
    service_context=service_context, response_mode=response_mode,
)

# Node Reorder

# Query Engine
query_engine = RetrieverQueryEngine(
    retriever=retriever, response_synthesizer=response_synthesizer
)

../data/emb_store/faiss/openai_512_20
