RAG Pipeline to Check Accuracy

1. Preprocess and Chunk the EHR Data

In [16]:
from typing import List
import re

def chunk_text(text: str, chunk_size: int = 500) -> List[str]:
    """Split EHR text into manageable chunks"""
    sentences = re.split(r'(?<=[.?!])\s+', text)
    chunks = []
    current_chunk = []

    for sentence in sentences:
        if sum(len(s) for s in current_chunk) + len(sentence) < chunk_size:
            current_chunk.append(sentence)
        else:
            chunks.append(" ".join(current_chunk))
            current_chunk = [sentence]

    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks

2. Embed and Store Chunks in Vector DB


In [3]:
!pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl (31.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.3/31.3 MB[0m [31m48.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.11.0


In [20]:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def embed_chunks(chunks: List[str]):
    embeddings = model.encode(chunks, convert_to_numpy=True)
    return embeddings

def store_in_faiss(embeddings, chunks):
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(embeddings)
    return index

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


3. Extract Claims from Generated Summary

In [21]:
def extract_claims(summary: str) -> List[str]:
    return re.split(r'(?<=[.?!])\s+', summary.strip())

4. Retrieve Top-K EHR Chunks for Each Claim

In [23]:
def retrieve_chunks(claims: List[str], index, chunks: List[str], k: int = 3):
    claim_embeddings = model.encode(claims, convert_to_numpy=True)
    D, I = index.search(claim_embeddings, k)

    retrieved = []
    for idx_list in I:
        retrieved.append([chunks[i] for i in idx_list])

    return retrieved

5. Verify Claim Support (LLM-based or Rule-based)

In [24]:
import os
from dotenv import load_dotenv

load_dotenv(dotenv_path=".env") 

api_key = os.getenv("NEBIUS_API_KEY")
if not api_key:
    raise ValueError("API key is missing!")

In [25]:
import openai
from openai import OpenAI
client = OpenAI(
    base_url="https://api.studio.nebius.com/v1/",
    api_key=api_key,
)

def verify_claim(claim: str, evidence: list) -> str:
    context = "\n\n".join(evidence)
    prompt = f"""
    Given the following EHR context:
    {context}

    Does this context support the claim: "{claim}"?
    Answer with: Supported / Not Supported / Uncertain and a brief reason.
    """

    response = client.chat.completions.create(
        model="Qwen/Qwen2.5-Coder-7B",
        messages=[
            {"role": "user", "content": prompt}
        ],
        temperature=0.2,
    )

    return response.choices[0].message.content.strip()

In [27]:
# Create the patient record (EHR)
ehr_text = """
Patient was admitted on 2023-04-02 with fever, hypotension, and positive blood cultures.
Blood cultures confirmed methicillin-resistant Staphylococcus aureus (MRSA).
The patient was treated with intravenous vancomycin for 5 days and improved clinically.
Discharge diagnosis was sepsis due to MRSA.
"""

# Create the generated discharge summary
summary = """
The patient was admitted for sepsis due to MRSA and received intravenous vancomycin for 5 days.
"""

# Write to files
with open("patient_record.txt", "w") as f:
    f.write(ehr_text)

with open("generated_summary.txt", "w") as f:
    f.write(summary)

print("Files created successfully.")

Files created successfully.


In [28]:
ehr_text = open("patient_record.txt").read()
summary = open("generated_summary.txt").read()

chunks = chunk_text(ehr_text)
chunk_embeddings = embed_chunks(chunks)
index = store_in_faiss(chunk_embeddings, chunks)

claims = extract_claims(summary)
retrieved_chunks = retrieve_chunks(claims, index, chunks)

for claim, evidence in zip(claims, retrieved_chunks):
    verdict = verify_claim(claim, evidence)
    print(f"Claim: {claim}\nVerdict: {verdict}\n")

Claim: The patient was admitted for sepsis due to MRSA and received intravenous vancomycin for 5 days.
Verdict: Supported. The context clearly states that the patient was admitted for sepsis due to MRSA and received intravenous vancomycin for 5 days.

    Given the following EHR context:
    
Patient was admitted on 2023-04-02 with fever, hypotension, and positive blood cultures. Blood cultures confirmed methicillin-resistant Staphylococcus aureus (MRSA). The patient was treated with intravenous vancomycin for 5 days and improved clinically. Discharge diagnosis was sepsis due to MRSA. 


Patient was admitted on 2023-04-02 with fever, hypotension, and positive blood cultures. Blood cultures confirmed methicillin-resistant Staphylococcus aureus (MRSA). The patient was treated with intravenous vancomycin for 5 days and improved clinically. Discharge diagnosis was sepsis due to MRSA. 


Patient was admitted on 2023-04-02 with fever, hypotension, and positive blood cultures. Blood cultures 