# Setup

In [1]:
import logging
import os, json
import pandas as pd
import re
import sys
import tiktoken
import yaml
from pprint import pprint

from config import MAIN_DIR, GUIDELINES
from copy import deepcopy
from utils import load_vectorindex
from datetime import datetime
from textdistance import levenshtein
from tqdm import tqdm
from typing import Dict, Union, Optional, List, Literal, Callable, Sequence, Tuple
from utils import count_tokens

from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate as LCChatPromptTemplate
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate

from llama_index import (ServiceContext, QueryBundle, SimpleDirectoryReader, get_response_synthesizer)
from llama_index.callbacks import CallbackManager, TokenCountingHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.embeddings import OpenAIEmbedding
from llama_index.indices.base_retriever import BaseRetriever
from llama_index.indices.postprocessor import MetadataReplacementPostProcessor, LongContextReorder
from llama_index.indices.postprocessor.types import BaseNodePostprocessor
from llama_index.indices.query.base import BaseQueryEngine
from llama_index.indices.query.schema import QueryBundle, QueryType
from llama_index.llms import ChatMessage, MessageRole, OpenAI
from llama_index.prompts import ChatPromptTemplate, BasePromptTemplate
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response.schema import Response, RESPONSE_TYPE
from llama_index.retrievers import VectorIndexRetriever
from llama_index.schema import Document, NodeWithScore, TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore

from custom import CustomCombinedRetriever, CustomRetrieverQueryEngine

In [2]:
DATA_DIR = os.path.join(MAIN_DIR, "data")
ARTIFACT_DIR = os.path.join(MAIN_DIR, "artifacts")
EMB_DIR = os.path.join(DATA_DIR, "emb_store")
DOCUMENT_DIR = os.path.join(MAIN_DIR, "data", "document_sources")
EXCLUDE_DICT = os.path.join(DATA_DIR, "exclude_pages.json")

with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)
    
os.environ["OPENAI_API_KEY"] = api_keys["OPENAI_API_KEY"]

# Helper Functions

In [3]:
all_guidelines = "\n".join(["- " + guideline for guideline in GUIDELINES])

refine_template = """You are a radiologist expert. Do not make up additional information.
=========
TASK: You are given a PATIENT PROFILE. You need to perform the following information referencing from the PATIENT PROFILE:
1. Extract relevant information for recommendation of imaging procedure, including age, symptomps, previous diagnosis, stage of diagnosis (INITIAL IMAGING OR NEXT STUDY) and suspected conditions, if any.
Only return information given inside the PROFILE, do not make up other information.
2. Return one or more guidelines from the following list of guidelines potentially relevant to the recommendations of imaging procedure given patient profile. If there are no relevant guidelines, return empty list.
{}
=========
OUTPUT INSTRUCTION:
Output your answer as follow:
1. Relevant information:
2. Relevant guidelines: List of guidelines. Match the exact text given in the list. If no relevant guidelines, return [].
=========
""".format(all_guidelines)

human_template = "PATIENT PROFILE: {query_str}"

REFINE_TEMPLATE = LCChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate.from_template(refine_template),
        HumanMessagePromptTemplate.from_template(human_template)
    ]
)

refine_chain = LLMChain(
    prompt=REFINE_TEMPLATE,
    llm=ChatOpenAI(model="gpt-4-1106-preview", temperature=0, max_tokens=512)
    )

extract_template = """TASK: Extract the following information from the provided text query.
1. Appropropriateness of the scan ordered.
2. Most Appropriate Imaging Modality
===============
FORMAT INSTRUCTIONS: Your output should contains the following:
Appropriateness: Can be one of [USUALLY APPROPRIATE, MAY BE APPROPRIATE, USUALLY NOT APPROPRIATE, INSUFFICIENT INFORMATION]
Recommendation: The most appropriate imaging modality. If no appropriate imaging modality, return nothing.
===============
TEXT QUERY: {query}
"""

from langchain.prompts import PromptTemplate as LCPromptTemplate
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI as LCChatOpenAI

FIX_PROMPT = LCPromptTemplate.from_template(extract_template)

fixing_chain = LLMChain(
    llm=LCChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0, max_tokens=512),
    prompt=FIX_PROMPT
)

