# Import Packages

In [8]:
from config import MAIN_DIR
from copy import deepcopy
from utils import load_vectorindex, count_tokens, get_experiment_logs, filter_by_pages
from logging import Logger
from pprint import pprint
from statsmodels.stats.weightstats import ztest
from textdistance import levenshtein
from typing import Dict, Union, List, Literal, Sequence, Optional
import numpy as np
import os, json
import pandas as pd
import re

from llama_index import SimpleDirectoryReader
from llama_index.embeddings import OpenAIEmbedding
from llama_index.embeddings.base import BaseEmbedding
from llama_index.indices.base_retriever import BaseRetriever
from llama_index.schema import Document, NodeWithScore, MetadataMode

In [9]:
DATA_DIR = os.path.join(MAIN_DIR, "data")
ARTIFACT_DIR = os.path.join(MAIN_DIR, "artifacts")
EMB_DIR = os.path.join(DATA_DIR, "emb_store")
DOCUMENT_DIR = os.path.join(MAIN_DIR, "data", "document_sources")
EXCLUDE_DICT = os.path.join(DATA_DIR, "exclude_pages.json")

with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)
    
os.environ["OPENAI_API_KEY"] = api_keys["OPENAI_API_KEY"]
embed_model = OpenAIEmbedding()

In [10]:
def convert_doc_to_dict(doc: Union[Document, NodeWithScore, Dict]) -> Dict:
    if isinstance(doc, NodeWithScore):
        json_doc = {
            "page_content": doc.text,
            "metadata": doc.metadata,
            "score": doc.score
            } 
    elif isinstance(doc, Document):
        json_doc = {
            "page_content": doc.text,
            "metadata": doc.metadata,
            "score": ""
            }
    elif isinstance(doc, Dict):
        json_doc = {
            "page_content": doc["text"],
            "metadata": doc["metadata"],
            "score": "None"
        }
    return json_doc

def remove_final_sentence(
    text: str,
    return_final_sentence: bool = False
):
    text = text.strip()
    if text.endswith("."):
        text = text[:-1]
    sentence_list = text.split(".")
    previous_text = ".".join(sentence_list[:-1])
    final_sentence = sentence_list[-1]
    return (previous_text, final_sentence) if return_final_sentence else previous_text

def query_wrapper(
    template: str, 
    input_text: Union[str, Dict[str, str]]
) -> str:
    placeholders = re.findall(pattern = r"{([A-Za-z0-9_-]+)}", string=template)
    if isinstance(input_text, str):
        assert len(placeholders) == 1, "Must Provide a single placeholder when input_text is string."
        placeholder = placeholders[0]
        return template.format(**{placeholder:input_text})
    
    assert len(input_text) == len(placeholders)
    for key in input_text.keys():
        assert key in placeholders, f"{key} not present in template."
    
    return template.format(**input_text)

def calculate_emb_distance(
    emb1: List[float],
    emb2: List[float],
    dist_type: Literal["l2", "ip", "cosine", "neg_exp_l2"] = "l2"
):
    assert len(emb1) == len(emb2), "Length of embedding vectors must match"
    if dist_type == "l2":
        return np.square(np.linalg.norm(np.array(emb1) - np.array(emb2)))
    elif dist_type == "ip":
        return 1 - np.dot(emb1, emb2)
    elif dist_type == "cosine":
        cosine_similarity = np.dot(emb1, emb2)/(np.norm(emb1)*np.norm(emb2))
        return 1 - cosine_similarity
    elif dist_type == "neg_exp_l2":
        return np.exp(-np.square(np.linalg.norm(np.array(emb1) - np.array(emb2))))
    else:
        raise ValueError("Invalid distance type")
    
def calculate_string_distance(
    str1: str,
    str2: Union[str, Sequence[str]],
    embeddings: BaseEmbedding,
    dist_type: Literal["l2", "ip", "cosine", "neg_exp_l2"] = "l2"
):
    emb1 = embeddings.get_query_embedding(str1)
    if isinstance(str2, str):
        emb2 = embeddings.get_text_embedding(str2)
        return calculate_emb_distance(emb1, emb2, dist_type)
    else:
        emb2_list = embeddings.get_text_embedding_batch(str2)
        return [calculate_emb_distance(emb1, emb2) for emb2 in emb2_list]
    
