In [1]:
import sys, os, pandas as pd
from glob import glob
import warnings
warnings.filterwarnings('ignore')

In [2]:
os.chdir('../..')

In [3]:
# ================== GENERAL IMPORTS ==================
import os
import json
from dotenv import load_dotenv

# ================== UTIL FUNCTIONS ==================
from utils.embedding import get_context_db, retrieve_context
from utils.prompt import get_prompt
from llm.run_RAGLLM import run_RAG


# ================== MODEL & API IMPORTS ==================
from mistralai.client import MistralClient
from openai import OpenAI
from llm.inference import run_llm
import faiss

# import from 

In [4]:
#test input
report = 'external-validation/panel-sequencing/reports/test/P-0000036.txt'
with open(report, 'r') as f:
    report = f.read()

In [7]:
# --- existing module state ---
_READY = False
_CLIENT = None
_CONTEXT = None
_INDEX = None
_MODEL_TYPE = None
_MODEL_NAME = None
_MODEL_EMBED = None

def _cache_paths(embed_name: str, version: str = "v1"):
    os.makedirs("indexes", exist_ok=True)
    return (
        f"indexes/{embed_name}__{version}.faiss",
        f"indexes/{embed_name}__{version}.context.json",
    )

In [8]:
def init(
    context_json_path: str = "data/structured_context_chunks.json",
    model_type: str = "gpt",
    model_api: str = "gpt-4o-2024-05-13",
    *,
    force_rebuild: bool = False,
):
    global _READY, _CLIENT, _CONTEXT, _INDEX, _MODEL_TYPE, _MODEL_NAME, _MODEL_EMBED
    if _READY:
        return

    load_dotenv()
    if model_type in ["gpt", "gpt_reasoning"]:
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key: raise RuntimeError("OPENAI_API_KEY not set")
        _CLIENT = OpenAI(api_key=api_key)
        _MODEL_EMBED = "text-embedding-3-small"
    elif model_type in ["mistral", "mistral-7b"]:
        api_key = os.getenv("MISTRAL_API_KEY")
        if not api_key: raise RuntimeError("MISTRAL_API_KEY not set")
        _CLIENT = MistralClient(api_key=api_key)
        _MODEL_EMBED = "mistral-embed"
    else:
        raise ValueError("Invalid model_type. Please choose from: mistral-7b, mistral, gpt")

    _MODEL_TYPE = model_type
    _MODEL_NAME = model_api

    index_path, ctx_path = _cache_paths(_MODEL_EMBED)

    if (not force_rebuild) and os.path.exists(index_path) and os.path.exists(ctx_path):
        with open(ctx_path, "r") as f:
            _CONTEXT = json.load(f)
        _INDEX = faiss.read_index(index_path)
    else:
        with open(context_json_path, "r") as f:
            _CONTEXT = json.load(f)
        _INDEX = get_context_db(_CONTEXT, _CLIENT, _MODEL_EMBED)
        faiss.write_index(_INDEX, index_path)
        with open(ctx_path, "w") as f:
            json.dump(_CONTEXT, f)

    _READY = True


In [9]:
init(context_json_path="data/structured_context_chunks.json",
     model_type="gpt",
     model_api="gpt-4o-2024-05-13",
     force_rebuild=False)

In [10]:
format_input = \
'''
The following is a summarized molecular profile from MSK-IMPACT.
Based on this profile, answer with the appropriate treatments for this patient in the format above.
{report}
'''.format(report=report)
prompt = get_prompt(strategy = 0, prompt_chunk= format_input)
print(prompt)

Please provide each treatment as a json format with the following JSON schema.
    {
        "Treatment 1": {
            "Disease Name": ,
            "Disease Phase or Condition": ,
            "Drug Name": ,
            "Prior Treatment or Resistance Status": ,
            "Genomic Features":
            "Link to FDA-approved Label": 
            }
    }
    Query: 
The following is a summarized molecular profile from MSK-IMPACT.
Based on this profile, answer with the appropriate treatments for this patient in the format above.
Patient ID: P-0000036
Age: 68
Gender: Female
Sample ID: P-0000036-T01-IM3
Gene Panel: IMPACT341
Cancer Type: Lung Adenocarcinoma
Sample Type: Primary
Tumor Purity: 30.0%
DNA Variants:
AR c.676T>G (p.L226V) - in 14.000000000000002% of 584 reads
ERBB2 c.3044G>A (p.G1015E) - in 42.0% of 354 reads
FBXW7 c.336_344del (p.D112_E114del) - in 40.0% of 378 reads
IRS1 c.1382A>G (p.E461G) - in 44.0% of 418 reads
NOTCH4 c.2443T>G (p.C815G) - in 48.0% of 296 reads
TP53 c.9

In [13]:
def answer(
    text: str,
    strategy: int = 0,
    num_vec: int = 10,
    max_len: int = 2048,
    temp: float = 0.0,
    random_seed: int = 2025,
    *,
    rag = True
):
    if not _READY:
            raise RuntimeError("Call init(...) first.")
    if rag:
        out, _ = run_RAG(
            _CONTEXT,
            text,
            strategy,
            _INDEX,
            _CLIENT,
            num_vec,
            _MODEL_TYPE,
            _MODEL_NAME,
            _MODEL_EMBED,
            max_len,
            temp,
            random_seed,
        )
    return out or ""

In [14]:
answer(prompt, strategy=0, rag=True)

'{\n    "Treatment 1": {\n        "Disease Name": "Lung Adenocarcinoma",\n        "Disease Phase or Condition": "Locally advanced or metastatic",\n        "Drug Name": "Entrectinib",\n        "Prior Treatment or Resistance Status": "No satisfactory standard therapy",\n        "Genomic Features": "ROS1 - SLC34A2 (TRANSLOCATION)",\n        "Link to FDA-approved Label": "https://www.accessdata.fda.gov/drugsatfda_docs/label/2023/212725Orig1s009lbl.pdf"\n    }\n}'

In [None]:
{\n    
 "Treatment 1": 
 {\n        
  "Disease Name": "Lung Adenocarcinoma",\n        
  "Disease Phase or Condition": "Locally advanced or metastatic",\n        
  "Drug Name": "Entrectinib",\n        
  "Prior Treatment or Resistance Status": "No satisfactory standard therapy",\n       
  "Genomic Features": "ROS1 - SLC34A2 (TRANSLOCATION)",\n       
  "Link to FDA-approved Label": "https://www.accessdata.fda.gov/drugsatfda_docs/label/2023/212725Orig1s009lbl.pdf"\n    
  }\n}'