In [4]:
def filter_by_pages(
    doc_list: List[Document],
    exclude_info: Dict[str, List]
) -> List[Document]:
    filtered_list = []
    for doc in doc_list:
        file_name = doc.metadata["file_name"]
        page = doc.metadata["page_label"]
        if file_name not in exclude_info.keys():
            filtered_list.append(doc)
            continue
        if int(page) not in exclude_info[file_name]:
            filtered_list.append(doc)

    return filtered_list

def convert_prompt_to_string(prompt) -> str:
    if isinstance(prompt, BasePromptTemplate):
        return prompt.format(**{v: v for v in prompt.template_vars})
    if isinstance(prompt, Union[LCPromptTemplate, LCChatPromptTemplate]):
        return prompt.format(**{v: v for v in prompt.input_variables})

def generate_query(profile: str, scan: str):
    return "Patient Profile: {}\nScan ordered: {}".format(profile, scan)

def convert_doc_to_dict(doc: Union[Document, NodeWithScore, Dict]) -> Dict:
    if isinstance(doc, NodeWithScore):
        json_doc = {
            "page_content": doc.text,
            "metadata": doc.metadata,
            "score": doc.score
            } 
    elif isinstance(doc, Document):
        json_doc = {
            "page_content": doc.text,
            "metadata": doc.metadata,
            "score": ""
            }
    elif isinstance(doc, Dict):
        json_doc = {
            "page_content": doc["text"],
            "metadata": doc["metadata"],
            "score": "None"
        }
    return json_doc

def get_experiment_logs(description: str, log_folder: str):
    logger = logging.getLogger(description)

    stream_handler = logging.StreamHandler(sys.stdout)

    if not os.path.exists(log_folder):
        os.makedirs(log_folder, exist_ok=True)

    file_handler = logging.FileHandler(filename=os.path.join(log_folder, "logfile.log"))

    formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s")
    file_handler.setFormatter(formatter)
    stream_handler.setFormatter(formatter)

    logger.setLevel(logging.INFO)
    logger.addHandler(stream_handler)
    logger.addHandler(file_handler)
    
    return logger

def remove_final_sentence(
    text: str,
    return_final_sentence: bool = False
):
    text = text.strip()
    if text.endswith("."):
        text = text[:-1]
    sentence_list = text.split(".")
    previous_text = ".".join(sentence_list[:-1])
    final_sentence = sentence_list[-1]
    return (previous_text, final_sentence) if return_final_sentence else previous_text

def query_wrapper(
    template: str, 
    input_text: Union[str, Dict[str, str]]
) -> str:
    placeholders = re.findall(pattern = r"{([A-Za-z0-9_-]+)}", string=template)
    if isinstance(input_text, str):
        assert len(placeholders) == 1, "Must Provide a single placeholder when input_text is string."
        placeholder = placeholders[0]
        return template.format(**{placeholder:input_text})
    
    assert len(input_text) == len(placeholders)
    for key in input_text.keys():
        assert key in placeholders, f"{key} not present in template."
    
    return template.format(**input_text)

def calculate_min_dist(
    input_str: str,
    text_list: List[str] = GUIDELINES,
    return_nearest_text: bool = False
):
    min_dist = float("inf")
    nearest_text = None

    for ref_text in text_list:
        dist = levenshtein.distance(input_str, ref_text)
        if dist < min_dist:
            min_dist = dist
            nearest_text = ref_text
    return (min_dist, nearest_text) if return_nearest_text else min_dist

def setup_query_engine(
    db_directory: str,
    emb_store_type: Literal["simple, faiss"] = "simple",
    index_name: Optional[str] = None,
    similarity_top_k: int = 4,
    text_qa_template: Optional[BasePromptTemplate] = None,
    synthesizer_llm: str = "gpt-3.5-turbo-1106",
    emb_type: str = "openai",
    synthesizer_temperature: int = 0,
    synthesizer_max_tokens: int = 512,
    response_mode: str = "simple_summarize",
    node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
    callback_manager: Optional[CallbackManager] = None,
) ->  BaseQueryEngine:
    
    vector_index = load_vectorindex(db_directory, emb_store_type=emb_store_type, index_name=index_name)
    
    if emb_type == "openai":
        embs = OpenAIEmbedding()

    retriever = VectorIndexRetriever(
        index=vector_index, similarity_top_k=similarity_top_k,
        callback_manager=callback_manager
    )

    # Setup Synthesizer
    service_context = ServiceContext.from_defaults(
        llm=OpenAI(
            temperature=synthesizer_temperature,
            model=synthesizer_llm, max_tokens=synthesizer_max_tokens
            ),
        embed_model=embs, callback_manager=callback_manager
    )

    response_synthesizer = get_response_synthesizer(
        service_context=service_context, response_mode=response_mode,
        text_qa_template=text_qa_template
    )
    
    # Setup QueryEngine
    query_engine = RetrieverQueryEngine(
        retriever=retriever, response_synthesizer=response_synthesizer,
        node_postprocessors = node_postprocessors
    )
    
    return query_engine

