### Load RAG-LLM pipeline

In [1]:

# ================== GENERAL IMPORTS ==================
import os
import json
from dotenv import load_dotenv

#imports for this testing
import sys
script_dir = os.getcwd()
root_dir = os.path.join(os.path.dirname(os.path.abspath(script_dir)))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(script_dir))))

# ================== UTIL FUNCTIONS ==================
from utils.embedding import get_context_db
from utils.prompt import get_prompt
from llm.run_RAGLLM import run_RAG

# ================== MODEL & API IMPORTS ==================
from openai import OpenAI
from llm.inference import run_llm
import faiss

# --- module state ---
_READY = False
_CLIENT = None
_CONTEXT = None
_INDEX = None
_MODEL_TYPE = "gpt"
_MODEL_NAME = None
_MODEL_EMBED = "text-embedding-3-small"

def reset():
    global _READY, _CLIENT, _CONTEXT, _INDEX, _MODEL_TYPE, _MODEL_NAME, _MODEL_EMBED
    _READY = False
    _CLIENT = None
    _CONTEXT = None
    _INDEX = None
    _MODEL_NAME = None

def _cache_paths(embed_name: str, version: str = "v1"):
    os.makedirs("indexes", exist_ok=True)
    return (
        f"indexes/{embed_name}__{version}.faiss",
        f"indexes/{embed_name}__{version}.context.json",
    )

def init(
    context_json_path: str = "data/structured_context_chunks.json",
    model_api: str = "gpt-4o-2024-05-13",
    *,
    force_rebuild: bool = False,
):
    """
    Initializes OpenAI client, loads (or builds) the FAISS index ONCE per process.
    """
    global _READY, _CLIENT, _CONTEXT, _INDEX, _MODEL_NAME, _MODEL_EMBED
    if _READY:
        return

    load_dotenv()
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise RuntimeError("OPENAI_API_KEY not set")
    _CLIENT = OpenAI(api_key=api_key)
    _MODEL_NAME = model_api

    index_path, ctx_path = _cache_paths(_MODEL_EMBED)

    # Always prefer cached index if present; only build once
    if (not force_rebuild) and os.path.exists(index_path) and os.path.exists(ctx_path):
        with open(ctx_path, "r") as f:
            _CONTEXT = json.load(f)
        _INDEX = faiss.read_index(index_path)
    else:
        with open(context_json_path, "r") as f:
            _CONTEXT = json.load(f)
        _INDEX = get_context_db(_CONTEXT, _CLIENT, _MODEL_EMBED)
        faiss.write_index(_INDEX, index_path)
        with open(ctx_path, "w") as f:
            json.dump(_CONTEXT, f)

    _READY = True

def answer(
    text: str,
    strategy: int = 0,
    num_vec: int = 10,
    max_len: int = 2048,
    temp: float = 0.0,
    random_seed: int = 2025,
    *,
    rag: bool = True,
) -> str:
    if not _READY:
        raise RuntimeError("Call init(...) first.")
    if rag:
        out, _ = run_RAG(
            _CONTEXT,
            text,
            strategy,
            _INDEX,
            _CLIENT,
            num_vec,
            _MODEL_TYPE,
            _MODEL_NAME,
            _MODEL_EMBED,
            max_len,
            temp,
            random_seed,
        )
    else:
        # LLM-only path (no retrieval)
        input_prompt = get_prompt(strategy, text)
        out, _ = run_llm(input_prompt, _CLIENT, _MODEL_TYPE, _MODEL_NAME, max_len, temp, random_seed)
    return out or ""

In [2]:
#other imports needed for testing
import pandas as pd
import ast
from utils.io import load_object, save_object


### Load synthetic query and moalmanac db

In [3]:
#load synthetic query entities 
synthetic_query_ner = pd.read_csv(f"{root_dir}/data/synthetic_query_ner_entities.csv")
synthetic_query_ner["standardized_entities_dict"] = synthetic_query_ner["standardized_entities_dict"].apply(ast.literal_eval)
print(synthetic_query_ner.head())

#and moalmanac entities
moalmanac_db = pd.read_csv(f"{root_dir}/data/latest_db/moalmanac-draft.dereferenced.unique.context_db.standardized_entities.csv")
moalmanac_db['standardized_entities_dict'] = moalmanac_db['standardized_entities_dict'].apply(ast.literal_eval)

                                     synthetic_query  \
0  If a chronic myelogenous leukemia patient has ...   
1  If a chronic myelogenous leukemia patient has ...   
2  If a acute lymphoblastic leukemia patient has ...   
3  If a acute lymphoblastic leukemia patient has ...   
4  If a chronic myeloid leukemia patient has a re...   

                                    biobert_entities  \
