# Import Packages

In [473]:
from config import MAIN_DIR
from copy import deepcopy
from custom_storage import load_vectorindex
from pprint import pprint
from statsmodels.stats.weightstats import ztest
from textdistance import levenshtein
from typing import Dict, Union, List, Literal, Sequence, Optional
import numpy as np
import os, json
import pandas as pd
import re
from utils import count_tokens, filter_by_pages

from llama_index import SimpleDirectoryReader, get_response_synthesizer
from llama_index.embeddings import OpenAIEmbedding
from llama_index.embeddings.base import BaseEmbedding
from llama_index.indices.base_retriever import BaseRetriever
from llama_index.schema import Document, NodeWithScore, MetadataMode

In [219]:
DATA_DIR = os.path.join(MAIN_DIR, "data")
ARTIFACT_DIR = os.path.join(MAIN_DIR, "artifacts")
EMB_DIR = os.path.join(DATA_DIR, "emb_store")
DOCUMENT_DIR = os.path.join(MAIN_DIR, "data", "document_sources")
EXCLUDE_DICT = os.path.join(DATA_DIR, "exclude_pages.json")

with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)
    
os.environ["OPENAI_API_KEY"] = api_keys["OPENAI_API_KEY"]
embed_model = OpenAIEmbedding()

In [374]:
def convert_doc_to_dict(doc: Union[Document, NodeWithScore, Dict]) -> Dict:
    if isinstance(doc, NodeWithScore):
        json_doc = {
            "page_content": doc.text,
            "metadata": doc.metadata,
            "score": doc.score
            } 
    elif isinstance(doc, Document):
        json_doc = {
            "page_content": doc.text,
            "metadata": doc.metadata,
            "score": ""
            }
    elif isinstance(doc, Dict):
        json_doc = {
            "page_content": doc["text"],
            "metadata": doc["metadata"],
            "score": "None"
        }
    return json_doc

def remove_final_sentence(
    text: str,
    return_final_sentence: bool = False
):
    text = text.strip()
    if text.endswith("."):
        text = text[:-1]
    sentence_list = text.split(".")
    previous_text = ".".join(sentence_list[:-1])
    final_sentence = sentence_list[-1]
    return (previous_text, final_sentence) if return_final_sentence else previous_text

def query_wrapper(
    template: str, 
    input_text: Union[str, Dict[str, str]]
) -> str:
    placeholders = re.findall(pattern = r"{([A-Za-z0-9_-]+)}", string=template)
    if isinstance(input_text, str):
        assert len(placeholders) == 1, "Must Provide a single placeholder when input_text is string."
        placeholder = placeholders[0]
        return template.format(**{placeholder:input_text})
    
    assert len(input_text) == len(placeholders)
    for key in input_text.keys():
        assert key in placeholders, f"{key} not present in template."
    
    return template.format(**input_text)

def calculate_emb_distance(
    emb1: List[float],
    emb2: List[float],
    dist_type: Literal["l2", "ip", "cosine", "neg_exp_l2"] = "l2"
):
    assert len(emb1) == len(emb2), "Length of embedding vectors must match"
    if dist_type == "l2":
        return np.square(np.linalg.norm(np.array(emb1) - np.array(emb2)))
    elif dist_type == "ip":
        return 1 - np.dot(emb1, emb2)
    elif dist_type == "cosine":
        cosine_similarity = np.dot(emb1, emb2)/(np.norm(emb1)*np.norm(emb2))
        return 1 - cosine_similarity
    elif dist_type == "neg_exp_l2":
        return np.exp(-np.square(np.linalg.norm(np.array(emb1) - np.array(emb2))))
    else:
        raise ValueError("Invalid distance type")
    
def calculate_string_distance(
    str1: str,
    str2: Union[str, Sequence[str]],
    embeddings: BaseEmbedding,
    dist_type: Literal["l2", "ip", "cosine", "neg_exp_l2"] = "l2"
):
    emb1 = embeddings.get_query_embedding(str1)
    if isinstance(str2, str):
        emb2 = embeddings.get_text_embedding(str2)
        return calculate_emb_distance(emb1, emb2, dist_type)
    else:
        emb2_list = embeddings.get_text_embedding_batch(str2)
        return [calculate_emb_distance(emb1, emb2) for emb2 in emb2_list]
    