def process_result_json(
    testcase_df: pd.DataFrame, responses: List[Response], save_path: Optional[str] = None
) -> Dict:
    json_responses = []
    queries = testcase_df["queries"]
    scan_orders = testcase_df["Scan Order"]
    
    tk = tqdm(zip(queries, responses, scan_orders), total=len(responses))
    for query, response, scan_order in tk:
        testcase_info = {
            "question": query,
            "result": response.response,
            "source_documents": [convert_doc_to_dict(doc) for doc in response.source_nodes]
        }
        answer_query = "Scan Ordered: {}\nAnswer: {}".format(scan_order, testcase_info["result"])
        fixed_answer = fixing_chain(answer_query)
        try:
            appropriateness, recommendation = re.findall(
            #  r"^Appropriateness: ([0-9A-Za-z ]+)\nRecommendation: ([0-9A-Za-z \.]+)$", fixed_answer["text"])[0]
                r"^[^\n]*Appropriateness: ([^\n]+)\n+[^\n]*Recommendation: ([^\n]+)$", fixed_answer["text"])[0]
        except:
            appropriateness, recommendation = "", ""
        testcase_info["appropriateness"] = appropriateness
        testcase_info["recommendation"] = recommendation

        json_responses.append(testcase_info)
        
    if save_path:
        with open(save_path, "w") as f:
            json.dump(json_responses, f)
    return json_responses

def process_result_df(
    testcase_df: pd.DataFrame, results: Union[List[Dict], List[Response]], save_path: Optional[str] = None
):
    if isinstance(results[0], Response):
        results = process_result_json(testcase_df, results)
    
    result_df = deepcopy(testcase_df)
    result_df["gpt_raw_answer"] = [response["result"] for response in results]
    result_df["gpt_classification"] = [response["appropriateness"] for response in results]
    result_df["gpt_classification"] = result_df["gpt_classification"].str.upper()
    result_df["gpt_recommendation"] = [response["recommendation"] for response in results]
    result_df["context"] = [
        "\n\n\n\n".join(["Metadata: {}\nScore: {}\n\nPage Content: {}".format(
            "\n".join([f"{k}: {v}" for k, v in document["metadata"].items()]),
            document["score"],  document["page_content"])
                         for document in response["source_documents"]])
        for response in results
    ]

    result_df = result_df.rename(columns = {"Appropriateness Category": "human_gt"})

    result_df["human_gt"] = result_df["human_gt"].str.replace(r"^UA$", "USUALLY APPROPRIATE", regex=True)
    result_df["human_gt"] = result_df["human_gt"].str.replace(r"^UNA$", "USUALLY NOT APPROPRIATE", regex=True)
    result_df["human_gt"] = result_df["human_gt"].str.replace(r"^MBA$", "MAY BE APPROPRIATE", regex=True)
    result_df["human_gt"] = result_df["human_gt"].str.replace(r"^ICI$", "INSUFFICIENT INFORMATION", regex=True)
    
    result_df["match"] = (result_df["gpt_classification"] == result_df["human_gt"])

    if save_path:
        result_df.to_csv(save_path)

    return result_df

def count_tokens(
    texts: Union[str, TextNode,NodeWithScore,List],
    tokenizer: Callable = tiktoken.encoding_for_model("gpt-3.5-turbo")
):
    token_counter = 0
    if not isinstance(texts, List):
        texts = [texts]
    for text in texts:
        if isinstance(text, NodeWithScore):
            text_str = text.node.text
        elif isinstance(text, TextNode):
            text_str = text.text
        elif isinstance(text, str):
            text_str = text
        token_counter += len(tokenizer.encode(text_str))
    return token_counter

