In [1]:
import sys, os, pandas as pd, ast, json
from glob import glob
import warnings
warnings.filterwarnings('ignore')


In [2]:
os.chdir('../..')

In [3]:
# ================== GENERAL IMPORTS ==================
import os
import json
from dotenv import load_dotenv

# ================== UTIL FUNCTIONS ==================
from utils.embedding import get_context_db, retrieve_context
from utils.prompt import get_prompt
from llm.run_RAGLLM import run_RAG


# ================== MODEL & API IMPORTS ==================
from mistralai.client import MistralClient
from openai import OpenAI
from llm.inference import run_llm
import faiss


---

In [4]:
# --- existing module state ---
_READY = False
_CLIENT = None
_CONTEXT = None
_INDEX = None
_MODEL_TYPE = None
_MODEL_NAME = None
_MODEL_EMBED = None

def _cache_paths(embed_name: str, version: str = "v1"):
    os.makedirs("indexes", exist_ok=True)
    return (
        f"indexes/{embed_name}__{version}.faiss",
        f"indexes/{embed_name}__{version}.context.json",
    )

In [5]:
def init(
    context_json_path: str = "data/structured_context_chunks.json",
    model_type: str = "gpt",
    model_api: str = "gpt-4o-2024-05-13",
    *,
    force_rebuild: bool = False,
):
    global _READY, _CLIENT, _CONTEXT, _INDEX, _MODEL_TYPE, _MODEL_NAME, _MODEL_EMBED
    if _READY:
        return

    load_dotenv()
    if model_type in ["gpt", "gpt_reasoning"]:
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key: raise RuntimeError("OPENAI_API_KEY not set")
        _CLIENT = OpenAI(api_key=api_key)
        _MODEL_EMBED = "text-embedding-3-small"
    elif model_type in ["mistral", "mistral-7b"]:
        api_key = os.getenv("MISTRAL_API_KEY")
        if not api_key: raise RuntimeError("MISTRAL_API_KEY not set")
        _CLIENT = MistralClient(api_key=api_key)
        _MODEL_EMBED = "mistral-embed"
    else:
        raise ValueError("Invalid model_type. Please choose from: mistral-7b, mistral, gpt")

    _MODEL_TYPE = model_type
    _MODEL_NAME = model_api

    index_path, ctx_path = _cache_paths(_MODEL_EMBED)

    if (not force_rebuild) and os.path.exists(index_path) and os.path.exists(ctx_path):
        with open(ctx_path, "r") as f:
            _CONTEXT = json.load(f)
        _INDEX = faiss.read_index(index_path)
    else:
        with open(context_json_path, "r") as f:
            _CONTEXT = json.load(f)
        _INDEX = get_context_db(_CONTEXT, _CLIENT, _MODEL_EMBED)
        faiss.write_index(_INDEX, index_path)
        with open(ctx_path, "w") as f:
            json.dump(_CONTEXT, f)

    _READY = True


In [6]:
init(context_json_path="data/structured_context_chunks.json",
     model_type="gpt",
     model_api="gpt-4.1-nano",#"gpt-4o-2024-05-13",
     force_rebuild=False)

---

In [7]:
def answer(
    text: str,
    strategy: int = 0,
    num_vec: int = 10,
    max_len: int = 2048,
    temp: float = 0.0,
    random_seed: int = 2025,
    *,
    rag = True
):
    if not _READY:
            raise RuntimeError("Call init(...) first.")
    if rag:
        out, _ = run_RAG(
            _CONTEXT,
            text,
            strategy,
            _INDEX,
            _CLIENT,
            num_vec,
            _MODEL_TYPE,
            _MODEL_NAME,
            _MODEL_EMBED,
            max_len,
            temp,
            random_seed,
        )
    return out or ""

-----

In [23]:
patient_list = ['P-0000495', 'P-0002861']