def remove_final_sentence(
    text: str,
    return_final_sentence: bool = False
):
    text = text.strip()
    if text.endswith("."):
        text = text[:-1]
    sentence_list = text.split(".")
    previous_text = ".".join(sentence_list[:-1])
    final_sentence = sentence_list[-1]
    return (previous_text, final_sentence) if return_final_sentence else previous_text

def query_wrapper(
    template: str, 
    input_text: Union[str, Dict[str, str]]
) -> str:
    placeholders = re.findall(pattern = r"{([A-Za-z0-9_-]+)}", string=template)
    if isinstance(input_text, str):
        assert len(placeholders) == 1, "Must Provide a single placeholder when input_text is string."
        placeholder = placeholders[0]
        return template.format(**{placeholder:input_text})
    
    assert len(input_text) == len(placeholders)
    for key in input_text.keys():
        assert key in placeholders, f"{key} not present in template."
    
    return template.format(**input_text)

def retrieval_analysis(
    testcase_df: pd.DataFrame,
    testcases: Sequence[str] = None,
    metadata_filters: Optional[Sequence[List[str]]] = None,
    table_retriever: Optional[BaseRetriever] = None,
    text_retriever: Optional[BaseRetriever] = None,
    save_folder: Optional[str] = None,
    logger: Optional[Logger] = None
):
    # Retrieve Nodes
    retrieved_table_nodes = []
    retrieved_text_nodes = []

    if not os.path.exists(save_folder):
        os.makedirs(save_folder, exist_ok=True)
    
    if not logger:
        logger = get_experiment_logs(
            save_folder.split("/")[-1], log_folder=save_folder
        )
        
    if testcases is not None:
        testcase_df["queries"] = testcases
    
    if not metadata_filters:
        if table_retriever:
            table_retriever._kwargs = {}
        if text_retriever:
            text_retriever._kwargs = {}
            
        for test_case in testcases:
            retrieved_table_nodes.append(table_retriever.retrieve(test_case) if table_retriever else [])
            retrieved_text_nodes.append(text_retriever.retrieve(test_case) if text_retriever else [])
    
    else:
        for test_case, filter_list in zip(testcases, metadata_filters):
            if filter_list is not None:
                if table_retriever:
                    table_retriever._kwargs["where"] = {"condition": {"$in": filter_list}}
                if text_retriever:
                    text_retriever._kwargs["where"] = {"condition": {"$in": filter_list}}
            
            else:
                if table_retriever:
                    table_retriever._kwargs = {}
                if text_retriever:
                    text_retriever._kwargs = {}
            
            retrieved_table_nodes.append(table_retriever.retrieve(test_case) if (table_retriever and filter_list) else [])
            retrieved_text_nodes.append(text_retriever.retrieve(test_case) if (text_retriever and filter_list) else [])

    table_top_k = table_retriever.similarity_top_k if table_retriever else 0 
    text_top_k = text_retriever.similarity_top_k if text_retriever else 0
    
    logger.info(f"Successfully loaded table database k={table_top_k} and text database k={text_top_k}")
    
    # Analyse retrieved contents
    description_df = deepcopy(testcase_df)

    retrieved_tables = [[] for _ in range(table_top_k)]
    retrieved_texts = [[] for _ in range(text_top_k)]
    
    for case_idx in range(len(testcases)):
        case_table_nodes = retrieved_table_nodes[case_idx]
        case_text_nodes = retrieved_text_nodes[case_idx]
        for node_idx, node_info_list in enumerate(retrieved_tables):
            if node_idx > len(case_table_nodes) - 1:
                node_info = np.nan
            else:
                node = case_table_nodes[node_idx]
                node_info = node.get_content(MetadataMode.EMBED) + "\n\nScore: {}".format(node.score)
            node_info_list.append(node_info)
        for node_idx, node_info_list in enumerate(retrieved_texts):
            if node_idx > len(case_text_nodes) - 1:
                node_info = np.nan
            else:
                node = case_text_nodes[node_idx]
                node_info = node.get_content(MetadataMode.EMBED) + "\n\nScore: {}".format(node.score)
            node_info_list.append(node_info)
            
    for idx, tables in enumerate(retrieved_tables):
        description_df[f"Table_{idx+1}"] = tables
        
    for idx, texts in enumerate(retrieved_texts):
        description_df[f"Text_{idx+1}"] = texts

    description_df.to_csv(os.path.join(save_folder, "table_text.csv"))
    
    # Analyse Document level Hit Rate
    file_df = deepcopy(testcase_df)

    retrieved_table_pages = [[] for _ in range(table_top_k)]
    retrieved_text_pages = [[] for _ in range(text_top_k)]
    
    for case_idx in range(len(testcases)):
        case_table_nodes = retrieved_table_nodes[case_idx]
        case_text_nodes = retrieved_text_nodes[case_idx]
        
        for node_idx, node_info_list in enumerate(retrieved_table_pages):
            if node_idx > len(case_table_nodes) - 1:
                node_file = np.nan
            else:
                node = case_table_nodes[node_idx]
                node_file = node.metadata["file_name"]
            node_info_list.append(node_file)
        for node_idx, node_info_list in enumerate(retrieved_text_pages):
            if node_idx > len(case_text_nodes) - 1:
                node_file = np.nan
            else:
                node = case_text_nodes[node_idx]
                node_file = node.metadata["file_name"]
            node_info_list.append(node_file)
         
    if table_retriever:           
        for idx, tables in enumerate(retrieved_table_pages):
            file_df[f"Table_{idx+1}_Page"] = retrieved_table_pages[idx]
            file_df[f"Table_{idx+1}_HIT"] = file_df[f"Table_{idx+1}_Page"] == file_df["Guideline"]
        file_df["Total Retrieved Tables"] = (~file_df[[f"Table_{idx+1}_Page" for idx in range(table_top_k)]].isna()).sum(axis=1)
        file_df["Total Table HITs"] = file_df[[f"Table_{idx+1}_HIT" for idx in range(table_top_k)]].sum(axis=1)
        table_retrieved_nodes_no = file_df["Total Retrieved Tables"].sum()
        table_node_retrieval_precision = file_df["Total Table HITs"].sum() / table_retrieved_nodes_no
        table_node_retrieval_recall = ((file_df["Total Table HITs"] > 0) | (file_df["Appropriateness Category"] == "ICI")).sum() / len(testcases)
        
        logger.info(f"Table Node Retrieval Precision: {table_node_retrieval_precision * 100:.3f}")
        logger.info(f"Table Node Retrieval Recall: {table_node_retrieval_recall * 100:.3f}")
    
    if text_retriever:
        for idx, texts in enumerate(retrieved_text_pages):
            file_df[f"Text_{idx+1}_Page"] = retrieved_text_pages[idx]
            file_df[f"Text_{idx+1}_HIT"] = file_df[f"Text_{idx+1}_Page"] == file_df["Guideline"]
        file_df["Total Text HITs"] = file_df[[f"Text_{idx+1}_HIT"for idx in range(text_top_k)]].sum(axis=1)
        file_df["Total Retrieved Texts"] = (~file_df[[f"Text_{idx+1}_Page" for idx in range(text_top_k)]].isna()).sum(axis=1)
        text_retrieved_nodes_no = file_df["Total Retrieved Texts"].sum()
        text_node_retrieval_precision = file_df["Total Text HITs"].sum() / text_retrieved_nodes_no
        text_node_retrieval_recall = ((file_df["Total Text HITs"] > 0) | (file_df["Appropriateness Category"] == "ICI")).sum() / len(testcases)
        
        logger.info(f"Text Node Retrieval Precision: {text_node_retrieval_precision * 100:.3f}")
        logger.info(f"Text Node Retrieval Recall: {text_node_retrieval_recall * 100:.3f}")
    
    if table_retriever and text_retriever:
        file_df["Total HITs"] = file_df[["Total Table HITs", "Total Text HITs"]].sum(axis=1)
        file_df["Total Retrieved Nodes"] = file_df["Total Retrieved Tables"] + file_df["Total Retrieved Texts"]
        total_node_retrieval_precision = file_df["Total HITs"].sum() / (table_retrieved_nodes_no + text_retrieved_nodes_no)
        total_node_retrieval_recall = ((file_df["Total HITs"] > 0) | (file_df["Appropriateness Category"] == "ICI")).sum() / len(testcases)
       
        logger.info(f"Total Node Retrieval Precision: {total_node_retrieval_precision * 100:.3f}")
        logger.info(f"Total Node Retrieval Recall: {total_node_retrieval_recall * 100:.3f}")
    
    file_df.to_csv(os.path.join(save_folder, "table_text_pages.csv"))
    
    # Analyse Node Scores
    score_df = deepcopy(testcase_df)

    retrieved_table_scores = [[] for _ in range(table_top_k)]
    retrieved_texts_scores = [[] for _ in range(text_top_k)]

    for case_idx in range(len(testcases)):
        case_table_nodes = retrieved_table_nodes[case_idx]
        case_text_nodes = retrieved_text_nodes[case_idx]
        
        for node_idx, node_info_list in enumerate(retrieved_table_scores):
            if node_idx > len(case_table_nodes) - 1:
                node_score = np.nan
            else:
                node = case_table_nodes[node_idx]
                node_score = node.score
            node_info_list.append(node_score)       
            
        for node_idx, node_info_list in enumerate(retrieved_texts_scores):
            if node_idx > len(case_text_nodes) - 1:
                node_score = np.nan
            else:
                node = case_text_nodes[node_idx]
                node_score = node.score
            node_info_list.append(node_score)
            
    if table_retriever:
        for idx, tables in enumerate(retrieved_table_scores):
            score_df[f"Table_{idx+1}"] = retrieved_table_scores[idx]
        score_df['avg_table'] = score_df[[f"Table_{idx+1}" for idx in range(table_top_k)]].mean(axis=1)
        mean_table_score = score_df['avg_table'].mean()
        std_table_score = score_df['avg_table'].std()
        logger.info(f"Mean Table Distance (L2): {mean_table_score:.3f}")
        logger.info(f"STD Table Distance (L2): {std_table_score:.3f}")

    if text_retriever:
        for idx, texts in enumerate(retrieved_texts_scores):
            score_df[f"Text_{idx+1}"] = retrieved_texts_scores[idx]
        score_df['avg_text'] = score_df[[f"Text_{idx+1}" for idx in range(text_top_k)]].mean(axis=1)
        mean_text_score = score_df['avg_text'].mean()
        std_text_score = score_df['avg_text'].std()
        logger.info(f"Mean Text Distance (L2): {mean_text_score:.3f}")
        logger.info(f"STD Text Distance (L2): {std_text_score:.3f}")
        
    if table_retriever and text_retriever:
        score_df['avg_overall'] = score_df[["avg_table", "avg_text"]].mean(axis=1)
        mean_overall_score = score_df['avg_overall'].mean()
        std_overall_score = score_df['avg_overall'].std()
        logger.info(f"Mean Overall Distance (L2): {mean_overall_score:.3f}")
        logger.info(f"STD Overall Distance (L2): {std_overall_score:.3f}")

    score_df.to_csv(os.path.join(save_folder, "table_text_scores.csv"))
    
    # Table HIT RATE
    def add_punctuation(text: str):
        if not text.endswith("."):
            text = text + "."
        return text

    if table_retriever:
        table_hitrate_df = deepcopy(testcase_df)

        table_hitrate_df['ACR scenario'] = table_hitrate_df['ACR scenario'].str.strip().apply(lambda x: add_punctuation(x))
        retrieved_table_descriptions = [[] for _ in range(table_top_k)]

        for case_idx in range(len(testcases)):
            case_table_nodes = retrieved_table_nodes[case_idx]
            
            for node_idx, node_info_list in enumerate(retrieved_table_descriptions):
                if node_idx > len(case_table_nodes) - 1:
                    node_description = ""
                else:
                    node = case_table_nodes[node_idx]
                    node_description = node.metadata["description"].strip()
                node_info_list.append(node_description)
                
        for idx, tables in enumerate(retrieved_table_descriptions):
            table_hitrate_df[f"Table_{idx+1}_descriptions"] = retrieved_table_descriptions[idx]
            table_hitrate_df[f"Table_{idx+1}_LVscore"] = table_hitrate_df.apply(lambda x: levenshtein.distance(x["ACR scenario"], x[f"Table_{idx+1}_descriptions"]), axis=1)
            
        table_hitrate_df["Min_LV_Dist"] = table_hitrate_df[[f"Table_{idx+1}_LVscore" for idx in range(table_top_k)]].min(axis=1)
        table_hitrate_df["HIT"] = (table_hitrate_df["Min_LV_Dist"] < 5) + (table_hitrate_df["Appropriateness Category"] == "ICI")

        logger.info("Exact Retrieved Table Nodes Hit Rate: {:.3f}".format(table_hitrate_df["HIT"].sum() / len(table_hitrate_df["HIT"])))

        table_hitrate_df.to_csv(os.path.join(save_folder, "hit_rate.csv"))

    return (description_df, file_df, score_df, table_hitrate_df) if table_retriever else (description_df, file_df, score_df)