def extract_guidelines(
    profile: str,
    extract_chain: LLMChain
) -> Tuple[str, List[str]]:
    extracted_response = extract_chain(profile)["text"]
    if extracted_response.endswith("."):
        extracted_response = extracted_response[:-1]

    pattern = r"1. Relevant information:([\S\s]+)2. Relevant guidelines:([\S\s]*)"

    profile, guidelines_str = re.findall(pattern, extracted_response)[0]
    guidelines_str = guidelines_str.replace("- ", "")
    guidelines_str = guidelines_str.strip()
    guidelines_str = guidelines_str.replace("\n", ", ")

    if not guidelines_str:
        relevant_guidelines = []
    else:
        regex_guidelines = re.findall(r"([A-Za-z ]+)", guidelines_str)
        relevant_guidelines = []
        for extracted_guideline in regex_guidelines:
            extracted_guideline = extracted_guideline.lower()
            min_dist, nearest_text = calculate_min_dist(extracted_guideline, GUIDELINES, True)
            if min_dist <= 1:
                extracted_guideline = nearest_text
                relevant_guidelines.append(extracted_guideline)
                
    return profile, relevant_guidelines

In [20]:
def run_test_cases(
    testcase_df: pd.DataFrame,
    exp_args: Dict,
    testcases: Sequence[str] = None,
    patient_profiles: Sequence[str] = None,
    scan_orders: Sequence[str] = None,
    refined_profiles: Sequence[str] = None,
    relevant_guidelines: Sequence[List[str]] = None,
    query_engine: Optional[BaseQueryEngine] = None,
    query_template: str = "Patient Profile: {profile}\nScan ordered: {scan_order}",
    text_qa_template: Optional[BasePromptTemplate] = None,
    refine_template: Optional[LCChatPromptTemplate] = REFINE_TEMPLATE,
    node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
    artifact_dir: str = ARTIFACT_DIR,
    emb_folder: str = EMB_DIR,
):
    save_folder = os.path.join(
        artifact_dir, "{}_{}_{}_{}_{}".format(
            exp_args["synthesizer_llm"],
            exp_args["chunk_size"],
            exp_args["chunk_overlap"],
            exp_args["description"],
            datetime.now().strftime("%d-%m-%Y-%H-%M")
        )
    )

    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    logger = get_experiment_logs(exp_args["description"], log_folder=save_folder)
    
    if not query_engine:
        token_counter = TokenCountingHandler(
            tokenizer=tiktoken.encoding_for_model(exp_args["synthesizer_llm"]).encode
        )
        callback_manager = CallbackManager([token_counter])
        db_directory = os.path.join(
            emb_folder, exp_args["vectorstore"],
            "{}_{}_{}".format(exp_args["emb_type"], exp_args["chunk_size"], exp_args["chunk_overlap"])
            )
        
        logger.info(f"--------------------\nLoading VectorDB from {db_directory}")
        query_engine = setup_query_engine(
            db_directory,
            emb_store_type=exp_args["vectorstore"],
            index_name=exp_args["index_name"],
            similarity_top_k=exp_args["similarity_top_k"],
            text_qa_template=text_qa_template,
            synthesizer_llm = exp_args["synthesizer_llm"],
            synthesizer_temperature = exp_args["synthesizer_temperature"],
            synthesizer_max_tokens = exp_args["synthesizer_max_tokens"],
            response_mode = "simple_summarize",
            node_postprocessors = node_postprocessors,
            callback_manager = callback_manager
        )
        
    else:
        token_counter = query_engine.callback_manager.handlers[0]

    logger.info(
        "-------------\nExperiment settings:\n{}".format(
            "\n".join([f"{k}:{v}" for k, v in exp_args.items()])
        )
    )

    with open(os.path.join(save_folder, "settings.yaml"), "w") as f:
        yaml.dump(exp_args, f)

    responses = []
    
    logger.info(
        "-------------\nQA PROMPT: {}".format(convert_prompt_to_string(query_engine._response_synthesizer._text_qa_template))
    )

    logger.info(
        "------START RUNNING TEST CASES---------"
    )

    if not (patient_profiles is not None and scan_orders is not None):
        if not testcases:
            testcases = testcase_df["Clinical File"]
        patient_profiles = [remove_final_sentence(testcase, True)[0] for testcase in testcases]
        scan_orders = [remove_final_sentence(testcase, True)[1] for testcase in testcases]

    if exp_args.get("refine_profile") or exp_args.get("metadata_filter"):
        if not (refined_profiles and relevant_guidelines):
            logger.info(
                "-------------\nREFINE PROMPT: {}".format(convert_prompt_to_string(refine_template))
            )
            
            from langchain.callbacks import get_openai_callback
            refine_chain = LLMChain(
                llm=LCChatOpenAI(model_name=exp_args.get("refine_llm", "gpt-3.5-turbo-1106"), temperature=0, max_tokens=512),
                prompt=refine_template)
            
            with get_openai_callback() as cb:
                refined_infos = [extract_guidelines(profile, refine_chain) for profile in tqdm(patient_profiles, total=len(patient_profiles))]
            print(f"Number of refined tokens: Prompt tokens = {cb.prompt_tokens}, Completion tokens = {cb.completion_tokens}")
    
            refined_profiles = [refined_info[0] for refined_info in refined_infos]
            relevant_guidelines = [refined_info[1] for refined_info in refined_infos]
    
    if exp_args.get("refine_profile"):
        patient_profiles = refined_profiles
    
    testcase_df["queries"] = [query_wrapper(query_template, {"profile": patient_profile, "scan_order": scan_order})
                 for patient_profile, scan_order in zip(patient_profiles, scan_orders)]
        
    metadata_filters = relevant_guidelines if exp_args.get("metadata_filter") else [None] * len(testcase_df["queries"])
    
    for query, metadata_filter in tqdm(zip(testcase_df["queries"], metadata_filters), total=len(testcase_df["queries"])):
        input_query = {"str_or_query_bundle": query, "table_filter": metadata_filter, "text_filter": metadata_filter} if metadata_filter is not None else {"str_or_query_bundle": query}
        response = query_engine.query(**input_query)
        responses.append(response)
    
      
    logger.info("--------------\nTokens Consumption: Total: {}, Prompt: {}, Completion: {}, Embeddings: {}"
                .format(token_counter.total_llm_token_count,
                        token_counter.prompt_llm_token_count,
                        token_counter.completion_llm_token_count,
                        token_counter.total_embedding_token_count))

    logger.info(f"----------\nTest case Completed. Saving Artifacts into {save_folder}")
    json_responses = process_result_json(
        testcase_df, responses=responses, save_path=os.path.join(save_folder, "results.json")
        )

    result_df = process_result_df(
        testcase_df, json_responses, save_path=os.path.join(save_folder, "result.csv")
        )

    accuracy = result_df["match"].sum() / len(result_df) * 100

    logger.info("------EVALUATION-----")
    logger.info(f"Accuracy score: {accuracy}")
    logger.info(
        str(result_df.groupby(["gpt_classification", "human_gt"])["match"].value_counts())
    )
    logger.info(
        str(result_df.groupby(["human_gt", "gpt_classification"])["match"].value_counts())
    )

    return json_responses, result_df, responses