In [24]:
#gold answer is
df = pd.read_csv('external-validation/panel-sequencing/reports/filtered-data/first_line_treatments_post_msk.csv')
treatment = df[df['PATIENT_ID'].isin(patient_list)][['PATIENT_ID', 'FIRST_LINE_TREATMENT']]
treatment

Unnamed: 0,PATIENT_ID,FIRST_LINE_TREATMENT
42,P-0000495,"{'TAMOXIFEN', 'EVEROLIMUS'}"
252,P-0002861,"{'BEVACIZUMAB', 'ERLOTINIB'}"


In [25]:
for patient in patient_list:
    report = 'external-validation/panel-sequencing/reports/test/{p}.txt'.format(p = patient)
    with open(report, 'r') as f:
        report = f.read()
    format_input = \
    '''
    The following is a summarized molecular profile from MSK-IMPACT.
    Based on this profile, answer with the appropriate treatments for this patient in the format above.
    {report}
    '''.format(report=report)
    os.makedirs('external-validation/panel-sequencing/reports/rag-test', exist_ok=True)
    strategy = 0
    prompt = get_prompt(strategy = strategy, prompt_chunk= format_input)
    with open('external-validation/panel-sequencing/reports/rag-test/{p}_{s}.json'.format(p = patient, s = str(strategy)), 'w') as f:
        json.dump({"prompt": prompt, "response": answer(prompt, strategy=strategy, rag=True)}, f, indent = 4)
    strategy = 5
    prompt = get_prompt(strategy = strategy, prompt_chunk= format_input)
    with open('external-validation/panel-sequencing/reports/rag-test/{p}_{s}.json'.format(p = patient, s = str(strategy)), 'w') as f:
        json.dump({"prompt": prompt, "response": answer(prompt, strategy=strategy, rag=True)}, f, indent = 4)

----

In [10]:
with open('external-validation/panel-sequencing/reports/test/selected-samples-no-previous-treatment.json') as f:
    import json
    patients = json.load(f)
patient_list = patients[:2]
patient_list

['P-0000557-T01-IM3', 'P-0001189-T01-IM3']

In [16]:
#gold answer is
df = pd.read_csv('external-validation/panel-sequencing/reports/filtered-data/first_line_treatments_post_msk.csv')
treatment = df[df['PATIENT_ID'].isin([i.split('-T01')[0] for i in patient_list])][['PATIENT_ID', 'FIRST_LINE_TREATMENT']]
treatment

Unnamed: 0,PATIENT_ID,FIRST_LINE_TREATMENT
49,P-0000557,{'ERLOTINIB'}
115,P-0001189,{'LAPATINIB'}


In [18]:
for patient in patient_list:
    patient = patient.split('-T01')[0]
    report = 'external-validation/panel-sequencing/reports/test/{p}.txt'.format(p = patient)
    with open(report, 'r') as f:
        report = f.read()
    format_input = \
    '''
    The following is a summarized molecular profile from MSK-IMPACT.
    Based on this profile, answer with the appropriate treatments for this patient in the format above.
    {report}
    '''.format(report=report)
    os.makedirs('external-validation/panel-sequencing/reports/rag-test', exist_ok=True)
    strategy = 0
    prompt = get_prompt(strategy = strategy, prompt_chunk= format_input)
    with open('external-validation/panel-sequencing/reports/rag-test/{p}_{s}.json'.format(p = patient, s = str(strategy)), 'w') as f:
        json.dump({"prompt": prompt, "response": answer(prompt, strategy=strategy, rag=True)}, f, indent = 4)
    strategy = 5
    prompt = get_prompt(strategy = strategy, prompt_chunk= format_input)
    with open('external-validation/panel-sequencing/reports/rag-test/{p}_{s}.json'.format(p = patient, s = str(strategy)), 'w') as f:
        json.dump({"prompt": prompt, "response": answer(prompt, strategy=strategy, rag=True)}, f, indent = 4)