# Load Test Data

In [15]:
testcase_df = pd.read_csv(
        os.path.join(DATA_DIR, "queries", "MSK LLM Fictitious Case Files Full.csv"),
        usecols = ['ACR scenario', 'Guideline', 'Variant', 'Appropriateness Category',
                   'Scan Order', 'Clinical File']
        )
patient_profiles = testcase_df["Clinical File"]
scan_orders = testcase_df["Scan Order"]

question_template = "Patient Profile: {profile}\nScan ordered: {scan_order}"

testcase_df["queries"] = [
    query_wrapper(question_template, {"profile": remove_final_sentence(patient_profile, True)[0],
                                      "scan_order": remove_final_sentence(patient_profile, True)[1]})
    for patient_profile in patient_profiles
    ]

In [16]:
documents = SimpleDirectoryReader(DOCUMENT_DIR).load_data()
print("Total no of docs before filtering:", len(documents))
with open(EXCLUDE_DICT, "r") as f:
    exclude_pages = json.load(f)
documents = filter_by_pages(doc_list=documents, exclude_info=exclude_pages)
print("Total number of docs after filtering", len(documents))

Total no of docs before filtering: 546
Total number of docs after filtering 395


# Evaluation

## Test LLM Extraction Hit Rate

In [18]:
with open(os.path.join(ARTIFACT_DIR, "extracted_best.json"), "r") as f:
    extracted_best = json.load(f)
    extracted_profiles = extracted_best["profiles"]
    extracted_best_guidelines = extracted_best["guidelines"]
    