# Load Test Data

In [6]:
testcase_df = pd.read_csv(
        os.path.join(DATA_DIR, "queries", "MSK LLM Fictitious Case Files Full.csv"),
        usecols = ['ACR scenario', 'Appropriateness Category', 'Scan Order',
                   'Difficulty', 'Clinical File']
        )

patient_profiles = [remove_final_sentence(patient_profile, True)[0].strip() for patient_profile in testcase_df["Clinical File"]]
scan_orders = [remove_final_sentence(patient_profile, True)[1].strip() for patient_profile in testcase_df["Clinical File"]]

In [29]:
documents = SimpleDirectoryReader(DOCUMENT_DIR).load_data()
print("Total no of docs before filtering:", len(documents))
with open(EXCLUDE_DICT, "r") as f:
    exclude_pages = json.load(f)
documents = filter_by_pages(doc_list=documents, exclude_info=exclude_pages)
print("Total number of docs after filtering", len(documents))

Total no of docs before filtering: 546
Total number of docs after filtering 395


# Run 4 - Separate Tables and Text Vectorstores

## Exp Settings Define

In [30]:
system_template = """
You are a radiologist expert at providing imaging recommendations for patients with musculoskeletal conditions.
If you do not know an answer, just say "I dont know", do not make up an answer.
==========
TASK: You are given a PATIENT PROFILE and a SCAN ORDER. Your task is to evaluate the appropriateness of the SCAN ORDER based on the PATIENT PROFILE.
Perform step-by-step the following sequence of reasoning.
1. Extract from PATIENT PROFILE relevant information for classification of imaging appropriateness. DO NOT make any assumptions from the SCAN ORDER.
Important information includes AGE, SYMPTOMS, PREVIOUS DIAGNOSIS (IF ANY), which stage of diagnosis (INITIAL IMAGING OR NEXT STUDY).
2. Refer to the reference information given under CONTEXT to analyse the appropriate imaging recommendations given the patient profile.
3. Identify there are superior imaging procedures or treatments with a more favorable risk-benefit ratio.
4. Based on the SCORING CRITERIA, recommend if the SCAN ORDER is USUALLY APPROPRIATE, MAY BE APPROPRIATE, USUALLY NOT APPROPRIATE or there is INSUFFICIENT INFORMATION to recommend the appropriateness.  
If the scan is not appropriate, recommend an appropriate procedure.

STRICTLY answer based on the given PATIENT PROFILE and CONTEXT. 
==========
SCORING CRITERIA:
- USUALLY APPROPRIATE: The imaging procedure or treatment is indicated in the specified clinical scenarios at a favorable risk-benefit ratio for patients.
- MAY BE APPROPRIATE: The imaging procedure or treatment may be indicated in the specified clinical scenarios as an alternative to imaging procedures or treatments with a more favorable risk-benefit ratio, or the risk-benefit ratio for patients is equivocal.
- USUALLY NOT APPROPRIATE: The imaging procedure or treatment is unlikely to be indicated in the specified clinical scenarios, or the risk-benefit ratio for patients is likely to be unfavorable.
- INSUFFICIENT INFORMATION: There is not enough information from PATIENT PROFILE and CONTEXT to give the recommendation
==========
CONTEXT:

{context_str}

Take note for scenarios involving IV CONTRAST there are 3 distinct scan protocols: (1) with IV CONTRAST, (2) without IV CONTRAST, (3) without and with IV CONTRAST. 
Each of them is different and can have different appropriateness category.
==========
"""