def remove_final_sentence(
    text: str,
    return_final_sentence: bool = False
):
    text = text.strip()
    if text.endswith("."):
        text = text[:-1]
    sentence_list = text.split(".")
    previous_text = ".".join(sentence_list[:-1])
    final_sentence = sentence_list[-1]
    return (previous_text, final_sentence) if return_final_sentence else previous_text

def query_wrapper(
    template: str, 
    input_text: Union[str, Dict[str, str]]
) -> str:
    placeholders = re.findall(pattern = r"{([A-Za-z0-9_-]+)}", string=template)
    if isinstance(input_text, str):
        assert len(placeholders) == 1, "Must Provide a single placeholder when input_text is string."
        placeholder = placeholders[0]
        return template.format(**{placeholder:input_text})
    
    assert len(input_text) == len(placeholders)
    for key in input_text.keys():
        assert key in placeholders, f"{key} not present in template."
    
    return template.format(**input_text)

# Load Test Data

In [222]:
testcase_df = pd.read_csv(
        os.path.join(DATA_DIR, "queries", "MSK LLM Fictitious Case Files Full.csv"),
        usecols = ['ACR scenario', 'Guideline', 'Variant', 'Appropriateness Category',
                   'MRI scan ordered', 'Clinical File']
        )
patient_profiles = testcase_df["Clinical File"]
scan_orders = testcase_df["MRI scan ordered"]

question_template = "Patient Profile: {profile}\nScan ordered: {scan_order}"

testcase_df["queries"] = [
    query_wrapper(question_template, {"profile": remove_final_sentence(patient_profile, True)[0],
                                      "scan_order": remove_final_sentence(patient_profile, True)[1]})
    for patient_profile in patient_profiles
    ]

In [223]:
documents = SimpleDirectoryReader(DOCUMENT_DIR).load_data()
print("Total no of docs before filtering:", len(documents))
with open(EXCLUDE_DICT, "r") as f:
    exclude_pages = json.load(f)
documents = filter_by_pages(doc_list=documents, exclude_info=exclude_pages)
print("Total number of docs after filtering", len(documents))

Total no of docs before filtering: 546
Total number of docs after filtering 395


# Evaluation

## Test LLM Extraction Hit Rate

In [420]:
with open(os.path.join(ARTIFACT_DIR, "extracted_best.json"), "r") as f:
    extracted_best = json.load(f)
    extracted_profiles = extracted_best["profiles"]
    extracted_best_guidelines = extracted_best["guidelines"]
    
with open(os.path.join(ARTIFACT_DIR, "extracted_multiple.json"), "r") as f:
    extracted_multiple = json.load(f)
    extracted_multiple_guidelines = extracted_multiple["guidelines"]
    
extracted_profiles = [profile.strip() for profile in extracted_profiles]

In [441]:
emb_df = deepcopy(testcase_df)
emb_df["llm_profile"] = extracted_profiles
emb_df["best_guideline"] = extracted_best_guidelines
emb_df["multiple_guidelines"] = extracted_multiple_guidelines
emb_df["condition"] = emb_df["Guideline"].apply(lambda x: x[4:-4]).str.lower()
emb_df["best_hit"] = [(guideline in best_guideline) or (app=="ICI" and not best_guideline)
                      for guideline, best_guideline, app in zip(emb_df["condition"], emb_df["best_guideline"], emb_df["Appropriateness Category"])]
emb_df["multiple_hit"] = [(guideline in multiple_guidelines) or (app=="ICI" and not multiple_guidelines)
                          for guideline, multiple_guidelines, app in zip(emb_df["condition"], emb_df["multiple_guidelines"], emb_df["Appropriateness Category"])]

In [442]:
print("Best Hit Rate: ", emb_df["best_hit"].sum()/len(emb_df))
print("Multiple Hit Rate: ", emb_df["multiple_hit"].sum()/len(emb_df))

Best Hit Rate:  0.8873239436619719
Multiple Hit Rate:  0.9436619718309859


In [None]:
for idx in emb_df[~emb_df["multiple_hit"]].index:
    case_file = emb_df.iloc[idx, :]
    print("Case:", idx + 1)
    pprint("Patient: " + remove_final_sentence(case_file["Clinical File"]))
    pprint("Variant: " + case_file["ACR scenario"])
    pprint(case_file["multiple_guidelines"])
    pprint(case_file["Guideline"])
    print()