with open(os.path.join(ARTIFACT_DIR, "extracted_multiple.json"), "r") as f:
    extracted_multiple = json.load(f)
    extracted_multiple_guidelines = extracted_multiple["guidelines"]
    
extracted_profiles = [profile.strip() for profile in extracted_profiles]

In [19]:
emb_df = deepcopy(testcase_df)
emb_df["llm_profile"] = extracted_profiles
emb_df["best_guideline"] = extracted_best_guidelines
emb_df["multiple_guidelines"] = extracted_multiple_guidelines
emb_df["condition"] = emb_df["Guideline"].apply(lambda x: x[4:-4]).str.lower()
emb_df["best_hit"] = [(guideline in best_guideline) or (app=="ICI" and not best_guideline)
                      for guideline, best_guideline, app in zip(emb_df["condition"], emb_df["best_guideline"], emb_df["Appropriateness Category"])]
emb_df["multiple_hit"] = [(guideline in multiple_guidelines) or (app=="ICI" and not multiple_guidelines)
                          for guideline, multiple_guidelines, app in zip(emb_df["condition"], emb_df["multiple_guidelines"], emb_df["Appropriateness Category"])]

In [20]:
print("Best Hit Rate: ", emb_df["best_hit"].sum()/len(emb_df))
print("Multiple Hit Rate: ", emb_df["multiple_hit"].sum()/len(emb_df))