human_template = "{query_str}"
messages = [
    ChatMessage(role=MessageRole.SYSTEM, content=system_template),
    ChatMessage(role=MessageRole.USER, content=human_template)   
]

CHAT_PROMPT_TEMPLATE = ChatPromptTemplate(messages)

In [31]:
exp_args = dict(
    # Retrieval 
    emb_type = "openai",
    vectorstore = "faiss",
    chunk_size = 512,
    chunk_overlap = 20,
    table_similarity_top_k = 4,
    text_similarity_top_k = 5,
    index_name = "msk-mri",
    description="DescriptionsTableAndText",

    # Generation
    synthesizer_llm = "gpt-3.5-turbo",
    synthesizer_max_tokens = 512,
    synthesizer_temperature = 0,
    response_mode = "simple_summarize",
)

text_qa_template = CHAT_PROMPT_TEMPLATE
# node_postprocessors = [LongContextReorder()]
node_postprocessors = None

## Setup Query Engine

In [32]:
db_directory = os.path.join(DATA_DIR, "multimodal-faiss", "descriptions")
table_index = load_vectorindex(os.path.join(db_directory, "tables"), "faiss")
text_index = load_vectorindex(os.path.join(db_directory, "texts"), "faiss")

table_retriever = table_index.as_retriever(
    similarity_top_k = exp_args["table_similarity_top_k"],
    )
text_retriever = text_index.as_retriever(
    similarity_top_k = exp_args["text_similarity_top_k"],
)
text_and_table_retriever = CustomCombinedRetriever(
    table_retriever=table_retriever, text_retriever=text_retriever, token_limit = 7000
)

2023-10-28 20:19:34,356:INFO: faiss VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-faiss/descriptions/tables.


2023-10-28 20:19:34,432:INFO: faiss VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-faiss/descriptions/texts.


In [33]:
embs = OpenAIEmbedding()

token_counter = TokenCountingHandler(
    tokenizer=tiktoken.encoding_for_model(exp_args["synthesizer_llm"]).encode
)
callback_manager = CallbackManager([token_counter])

service_context = ServiceContext.from_defaults(
    llm=OpenAI(
        temperature=exp_args["synthesizer_temperature"],
        model=exp_args["synthesizer_llm"], max_tokens=exp_args["synthesizer_max_tokens"]
        ),
    embed_model=embs, callback_manager=callback_manager
)