Case: 19
('Patient: 46 year old Chinese male.  Businessman, sales.  Frequent drinker '
 'due to job, about 1-2 beers/day.  Past medical history of fatty liver, acute '
 'cholecystitis post cholecystectomy.  Now presenting with severe pain at '
 'right big toe for 2 weeks, pain on and off improves with paracetamol and '
 'ibuprofen.  On examination: swelling and erythema at right big toe 1st '
 'metatarsophalangeal joint.  No prior imaging')
('Variant: Chronic extremity joint pain. Suspect inflammatory (seropositive or '
 'seronegative arthritis), crystalline (gout or pseudogout), or erosive '
 'osteoarthritis. Initial imaging.')
['chronic foot pain', 'suspected om septic arthritis soft tissue infection']
'ACR chroni extremity joint pain inflammatory arthritis.pdf'

Case: 20
('Patient: 55 year old Philipino female.  Domestic helper.  No significant '
 'past medical history.  Bilateral finger joint pain, swelling and stiffness '
 'worse in the morning for 4 months, worsening recently.  N

## Test Embeddings

### Embedding Content:
- Metadata (Descriptions + Conditions)

In [468]:
emb_analysis_path = os.path.join(ARTIFACT_DIR, "emb_analysis")
if not os.path.exists(emb_analysis_path):
    os.makedirs(emb_analysis_path, exist_ok=True)

In [465]:
emb_df["original_dist"] = emb_df.apply(lambda x: calculate_string_distance(x["Clinical File"], x["ACR scenario"], embed_model), axis=1)
emb_df["query_dist"] = emb_df.apply(lambda x: calculate_string_distance(x["queries"], x["ACR scenario"], embed_model), axis=1)
emb_df["refine_dist"] = emb_df.apply(lambda x: calculate_string_distance(x["llm_profile"], x["ACR scenario"], embed_model), axis=1)

In [469]:
emb_df.to_csv(os.path.join(emb_analysis_path, "emb_analysis.csv"))

In [None]:
db_directory = os.path.join(DATA_DIR, "multimodal-chroma", "descriptions")
table_index = load_vectorindex(os.path.join(db_directory, "tables"), "chroma")
text_index = load_vectorindex(os.path.join(db_directory, "texts"), "chroma")

table_retriever = table_index.as_retriever(
    similarity_top_k = 5,
    filters = None
    )
text_retriever = text_index.as_retriever(
    similarity_top_k = 5,
    filters = None
)

2023-10-26 17:46:34,019:INFO: chroma VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-chroma/descriptions/tables.
2023-10-26 17:46:34,268:INFO: chroma VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-chroma/descriptions/texts.


## Test Retriever

### Embedding Content:
- Metadata (Descriptions + Conditions)
- Table Texts

In [225]:
db_directory = os.path.join(DATA_DIR, "multimodal-faiss", "full")
table_index = load_vectorindex(os.path.join(db_directory, "tables"), "faiss")
text_index = load_vectorindex(os.path.join(db_directory, "texts"), "faiss")

table_retriever = table_index.as_retriever(
    similarity_top_k = 5, filters = None
    )
text_retriever = text_index.as_retriever(
    similarity_top_k = 5, filters = None
)

2023-10-26 12:06:38,503:INFO: faiss VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-faiss/full/tables.
2023-10-26 12:06:38,719:INFO: faiss VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal-faiss/full/texts.


In [179]:
retrieved_table_nodes = []
retrieved_text_nodes = []
for test_case in testcase_df["queries"]:
    retrieved_table_nodes.append(table_retriever.retrieve(test_case))
    retrieved_text_nodes.append(text_retriever.retrieve(test_case))

In [180]:
save_folder = os.path.join(ARTIFACT_DIR, "retrieval_analysis_full")

if not os.path.exists(save_folder):
    os.makedirs(save_folder, exist_ok=True)

In [226]:
# Generate Retrieved Context from Nodes

description_df = deepcopy(testcase_df)

retrieved_tables = [[] for _ in range(5)]
retrieved_texts = [[] for _ in range(5)]
for i in range(len(description_df["queries"])):
    for idx, node in enumerate(retrieved_table_nodes[i]):
        node_info = node.get_content(MetadataMode.EMBED) + "\n\nScore: {}".format(node.score)
        retrieved_tables[idx].append(node_info)
    for idx, node in enumerate(retrieved_text_nodes[i]):
        node_info = node.get_content(MetadataMode.EMBED) + "\n\nScore: {}".format(node.score)
        retrieved_texts[idx].append(node_info)
        
for idx, tables in enumerate(retrieved_tables):
    description_df[f"Table_{idx+1}"] = retrieved_tables[idx]
    
for idx, texts in enumerate(retrieved_texts):
    description_df[f"Text_{idx+1}"] = retrieved_texts[idx]

description_df.to_csv(os.path.join(save_folder, "table_text.csv"))

In [197]:
# HIT Rate of Nodes at document level
file_df = deepcopy(testcase_df)

retrieved_table_pages = [[] for _ in range(5)]
retrieved_text_pages = [[] for _ in range(5)]
for i in range(len(file_df["queries"])):
    for idx, node in enumerate(retrieved_table_nodes[i]):
        node_file = node.metadata["file_name"]
        retrieved_table_pages[idx].append(node_file)
    for idx, node in enumerate(retrieved_text_nodes[i]):
        node_file = node.metadata["file_name"]
        retrieved_text_pages[idx].append(node_file)
        
for idx, tables in enumerate(retrieved_table_pages):
    file_df[f"Table_{idx+1}_Page"] = retrieved_table_pages[idx]
    file_df[f"Table_{idx+1}_HIT"] = file_df[f"Table_{idx+1}_Page"] == file_df["Guideline"]
    
for idx, texts in enumerate(retrieved_text_pages):
    file_df[f"Text_{idx+1}_Page"] = retrieved_text_pages[idx]
    file_df[f"Text_{idx+1}_HIT"] = file_df[f"Text_{idx+1}_Page"] == file_df["Guideline"]

file_df["Total Table HITs"] = file_df[[f"Table_{idx+1}_HIT"for idx in range(5)]].sum(axis=1)
file_df["Total Text HITs"] = file_df[[f"Text_{idx+1}_HIT"for idx in range(5)]].sum(axis=1)
file_df["Total HITs"] = file_df[["Total Table HITs", "Total Text HITs"]].sum(axis=1)
file_df.to_csv(os.path.join(save_folder, "table_text_pages.csv"))

In [184]:
# L2 Distance of Nodes
score_df = deepcopy(testcase_df)

retrieved_table_scores = [[] for _ in range(5)]
retrieved_texts_scores = [[] for _ in range(5)]

for i in range(len(score_df["queries"])):
    for idx, node in enumerate(retrieved_table_nodes[i]):
        retrieved_table_scores[idx].append(node.score)
    for idx, node in enumerate(retrieved_text_nodes[i]):
        retrieved_texts_scores[idx].append(node.score)
        
for idx, tables in enumerate(retrieved_table_scores):
    score_df[f"Table_{idx+1}"] = retrieved_table_scores[idx]
    
for idx, texts in enumerate(retrieved_texts_scores):
    score_df[f"Text_{idx+1}"] = retrieved_texts_scores[idx]
    
score_df['avg_table'] = score_df[[f"Table_{idx+1}" for idx in range(5)]].mean(axis=1)
score_df['avg_text'] = score_df[[f"Text_{idx+1}" for idx in range(5)]].mean(axis=1)
score_df['avg_overall'] = score_df[["avg_table", "avg_text"]].mean(axis=1)

score_df.to_csv(os.path.join(save_folder, "table_text_scores.csv"))

In [185]:
# Generate Hit Rate of relevant table & Recall Score
def add_punctuation(text: str):
    if not text.endswith("."):
        text = text + "."
    return text

table_hitrate_df = deepcopy(testcase_df)

table_hitrate_df['ACR scenario'] = table_hitrate_df['ACR scenario'].str.strip().apply(lambda x: add_punctuation(x))
retrieved_table_descriptions = [[] for _ in range(5)]

for i in range(len(table_hitrate_df["queries"])):
    for idx, node in enumerate(retrieved_table_nodes[i]):
        retrieved_table_descriptions[idx].append(node.metadata["description"].strip())
        
for idx, tables in enumerate(retrieved_table_descriptions):
    table_hitrate_df[f"Table_{idx+1}_descriptions"] = retrieved_table_descriptions[idx]
    table_hitrate_df[f"Table_{idx+1}_LVscore"] = table_hitrate_df.apply(lambda x: levenshtein.distance(x["ACR scenario"], x[f"Table_{idx+1}_descriptions"]), axis=1)
    
table_hitrate_df["Min_LV_Dist"] = table_hitrate_df[[f"Table_{idx+1}_LVscore" for idx in range(5)]].min(axis=1)
table_hitrate_df["HIT"] = table_hitrate_df["Min_LV_Dist"] < 5

print(table_hitrate_df["HIT"].sum() / len(table_hitrate_df["HIT"]))

table_hitrate_df.to_csv(os.path.join(save_folder, "hit_rate.csv"))

0.6619718309859155


In [156]:
combined_score_df = pd.concat(
    [score_df[['ACR scenario', 'Appropriateness Category', 'MRI scan ordered',
       'Clinical File', 'queries', 'avg_table', 'avg_text', 'avg_overall']],
     table_hitrate_df[['Min_LV_Dist', 'HIT']]], axis=1)

In [157]:
hit_similarity = combined_score_df[combined_score_df["HIT"]]["avg_table"]
miss_simiarity = combined_score_df[~combined_score_df["HIT"]]["avg_table"]

ztest(hit_similarity.values, miss_simiarity.values, value=0)

(-1.2702985748691185, 0.20397829565210468)

### Embedding Content:
- Metadata (Descriptions + Conditions)

In [198]:
db_directory = os.path.join(DATA_DIR, "multimodal", "descriptions")
table_index = load_vectorindex(os.path.join(db_directory, "tables"), "faiss")
text_index = load_vectorindex(os.path.join(db_directory, "texts"), "faiss")

table_retriever = table_index.as_retriever(
    similarity_top_k = 5,
    filters = None
    )
text_retriever = text_index.as_retriever(
    similarity_top_k = 5,
    filters = None
)

2023-10-24 20:07:20,665:INFO: faiss VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal/descriptions/tables.


2023-10-24 20:07:20,740:INFO: faiss VectorStore successfully loaded from /mnt/c/Users/QUAN/Desktop/lbp_mri/data/multimodal/descriptions/texts.


In [199]:
retrieved_table_nodes = []
retrieved_text_nodes = []
for test_case in testcase_df["queries"]:
    retrieved_table_nodes.append(table_retriever.retrieve(test_case))
    retrieved_text_nodes.append(text_retriever.retrieve(test_case))

In [200]:
save_folder = os.path.join(ARTIFACT_DIR, "retrieval_analysis_descriptions")

if not os.path.exists(save_folder):
    os.makedirs(save_folder, exist_ok=True)

In [206]:
description_df = deepcopy(testcase_df)

retrieved_tables = [[] for _ in range(5)]
retrieved_texts = [[] for _ in range(5)]
for i in range(len(description_df["queries"])):
    for idx, node in enumerate(retrieved_table_nodes[i]):
        node_info = node.get_content(MetadataMode.EMBED) + "\n\nScore: {}".format(node.score)
        retrieved_tables[idx].append(node_info)
    for idx, node in enumerate(retrieved_text_nodes[i]):
        node_info = node.get_content(MetadataMode.EMBED) + "\n\nScore: {}".format(node.score)
        retrieved_texts[idx].append(node_info)
        
for idx, tables in enumerate(retrieved_tables):
    description_df[f"Table_{idx+1}"] = retrieved_tables[idx]
    
for idx, texts in enumerate(retrieved_texts):
    description_df[f"Text_{idx+1}"] = retrieved_texts[idx]

description_df.to_csv(os.path.join(save_folder, "table_text.csv"))

In [207]:
file_df = deepcopy(testcase_df)

retrieved_table_pages = [[] for _ in range(5)]
retrieved_text_pages = [[] for _ in range(5)]
for i in range(len(file_df["queries"])):
    for idx, node in enumerate(retrieved_table_nodes[i]):
        node_file = node.metadata["file_name"]
        retrieved_table_pages[idx].append(node_file)
    for idx, node in enumerate(retrieved_text_nodes[i]):
        node_file = node.metadata["file_name"]
        retrieved_text_pages[idx].append(node_file)
        
for idx, tables in enumerate(retrieved_table_pages):
    file_df[f"Table_{idx+1}_Page"] = retrieved_table_pages[idx]
    file_df[f"Table_{idx+1}_HIT"] = file_df[f"Table_{idx+1}_Page"] == file_df["Guideline"]
    
for idx, texts in enumerate(retrieved_text_pages):
    file_df[f"Text_{idx+1}_Page"] = retrieved_text_pages[idx]
    file_df[f"Text_{idx+1}_HIT"] = file_df[f"Text_{idx+1}_Page"] == file_df["Guideline"]

file_df["Total Table HITs"] = file_df[[f"Table_{idx+1}_HIT"for idx in range(5)]].sum(axis=1)
file_df["Total Text HITs"] = file_df[[f"Text_{idx+1}_HIT"for idx in range(5)]].sum(axis=1)
file_df["Total HITs"] = file_df[["Total Table HITs", "Total Text HITs"]].sum(axis=1)
file_df.to_csv(os.path.join(save_folder, "table_text_pages.csv"))

In [208]:
score_df = deepcopy(testcase_df)

retrieved_table_scores = [[] for _ in range(5)]
retrieved_texts_scores = [[] for _ in range(5)]

for i in range(len(score_df["queries"])):
    for idx, node in enumerate(retrieved_table_nodes[i]):
        retrieved_table_scores[idx].append(node.score)
    for idx, node in enumerate(retrieved_text_nodes[i]):
        retrieved_texts_scores[idx].append(node.score)
        
for idx, tables in enumerate(retrieved_table_scores):
    score_df[f"Table_{idx+1}"] = retrieved_table_scores[idx]
    
for idx, texts in enumerate(retrieved_texts_scores):
    score_df[f"Text_{idx+1}"] = retrieved_texts_scores[idx]
    
score_df['avg_table'] = score_df[[f"Table_{idx+1}" for idx in range(5)]].mean(axis=1)
score_df['avg_text'] = score_df[[f"Text_{idx+1}" for idx in range(5)]].mean(axis=1)
score_df['avg_overall'] = score_df[["avg_table", "avg_text"]].mean(axis=1)

score_df.to_csv(os.path.join(save_folder, "table_text_scores.csv"))

In [209]:
def add_punctuation(text: str):
    if not text.endswith("."):
        text = text + "."
    return text

table_hitrate_df = deepcopy(testcase_df)

table_hitrate_df['ACR scenario'] = table_hitrate_df['ACR scenario'].str.strip().apply(lambda x: add_punctuation(x))
retrieved_table_descriptions = [[] for _ in range(5)]

for i in range(len(table_hitrate_df["queries"])):
    for idx, node in enumerate(retrieved_table_nodes[i]):
        retrieved_table_descriptions[idx].append(node.metadata["description"].strip())
        
for idx, tables in enumerate(retrieved_table_descriptions):
    table_hitrate_df[f"Table_{idx+1}_descriptions"] = retrieved_table_descriptions[idx]
    table_hitrate_df[f"Table_{idx+1}_LVscore"] = table_hitrate_df.apply(lambda x: levenshtein.distance(x["ACR scenario"], x[f"Table_{idx+1}_descriptions"]), axis=1)
    
table_hitrate_df["Min_LV_Dist"] = table_hitrate_df[[f"Table_{idx+1}_LVscore" for idx in range(5)]].min(axis=1)
table_hitrate_df["HIT"] = table_hitrate_df["Min_LV_Dist"] < 5

print(table_hitrate_df["HIT"].sum() / len(table_hitrate_df["HIT"]))

table_hitrate_df.to_csv(os.path.join(save_folder, "hit_rate.csv"))

0.7887323943661971


In [210]:
combined_score_df = pd.concat(
    [score_df[['ACR scenario', 'Appropriateness Category', 'MRI scan ordered',
       'Clinical File', 'queries', 'avg_table', 'avg_text', 'avg_overall']],
     table_hitrate_df[['Min_LV_Dist', 'HIT']]], axis=1)

In [211]:
hit_similarity = combined_score_df[combined_score_df["HIT"]]["avg_table"]
miss_simiarity = combined_score_df[~combined_score_df["HIT"]]["avg_table"]

ztest(hit_similarity.values, miss_simiarity.values, value=0)

(-0.38182527768811947, 0.7025909678945342)

In [None]:
def retrieval_analysis(
    table_retriever: Optional[BaseRetriever] = None,
    text_retriever: Optional[BaseRetriever] = None,
    save_folder: Optional[str] = None
):