Best Hit Rate:  0.8873239436619719
Multiple Hit Rate:  0.9436619718309859


In [21]:
for idx in emb_df[~emb_df["multiple_hit"]].index:
    case_file = emb_df.iloc[idx, :]
    print("Case:", idx + 1)
    pprint("Patient: " + remove_final_sentence(case_file["Clinical File"]))
    pprint("Variant: " + case_file["ACR scenario"])
    pprint(case_file["multiple_guidelines"])
    pprint(case_file["Guideline"])
    print()

Case: 19
('Patient: 46 year old Chinese male.  Businessman, sales.  Frequent drinker '
 'due to job, about 1-2 beers/day.  Past medical history of fatty liver, acute '
 'cholecystitis post cholecystectomy.  Now presenting with severe pain at '
 'right big toe for 2 weeks, pain on and off improves with paracetamol and '
 'ibuprofen.  On examination: swelling and erythema at right big toe 1st '
 'metatarsophalangeal joint.  No prior imaging')
('Variant: Chronic extremity joint pain. Suspect inflammatory (seropositive or '
 'seronegative arthritis), crystalline (gout or pseudogout), or erosive '
 'osteoarthritis. Initial imaging.')
['chronic foot pain', 'suspected om septic arthritis soft tissue infection']
'ACR chroni extremity joint pain inflammatory arthritis.pdf'