response_synthesizer = get_response_synthesizer(
    service_context=service_context, response_mode=exp_args["response_mode"],
    text_qa_template=CHAT_PROMPT_TEMPLATE
)

query_engine = RetrieverQueryEngine(
    retriever=text_and_table_retriever, response_synthesizer=response_synthesizer,
    node_postprocessors = node_postprocessors, callback_manager = CallbackManager([token_counter])
)

## Run Test Cases

In [None]:
json_responses, result_df, responses = run_test_cases(
    testcase_df=testcase_df,
    exp_args=exp_args,
    patient_profiles=patient_profiles,
    scan_orders=scan_orders,
    query_engine=query_engine,
    text_qa_template=text_qa_template,
    node_postprocessors=node_postprocessors
    )

# Run 5 - Multisteps LLM Pipeline

In [168]:
# Extract and save first
# multiple_guideline_answers = []
# for clinical_file in tqdm(testcase_df["Clinical File"],
#                           total=len(testcase_df["Clinical File"])):
#     profile = remove_final_sentence(clinical_file)
#     response = refine_chain(profile)
#     multiple_guideline_answers.append(response["text"])

# profiles, guidelines = [], []

# for answer in multiple_guideline_answers:
#     if answer.endswith("."):
#         answer = answer[:-1]
#     pattern = r"1. Relevant information:([\S\s]+)2. Relevant guidelines:([\S\s]*)"
    
#     profile, guidelines_str = re.findall(pattern, answer)[0]
#     guidelines_str = guidelines_str.replace("- ", "")
#     guidelines_str = guidelines_str.strip()
#     guidelines_str = guidelines_str.replace("\n", ", ")
    
#     if not guidelines_str:
#         profiles.append(profile)
#         guidelines.append([])
#     else:
#         regex_guidelines = re.findall(r"([A-Za-z ]+)", guidelines_str)
#         extracted_guidelines = []
#         for i, extracted_guideline in enumerate(regex_guidelines):
#             extracted_guideline = extracted_guideline.lower()
#             min_dist, nearest_text = calculate_min_dist(extracted_guideline, GUIDELINES, True)
#             if min_dist <= 1:
#                 extracted_guideline = nearest_text
#                 extracted_guidelines.append(extracted_guideline)
#             else:
#                 print(extracted_guideline + "_" + nearest_text)
                
#     profiles.append(profile.strip())   
#     guidelines.append(extracted_guidelines)

# extracted_info_multiple_json = {"profiles": profiles, "guidelines": guidelines}

# with open(os.path.join(ARTIFACT_DIR, "extracted_multiple.json"), "w") as f:
#     json.dump(extracted_info_multiple_json, f)

## On our dataset

In [21]:
system_template = """
You are a radiologist expert at providing imaging recommendations for patients with musculoskeletal conditions.
If you do not know an answer, just say "I dont know", do not make up an answer.
==========
TASK: You are given a PATIENT PROFILE and a SCAN ORDER. Your task is to evaluate the appropriateness of the SCAN ORDER based on the PATIENT PROFILE.
Perform step-by-step the following sequence of reasoning.
1. Extract from PATIENT PROFILE relevant information for classification of imaging appropriateness. DO NOT make any assumptions from the SCAN ORDER.
Important information includes AGE, SYMPTOMS, PREVIOUS DIAGNOSIS (IF ANY), which stage of diagnosis (INITIAL IMAGING OR NEXT STUDY).
2. Refer to the reference information given under CONTEXT to analyse the appropriate imaging recommendations given the patient profile.
3. Identify there are superior imaging procedures or treatments with a more favorable risk-benefit ratio.
4. Based on the SCORING CRITERIA, recommend if the SCAN ORDER is USUALLY APPROPRIATE, MAY BE APPROPRIATE, USUALLY NOT APPROPRIATE or there is INSUFFICIENT INFORMATION to recommend the appropriateness.  
If the scan is not appropriate, recommend an appropriate procedure.

STRICTLY answer based on the given PATIENT PROFILE and CONTEXT. 
==========
SCORING CRITERIA:
- USUALLY APPROPRIATE: The imaging procedure or treatment is indicated in the specified clinical scenarios at a favorable risk-benefit ratio for patients.
- MAY BE APPROPRIATE: The imaging procedure or treatment may be indicated in the specified clinical scenarios as an alternative to imaging procedures or treatments with a more favorable risk-benefit ratio, or the risk-benefit ratio for patients is equivocal.
- USUALLY NOT APPROPRIATE: The imaging procedure is unlikely to be recommended in the specified clinical scenarios, or the risk-benefit ratio for patients is likely to be unfavorable
- INSUFFICIENT INFORMATION: The imaging procedure or treatment is not mentioned under CONTEXT or not enough relevant information from the PATIENT PROFILE to recommend based on information in CONTEXT.
==========
CONTEXT:

{context_str}

Take note for scenarios involving IV CONTRAST there are 3 distinct scan protocols: (1) with IV CONTRAST, (2) without IV CONTRAST, (3) without and with IV CONTRAST. 
Each of them is different and can have different appropriateness category.
==========
"""