0  [{"type": "Cancer", "text": "chronic myelogeno...   
1  [{"type": "Cancer", "text": "chronic myelogeno...   
2  [{"type": "Cancer", "text": "acute lymphoblast...   
3  [{"type": "Cancer", "text": "acute lymphoblast...   
4  [{"type": "Cancer", "text": "chronic myeloid l...   

                          standardized_entities_dict  
0  {'cancer_type': ['CHRONIC MYELOGENOUS LEUKEMIA...  
1  {'cancer_type': ['CHRONIC MYELOGENOUS LEUKEMIA...  
2  {'cancer_type': ['ACUTE LYMPHOBLASTIC LEUKEMIA...  
3  {'cancer_type': ['ACUTE LYMPHOBLASTIC LEUKEMIA...  
4  {'cancer_type': ['CHRONIC MYELOID LEUKEMIA'], ..

In [None]:
#save moalmanac context as json
with open(f"{root_dir}/data/latest_db/moalmanac-draft.dereferenced.unique.context_db.json", "w") as f:
    json.dump(list(moalmanac_db['context'].values), f)

#save ner entities as json
os.makedirs("entities", exist_ok=True)
with open("entities/moalmanac_db_ner_entities.json", "w") as f:
    json.dump(list(moalmanac_db['standardized_entities_dict'].values), f)
with open("entities/synthetic_query_ner_entities.json", "w") as f:
    json.dump(list(synthetic_query_ner['standardized_entities_dict']), f)

In [4]:
new_context_json_path = f"{root_dir}/data/latest_db/moalmanac-draft.dereferenced.unique.context_db.json"
db_entities_path = "entities/moalmanac_db_ner_entities.json"
query_entities_path = "entities/synthetic_query_ner_entities.json"
with open(db_entities_path, "r") as f:
    _DB_ENTITY = json.load(f)
with open(query_entities_path, "r") as f:
    _QUERY_ENTITY = json.load(f)

In [5]:
def init(
    context_json_path: str = new_context_json_path,
    model_api: str = "gpt-4o-2024-05-13",
    *,
    force_rebuild: bool = False,
):
    """
    Initializes OpenAI client, loads (or builds) the FAISS index ONCE per process.
    """
    global _READY, _CLIENT, _CONTEXT, _INDEX, _MODEL_NAME, _MODEL_EMBED
    if _READY:
        return

    load_dotenv()
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise RuntimeError("OPENAI_API_KEY not set")
    _CLIENT = OpenAI(api_key=api_key)
    _MODEL_NAME = model_api

    index_path, ctx_path = _cache_paths(_MODEL_EMBED)

    # Always prefer cached index if present; only build once
    if (not force_rebuild) and os.path.exists(index_path) and os.path.exists(ctx_path):
        with open(ctx_path, "r") as f:
            _CONTEXT = json.load(f)
        _INDEX = faiss.read_index(index_path)
    else:
        with open(context_json_path, "r") as f:
            _CONTEXT = json.load(f)
        _INDEX = get_context_db(_CONTEXT, _CLIENT, _MODEL_EMBED)
        faiss.write_index(_INDEX, index_path)
        with open(ctx_path, "w") as f:
            json.dump(_CONTEXT, f)

    _READY = True


In [20]:
init()

### Load NER-based retrieval pipeline

In [8]:
from rapidfuzz import fuzz

def check_list(input):
    if isinstance(input, list):
        input = input
    else:
        input = [input]
    return input

def match_entities(user_entities, db_entities, fuzzy_thres=70, id_col='statement_id', entities_col='standardized_entities_dict', context_col='context'):
    """
    Calculate score based on matching cancer types and biomarkers between user's query and the database
    
    Arguments:
        user_entities (dict): A dictionary with 'cancer_type' and 'biomarker' entities extracted using biobert
        db (DataFrame): A context database with 'cancer_type' and 'biomarker' entities extracted using biobert, context, and id

    """
    
    user_cancer = user_entities['cancer_type']
    user_biomarker = user_entities['biomarker']
    user_cancer = check_list(user_cancer)
    user_biomarker = check_list(user_biomarker)
    match_score_all=[]
    
    #iterate over all db entities
    for idx, db_entity in enumerate(db_entities):
        
        #initialize matching score
        score=0
        
        #append matching count
        for db_cancer in db_entity['cancer_type']: 
            db_cancer = check_list(db_cancer)
            if len(set(db_cancer) & set(user_cancer)) > 0:
                score += len(set(db_cancer) & set(user_cancer))
            elif any(fuzz.ratio(dbc, uc) > fuzzy_thres for uc in user_cancer for dbc in db_cancer):
                score += 0.5
        
        for db_biomarker in db_entity['biomarker']:
            db_biomarker = check_list(db_biomarker)
            if len(set(db_biomarker) & set(user_biomarker)) > 0:
                score += len(set(db_biomarker) & set(user_biomarker))
            elif any(fuzz.ratio(dbb, ub) > fuzzy_thres for ub in user_biomarker for dbb in db_biomarker):
                score += 0.5

        if score > 0:
            match_score_all.append((idx, score, db_entity))

    #sort by score descending
    match_score_all.sort(key=lambda x: x[1], reverse=True)
    return match_score_all

### Test on one synthetic query

In [None]:
#load ground truth answers
synthetic_prompt_groundtruth_dict=load_object(os.path.join(root_dir, 'data/synthetic_prompt_groundtruth_dict.pkl'))

In [49]:
_QUERY_ENTITY[0]['cancer_type']

['CHRONIC MYELOGENOUS LEUKEMIA']

In [61]:
test_idx = 99
prompt_chunk = synthetic_query_ner.synthetic_query[test_idx]
matched_score_all=match_entities(_QUERY_ENTITY[test_idx], _DB_ENTITY)
true_drug=synthetic_prompt_groundtruth_dict[prompt_chunk]

print(prompt_chunk)
print(matched_score_all)
print(true_drug)

If a medullary thyroid cancer patient has a somatic variant in gene RET, and is advanced or metastatic, what would be the suggested lines of treatment?
[(438, 2, {'cancer_type': ['MEDULLARY THYROID CANCER'], 'biomarker': ['RET']}), (392, 0.5, {'cancer_type': ['PAPILLARY THYROID CANCER'], 'biomarker': ['V::RET']})]
[{'selpercatinib'}]


In [12]:
ner_selected_indices = []
for matched in matched_score_all:
    ner_selected_indices.append(matched[0])
ner_selected_indices

[438, 392]

[Retrieval logic]
1. Retrieve context chunks based on openai embeddings using faiss cosine similarity search.
2. Find context with matching cancer types and biomarkers. If none exists, provide retrieved chunks from 1. in the prompt.

If context with matching features exists:

3. If it exists in the retrieved list from 1., re-position them to the top of the list. If they don't exist in 1., add them to the beginning of the list
4. Provide the reordered context in the prompt

In [46]:
import numpy as np
from utils.embedding import get_text_embedding, prep_embed_for_search
from sklearn.metrics.pairwise import cosine_similarity

#test single query
test_idx = 99
prompt_chunk = synthetic_query_ner.synthetic_query[test_idx]

#compute query embeddings
query_embeddings = np.array([get_text_embedding(prompt_chunk, _CLIENT, _MODEL_EMBED)])
print(query_embeddings[0][0])

#normalize query embeddings for cosine similarity
query_embeddings_norm = query_embeddings/np.linalg.norm(query_embeddings, axis=1, keepdims=True)
print(query_embeddings_norm.shape[1])

#search similar embeddings
num_vec=10
D, I = _INDEX.search(prep_embed_for_search(query_embeddings_norm, n_dim=2), k=num_vec) #D: cosine similarity scores

#matching context index and score for the test query
matched_score_all=match_entities(_QUERY_ENTITY[test_idx], _DB_ENTITY)
if len(matched_score_all) > 0:
    ner_selected_indices = [matched[0] for matched in matched_score_all]
    ner_selected_scores = [matched[1] for matched in matched_score_all]
else:
    ner_selected_indices = []
    ner_selected_scores = []

#if matching context exists, reorder retrieved context chunk; otherwise, retrieve based on similarity search only
if ner_selected_indices:
    #reorder retrieved context chunk
    retrieved_indices = I[0]
    ordered = ner_selected_indices.copy()
    ordered += [i for i in retrieved_indices if i not in ner_selected_indices]
print(retrieved_indices)
print(ner_selected_indices)
print(ordered)

#retrieve final context text
retrieved_contexts = [_CONTEXT[i] for i in ordered]


-0.015075127594172955
1536
[439 438 440 437 392 398 397 396 130 129]
[438, 392]
[438, 392, 439, 440, 437, 398, 397, 396, 130, 129]


In [63]:
0.7*2

1.4

In [43]:
D = np.linalg.norm(subset_embeddings - query_embeddings, axis=1)
D

array([0.8553842 , 0.9072539 , 0.83718907, 0.86318736, 0.878151  ,
       0.94753678, 0.94785957, 0.94824718, 0.95342115, 0.95431344])

In [46]:
np.linalg.norm(subset_embeddings[0] - query_embeddings.squeeze(), axis=0)

0.8553841958528482

## Test latest db context augmentation on existing pipeline & synthetic queries

In [3]:
from utils.evaluation import calc_eval_metrics
moalmanac_data = pd.read_csv(f"{root_dir}/data/moa_fda_queries_answers.csv", index_col=0)
synthetic_prompt_groundtruth_dict=load_object(os.path.join(root_dir, 'data/synthetic_prompt_groundtruth_dict.pkl'))

Query: synthetic free-text queries \
Context: structured json-style context from latest moalmanac db

In [33]:
rag_gpt4o_new_db_res_dict = load_object(filename=os.path.join(root_dir, 'output/RAG_res_gpt4o/structured_latest_db/RAGstra0n1temp0.0_res_dict.pkl'))
rag_gpt4o_new_db_eval_res = calc_eval_metrics(rag_gpt4o_new_db_res_dict['full output'][0], moalmanac_data['prompt'], synthetic_prompt_groundtruth_dict)
rag_gpt4o_new_db_eval_res

{'avg_exact_match_acc': 0.49145299145299143,
 'avg_partial_match_acc': 0.6153846153846154,
 'avg_precision': 0.46994301994301996,
 'avg_recall': 0.5534188034188035,
 'avg_f1': 0.4871184371184376,
 'avg_specificity': 0.9840810552178076,
 'exact_match_acc': [False,
  False,
  False,
  False,
  True,
  True,
  False,
  False,
  False,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  False,
  False,
  True,
  True,
  True,
  False,
  False,
  False,
  False,
  True,
  False,
  False,
  True,
  True,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  True,
  True,
  True,
  False,
  False,
  False,
  False,
  True,
  True,
  False,
  False,
  False,
  False,
  True,
  True,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  False,
  False,
  False,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  T

Query: synthetic free-text queries\
Context: structured free-text context from latest db

In [35]:
rag_gpt4o_new_db_text_res_dict = load_object(filename=os.path.join(root_dir, 'output/RAG_res_gpt4o/structured_latest_db_text/RAGstra0n1temp0.0_res_dict.pkl'))
rag_gpt4o_new_db_text_eval_res = calc_eval_metrics(rag_gpt4o_new_db_text_res_dict['full output'][0], moalmanac_data['prompt'], synthetic_prompt_groundtruth_dict)
rag_gpt4o_new_db_text_eval_res

{'avg_exact_match_acc': 0.44017094017094016,
 'avg_partial_match_acc': 0.5042735042735043,
 'avg_precision': 0.40655270655270653,
 'avg_recall': 0.4672364672364671,
 'avg_f1': 0.41338353005019707,
 'avg_specificity': 0.9850233556968898,
 'exact_match_acc': [True,
  True,
  False,
  True,
  True,
  False,
  True,
  False,
  False,
  False,
  True,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  True,
  True,
  False,
  True,
  False,
  True,
  True,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  True,
  True,
  True,
  False,
  False,
  False,
  False,
  True,
  True,
  False,
  False,
  False,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  True,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  True,
  True,
  True,
  True,
  True,
  False,
  True

Query: synthetic free-text queries \
Context: structured json-style context from latest moalmanac db \
Retrieval: entity matching + similarity search 
(`output/RAG_res_gpt4o/structured_latest_db_entity_matching`)

In [8]:
rag_gpt4o_new_db_entity_res_dict = load_object(filename=os.path.join(root_dir, 'output/RAG_res_gpt4o/structured_latest_db_entity_matching/RAGstra0n1temp0.0_res_dict.pkl'))
rag_gpt4o_new_db_entity_eval_res = calc_eval_metrics(rag_gpt4o_new_db_res_dict['full output'][0], moalmanac_data['prompt'], synthetic_prompt_groundtruth_dict)
rag_gpt4o_new_db_entity_eval_res

{'avg_exact_match_acc': 0.7735042735042735,
 'avg_partial_match_acc': 0.8760683760683761,
 'avg_precision': 0.5263719305385969,
 'avg_recall': 0.8245014245014246,
 'avg_f1': 0.599775485672922,
 'avg_specificity': 0.9753155328846659,
 'exact_match_acc': [True,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  False,
  False,
  True,
  True,
  True,
  False,
  False,
  False,
  False,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  True,
  True,
  False,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  Tru

- partial match acc improved substantially (~40% increase)
- confirms entity matching improve performance
- would need to examine the retrieved context for a few cases and see if there's anything that went wrong or concerning
- lower acc than previous preprint results which could be due to:
    - the validation QA pairs were formulated based on the previous db
    - the current db is much larger (~200 -> ~600), i.e., larger search space

Small testing:
- revisit context - reformatting:
    - description only
    - indication only
    - indication + cancer/gene