Case: 20
('Patient: 55 year old Philipino female.  Domestic helper.  No significant '
 'past medical history.  Bilateral finger joint pain, swelling and stiffness '
 'worse in the morning for 4 months, worsening recently.  N

## Test Embeddings

### Embedding Content:
- Metadata (Descriptions + Conditions)

In [None]:
emb_analysis_path = os.path.join(ARTIFACT_DIR, "emb_analysis")
if not os.path.exists(emb_analysis_path):
    os.makedirs(emb_analysis_path, exist_ok=True)

In [22]:
emb_df["original_dist"] = emb_df.apply(lambda x: calculate_string_distance(x["Clinical File"], x["ACR scenario"], embed_model), axis=1)
emb_df["query_dist"] = emb_df.apply(lambda x: calculate_string_distance(x["queries"], x["ACR scenario"], embed_model), axis=1)
emb_df["refine_dist"] = emb_df.apply(lambda x: calculate_string_distance(x["llm_profile"], x["ACR scenario"], embed_model), axis=1)

In [None]:
emb_df.to_csv(os.path.join(emb_analysis_path, "emb_analysis.csv"))

In [None]:
db_directory = os.path.join(DATA_DIR, "multimodal-chroma", "descriptions")
table_index = load_vectorindex(os.path.join(db_directory, "tables"), "chroma")
text_index = load_vectorindex(os.path.join(db_directory, "texts"), "chroma")

table_retriever = table_index.as_retriever(
    similarity_top_k = 5,
    filters = None
    )
text_retriever = text_index.as_retriever(
    similarity_top_k = 5,
    filters = None
)

## Test Retriever

### Baseline Single Retriever

In [32]:
db_directory = os.path.join(DATA_DIR, "emb_store", "faiss", "openai_512_20")
baseline_index = load_vectorindex(db_directory, "faiss")

baseline_text_retriever = baseline_index.as_retriever(similarity_top_k = 5)

2023-12-08 19:03:02,555:INFO: faiss VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/emb_store/faiss/openai_512_20.


In [None]:
# Just file
save_folder = os.path.join(ARTIFACT_DIR, "retrieval_analysis_baseline_nosplit_withscanorder_full")
question_template = "Patient Profile: {profile}\nScan ordered: {scan_order}"
testcases = [
    query_wrapper(question_template, {"profile": remove_final_sentence(patient_profile, True)[0],
                                      "scan_order": remove_final_sentence(patient_profile, True)[1]})
    for patient_profile in patient_profiles
    ]

description_df, file_df, score_df, table_hitrate_df = retrieval_analysis(
    testcase_df=testcase_df, testcases=testcases,
    text_retriever=baseline_text_retriever, save_folder=save_folder
)

### Embedding Content:
- Metadata (Descriptions + Conditions)
- Table Texts

In [34]:
db_directory = os.path.join(DATA_DIR, "multimodal-faiss", "full")
table_index = load_vectorindex(os.path.join(db_directory, "tables"), "faiss")
text_index = load_vectorindex(os.path.join(db_directory, "texts"), "faiss")

table_retriever = table_index.as_retriever(similarity_top_k = 5)
text_retriever = text_index.as_retriever(similarity_top_k = 5)

2023-12-08 19:03:23,068:INFO: faiss VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-faiss/full/tables.


2023-12-08 19:03:23,192:INFO: faiss VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-faiss/full/texts.


In [35]:
# Just file
save_folder = os.path.join(ARTIFACT_DIR, "retrieval_analysis_original_profile_withscanorder_full")
question_template = "Patient Profile: {profile}\nScan ordered: {scan_order}"
testcases = [
    query_wrapper(question_template, {"profile": remove_final_sentence(patient_profile, True)[0],
                                      "scan_order": remove_final_sentence(patient_profile, True)[1]})
    for patient_profile in patient_profiles
    ]

description_df, file_df, score_df, table_hitrate_df = retrieval_analysis(
    testcase_df=testcase_df, testcases=testcases,
    table_retriever=table_retriever, text_retriever=text_retriever,
    save_folder=save_folder
)

2023-12-08 19:04:11,339:INFO: Successfully loaded table database k=5 and text database k=5
2023-12-08 19:04:11,339:INFO: Successfully loaded table database k=5 and text database k=5
2023-12-08 19:04:11,410:INFO: Table Node Retrieval Precision: 60.000
2023-12-08 19:04:11,410:INFO: Table Node Retrieval Precision: 60.000
2023-12-08 19:04:11,412:INFO: Table Node Retrieval Recall: 87.324
2023-12-08 19:04:11,412:INFO: Table Node Retrieval Recall: 87.324
2023-12-08 19:04:11,418:INFO: Text Node Retrieval Precision: 67.324
2023-12-08 19:04:11,418:INFO: Text Node Retrieval Precision: 67.324
2023-12-08 19:04:11,420:INFO: Text Node Retrieval Recall: 87.324
2023-12-08 19:04:11,420:INFO: Text Node Retrieval Recall: 87.324
2023-12-08 19:04:11,425:INFO: Total Node Retrieval Precision: 63.662
2023-12-08 19:04:11,425:INFO: Total Node Retrieval Precision: 63.662
2023-12-08 19:04:11,427:INFO: Total Node Retrieval Recall: 90.141
2023-12-08 19:04:11,427:INFO: Total Node Retrieval Recall: 90.141
2023-12-08 1

In [38]:
compare_df = pd.read_csv(
    os.path.join("../dist_compare.csv")
)

In [45]:
from statsmodels.stats.weightstats import ztest as ztest

print(compare_df["table_desc_meta"].mean())
print(compare_df["table_full"].mean())
ztest(compare_df["table_desc_meta"], compare_df["table_full"], value=0)

0.315901198087324
0.32204540616901406


(-2.5617918144365492, 0.010413371867957268)

### Embedding Content:
- Metadata (Descriptions + Conditions)

In [28]:
db_directory = os.path.join(DATA_DIR, "multimodal-chroma", "descriptions")

table_index = load_vectorindex(
    db_directory=os.path.join(db_directory, "tables"),
    emb_store_type="chroma", index_name="tables"
    )
text_index = load_vectorindex(
    db_directory=os.path.join(db_directory, "texts"),
    emb_store_type="chroma", index_name="texts"
    )

table_retriever = table_index.as_retriever(similarity_top_k = 5)
text_retriever = text_index.as_retriever(similarity_top_k = 5)

2023-12-08 18:37:33,520:INFO: chroma VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-chroma/descriptions/tables.
2023-12-08 18:37:33,623:INFO: chroma VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-chroma/descriptions/texts.


In [29]:
# Standard Queries (Original Clinical Profile + Scan order)
save_folder = os.path.join(ARTIFACT_DIR, "retrieval_analysis_original_profile_withscanorder_descriptions")

question_template = "Patient Profile: {profile}\nScan ordered: {scan_order}"
testcases = [
    query_wrapper(question_template, {"profile": remove_final_sentence(patient_profile, True)[0],
                                      "scan_order": remove_final_sentence(patient_profile, True)[1]})
    for patient_profile in patient_profiles
    ]

description_df, file_df, score_df, table_hitrate_df = retrieval_analysis(
    testcase_df=testcase_df, testcases=testcases,
    table_retriever=table_retriever, text_retriever=text_retriever,
    save_folder=save_folder
)

2023-12-08 18:38:56,219:INFO: Successfully loaded table database k=5 and text database k=5
2023-12-08 18:38:56,292:INFO: Table Node Retrieval Precision: 66.761
2023-12-08 18:38:56,293:INFO: Table Node Retrieval Recall: 91.549
2023-12-08 18:38:56,300:INFO: Text Node Retrieval Precision: 71.268
2023-12-08 18:38:56,301:INFO: Text Node Retrieval Recall: 88.732
2023-12-08 18:38:56,304:INFO: Total Node Retrieval Precision: 69.014
2023-12-08 18:38:56,305:INFO: Total Node Retrieval Recall: 92.958
2023-12-08 18:38:56,320:INFO: Mean Table Distance (L2): 0.729
2023-12-08 18:38:56,321:INFO: STD Table Distance (L2): 0.021
2023-12-08 18:38:56,325:INFO: Mean Text Distance (L2): 0.730
2023-12-08 18:38:56,327:INFO: STD Text Distance (L2): 0.020
2023-12-08 18:38:56,330:INFO: Mean Overall Distance (L2): 0.730
2023-12-08 18:38:56,333:INFO: STD Overall Distance (L2): 0.020
2023-12-08 18:38:59,282:INFO: Exact Retrieved Table Nodes Hit Rate: 0.803


In [None]:
# Just Clinical Profile
save_folder = os.path.join(ARTIFACT_DIR, "retrieval_analysis_original_profile_noscan_descriptions")
testcases = [remove_final_sentence(testcase) for testcase in testcase_df["Clinical File"]]

description_df, file_df, score_df, table_hitrate_df = retrieval_analysis(
    testcase_df=testcase_df, testcases=testcases,
    table_retriever=table_retriever, text_retriever=text_retriever,
    save_folder=save_folder
)

In [None]:
# Just Extracted Profile
save_folder = os.path.join(ARTIFACT_DIR, "retrieval_analysis_LLMrefine_noscanorder_descriptions")
with open(os.path.join(ARTIFACT_DIR, "extracted_multiple.json"), "r") as f:
    extracted_multiple = json.load(f)
    extracted_profiles = extracted_multiple["profiles"]
    extracted_multiple_guidelines = extracted_multiple["guidelines"]
    
testcases = [profile.strip() for profile in extracted_profiles]

description_df, file_df, score_df, table_hitrate_df = retrieval_analysis(
    testcase_df=testcase_df, testcases=testcases,
    table_retriever=table_retriever, text_retriever=text_retriever,
    save_folder=save_folder
)

In [None]:
# Extracted Profile + Scan Order
save_folder = os.path.join(ARTIFACT_DIR, "retrieval_analysis_refine_withscanorder_descriptions")

question_template = "Patient Profile: {profile}\nScan ordered: {scan_order}"
testcases = [
    query_wrapper(question_template, {"profile": extracted_profile.strip(),
                                      "scan_order": remove_final_sentence(patient_profile, True)[1]})
    for extracted_profile, patient_profile in zip(extracted_profiles, patient_profiles)
    ]

description_df, file_df, score_df, table_hitrate_df = retrieval_analysis(
    testcase_df=testcase_df, testcases=testcases,
    table_retriever=table_retriever, text_retriever=text_retriever,
    save_folder=save_folder
)

### With Metadata Filter

In [None]:
db_directory = os.path.join(DATA_DIR, "multimodal-chroma", "descriptions")

table_index = load_vectorindex(
    db_directory=os.path.join(db_directory, "tables"),
    emb_store_type="chroma", index_name="tables"
    )
text_index = load_vectorindex(
    db_directory=os.path.join(db_directory, "texts"),
    emb_store_type="chroma", index_name="texts"
    )

table_retriever = table_index.as_retriever(similarity_top_k = 5)
text_retriever = text_index.as_retriever(similarity_top_k = 5)

In [None]:
# Extracted Profile + Scan Order + Original Profiles
save_folder = os.path.join(ARTIFACT_DIR, "retrieval_analysis_original_withscan_metadata_filter_descriptions")

question_template = "Patient Profile: {profile}\nScan ordered: {scan_order}"
testcases = [
    query_wrapper(question_template, {"profile": remove_final_sentence(patient_profile, True)[0],
                                      "scan_order": remove_final_sentence(patient_profile, True)[1]})
    for patient_profile in patient_profiles
    ]

description_df, file_df, score_df, table_hitrate_df = retrieval_analysis(
    testcase_df=testcase_df, testcases=testcases,
    table_retriever=table_retriever, text_retriever=text_retriever,
    metadata_filters = extracted_multiple_guidelines,
    save_folder=save_folder
)

In [None]:
with open(os.path.join(ARTIFACT_DIR, "extracted_multiple.json"), "r") as f:
    extracted_multiple = json.load(f)
    extracted_profiles = extracted_multiple["profiles"]
    extracted_multiple_guidelines = extracted_multiple["guidelines"]
    
extracted_profiles = [profile.strip() for profile in extracted_profiles]

In [None]:
# Extracted Profile + Scan Order + With LLM Refine
save_folder = os.path.join(ARTIFACT_DIR, "retrieval_analysis_refine_withscan_metadata_filter_descriptions")

question_template = "Patient Profile: {profile}\nScan ordered: {scan_order}"
testcases = [
    query_wrapper(question_template, {"profile": extracted_profile.strip(),
                                      "scan_order": remove_final_sentence(patient_profile, True)[1]})
    for extracted_profile, patient_profile in zip(extracted_profiles, patient_profiles)
    ]

description_df, file_df, score_df, table_hitrate_df = retrieval_analysis(
    testcase_df=testcase_df, testcases=testcases,
    table_retriever=table_retriever, text_retriever=text_retriever,
    metadata_filters = extracted_multiple_guidelines,
    save_folder=save_folder
)