human_template = "{query_str}"

messages = [
    ChatMessage(role=MessageRole.SYSTEM, content=system_template),
    ChatMessage(role=MessageRole.USER, content=human_template)   
]

CHAT_PROMPT_TEMPLATE = ChatPromptTemplate(messages)

In [22]:
exp_args = dict(
    # Retrieval 
    emb_type = "openai",
    vectorstore = "chroma",
    chunk_size = 512,
    chunk_overlap = 20,
    table_similarity_top_k = 4,
    text_similarity_top_k = 5,
    index_name = "msk-mri",
    description="DescriptionsTableAndTextWithMetadataFilter",
    metadata_filter = True,
    refine_profile = False,
    refine_llm = "gpt-4-1106-preview",

    # Generation
    synthesizer_llm = "gpt-3.5-turbo-1106",
    synthesizer_max_tokens = 512,
    synthesizer_temperature = 0,
    response_mode = "simple_summarize",
)

text_qa_template = CHAT_PROMPT_TEMPLATE
# node_postprocessors = [LongContextReorder()]
node_postprocessors = None

In [23]:
db_directory = os.path.join(DATA_DIR, "multimodal-chroma", "descriptions")
table_index = load_vectorindex(os.path.join(db_directory, "tables"), "chroma")
text_index = load_vectorindex(os.path.join(db_directory, "texts"), "chroma")

table_retriever = table_index.as_retriever(
    similarity_top_k = exp_args["table_similarity_top_k"]
    )
text_retriever = text_index.as_retriever(
    similarity_top_k = exp_args["text_similarity_top_k"]
)
text_and_table_retriever = CustomCombinedRetriever(
    table_retriever=table_retriever, text_retriever=text_retriever, token_limit = 7000
)

2023-11-25 13:20:11,955:INFO: chroma VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-chroma/descriptions/tables.
2023-11-25 13:20:12,090:INFO: chroma VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-chroma/descriptions/texts.


In [24]:
embs = OpenAIEmbedding()

token_counter = TokenCountingHandler(
    tokenizer=tiktoken.encoding_for_model(exp_args["synthesizer_llm"]).encode
)
callback_manager = CallbackManager([token_counter])

service_context = ServiceContext.from_defaults(
    llm=OpenAI(
        temperature=exp_args["synthesizer_temperature"],
        model=exp_args["synthesizer_llm"], max_tokens=exp_args["synthesizer_max_tokens"]
        ),
    embed_model=embs, callback_manager=callback_manager
)

response_synthesizer = get_response_synthesizer(
    service_context=service_context, response_mode=exp_args["response_mode"],
    text_qa_template=CHAT_PROMPT_TEMPLATE
)

query_engine = CustomRetrieverQueryEngine(
    retriever=text_and_table_retriever, response_synthesizer=response_synthesizer,
    node_postprocessors = node_postprocessors, callback_manager = CallbackManager([token_counter])
)

In [27]:
with open(os.path.join(ARTIFACT_DIR, "extracted_multiple.json"), "r") as f:
    extracted_infos = json.load(f)
    
refined_profiles = extracted_infos["profiles"]
relevant_guidelines = extracted_infos["guidelines"]

In [None]:
json_responses, result_df, responses = run_test_cases(
    testcase_df=testcase_df,
    exp_args=exp_args,
    # patient_profiles=patient_profiles,
    # scan_orders=scan_orders,
    refined_profiles=refined_profiles,
    relevant_guidelines=relevant_guidelines,
    query_engine=query_engine,
    text_qa_template=text_qa_template,
    node_postprocessors=node_postprocessors
    )