# Stroke RAG Agent Notebook

Sections: Environment Setup ¬∑ PDF Downloading + Processing ¬∑ Text Cleaning & Chunking ¬∑ Embedding + Vector Store Creation ¬∑ Agent Definition ¬∑ Reflection Function ¬∑ Workflow Function ¬∑ Chat Loop ¬∑ Final Instructions for Usage


## Environment Setup


In [1]:
# Install dependencies (run once)
# !pip install --quiet openai python-dotenv requests PyMuPDF nltk langchain-text-splitters faiss-cpu

In [2]:
import os, re, json, tempfile, pathlib, math
from typing import List, Dict, Tuple

import requests
import fitz  # PyMuPDF
import nltk
from nltk.corpus import stopwords
from langchain_text_splitters import RecursiveCharacterTextSplitter
import faiss        # Facebook AI Similarity Search -> library for efficient similarity search and clustering of dense vectors
import numpy as np
from dotenv import load_dotenv
from openai import OpenAI

# Download stopwords quietly and load env vars/API key
nltk.download("stopwords", quiet=True)
load_dotenv()

# Initialize OpenAI client
client = OpenAI()

# URLs to fetch
PDF_URLS = [
    "https://dnvp9c1uo2095.cloudfront.net/cms-content/030-046l_S2e_Akuttherapie-des-ischaemischen-Schlaganfalls_2022-11-verlaengert_1718363551944.pdf",
    "https://dnvp9c1uo2095.cloudfront.net/cms-content/030133_LL_Sekunda%CC%88rprophylaxe_Teil_1_2022_final_korr_1739803472035.pdf",
    "https://dnvp9c1uo2095.cloudfront.net/cms-content/030143_LL_Sekunda%CC%88rprophylaxe_Teil2_2022_V1.1_1670949892924.pdf",
    "https://tempis.de/download/tempis-sop-2025/?tmstv=1765451597",
]

# Controls for extraction and chunking
START_TITLES = ["Was gibt es Neues?", "Die wichtigsten Empfehlungen auf einen Blick"]
EXCLUDE_HEADINGS = ["Literatur", "Referenzen", "References", "Bibliografie"]
CHUNK_SIZE = 950
CHUNK_OVERLAP = 180
EMBED_MODEL = "text-embedding-3-large"      # OpenAI‚Äôs latest embedding model, high-dimensional vectors (1536 dimensions)
CHAT_MODEL = "gpt-4.1-mini"

# Temp directory to store downloaded PDFs
TEMP_DIR = pathlib.Path(tempfile.mkdtemp(prefix="rag_pdfs_"))
print(f"Temp dir: {TEMP_DIR}")

Temp dir: /tmp/rag_pdfs_qufsgbvy


## PDF Downloading + Processing


In [3]:
def download_pdfs(urls: List[str], dest_dir: pathlib.Path) -> List[pathlib.Path]:
    """Download all PDFs to a temp folder and return their paths."""
    files = []
    dest_dir.mkdir(parents=True, exist_ok=True)
    for i, url in enumerate(urls, 1):
        fname = dest_dir / f"doc_{i}.pdf"
        resp = requests.get(url, timeout=60)
        resp.raise_for_status()
        fname.write_bytes(resp.content)
        files.append(fname)
        print(f"Downloaded: {fname}")
    return files

pdf_paths = download_pdfs(PDF_URLS, TEMP_DIR)

Downloaded: /tmp/rag_pdfs_qufsgbvy/doc_1.pdf
Downloaded: /tmp/rag_pdfs_qufsgbvy/doc_2.pdf
Downloaded: /tmp/rag_pdfs_qufsgbvy/doc_3.pdf
Downloaded: /tmp/rag_pdfs_qufsgbvy/doc_4.pdf


## Text Cleaning & Chunking


In [4]:
# Load stopwords (German + English) for light heuristics
stop_words = set(stopwords.words("german")) | set(stopwords.words("english"))

def page_contains_start(text: str) -> bool:
    """Detect first relevant section titles to start extraction."""
    return any(title.lower() in text.lower() for title in START_TITLES)

def heading_is_excluded(text: str) -> bool:
    """Check if a heading should be skipped (references)."""
    return any(h.lower() in text.lower() for h in EXCLUDE_HEADINGS)

def looks_like_reference_block(text: str) -> bool:
    """Heuristic to drop pages/blocks that are mostly references."""
    ref_patterns = [r"\[[0-9]{1,3}\]", r"\([0-9]{1,3}\)", r"\d{4}\."]
    hits = sum(len(re.findall(p, text)) for p in ref_patterns)
    density = hits / max(1, len(text.split()))
    return density > 0.08 or heading_is_excluded(text[:80])

def clean_text(text: str) -> str:
    """Light cleaning: drop URLs, citations, page numbers, extra whitespace."""
    text = re.sub(r"https?://\S+", " ", text)
    text = re.sub(r"\b\d+\s*/\s*\d+\b", " ", text)  # page numbers like 12/34
    text = re.sub(r"\s+\d+\s+", " ", text)  # lone numbers
    text = re.sub(r"\[[0-9]+\]|\([0-9]+\)|\^{[0-9]+}", " ", text)  # citation markers
    text = re.sub(r"\s+", " ", text).strip()
    return text

def extract_relevant_text(pdf_path: pathlib.Path) -> str:
    """Extract text starting at first target section; skip reference-like pages."""
    doc = fitz.open(pdf_path)
    collecting = False
    blocks = []
    for page in doc:
        txt = page.get_text("text") or ""
        if not collecting and page_contains_start(txt):
            collecting = True
        if not collecting:
            continue  # skip intro pages
        if looks_like_reference_block(txt):
            continue  # drop reference pages/sections
        cleaned = clean_text(txt)
        if cleaned:
            blocks.append(cleaned)
    doc.close()
    return "\n".join(blocks)

# Extract from all PDFs
documents = []
for p in pdf_paths:
    extracted = extract_relevant_text(p)
    documents.append(extracted)
    print(f"Extracted {len(extracted)} chars from {p.name}")

# Combine docs and split into overlapping chunks for retrieval
all_text = "\n".join(documents)
splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP,
    separators=["\n\n", "\n", ". ", ".", " "]
)
chunks = splitter.split_text(all_text)
print(f"Total chunks: {len(chunks)}")

# Simple metadata mapping per chunk
metadatas = [{"source": f"chunk_{i}"} for i in range(len(chunks))]

Extracted 550234 chars from doc_1.pdf
Extracted 261721 chars from doc_2.pdf
Extracted 313425 chars from doc_3.pdf
Extracted 0 chars from doc_4.pdf
Total chunks: 1567


## Embedding + Vector Store Creation


In [5]:
def embed_texts(texts: List[str]) -> np.ndarray:
    """Batch-embed all chunks using OpenAI embeddings."""
    embeddings = []
    for i in range(0, len(texts), 64):
        batch = texts[i:i+64]
        resp = client.embeddings.create(model=EMBED_MODEL, input=batch)
        embeddings.extend([item.embedding for item in resp.data])
    return np.array(embeddings, dtype="float32")

# Compute embeddings and build FAISS index
emb_matrix = embed_texts(chunks)
index = faiss.IndexFlatL2(emb_matrix.shape[1])
index.add(emb_matrix)

# Store mappings for retrieval
id2text = {i: t for i, t in enumerate(chunks)}
id2meta = {i: m for i, m in enumerate(metadatas)}

print("FAISS index ready", index.ntotal)

FAISS index ready 1567


## Agent Definition


In [6]:
# Persona/system prompt guiding clinician vs patient tone
SYSTEM_PERSONA = (
    "You are a highly experienced neurologist specialized in stroke medicine. "
    "When talking to clinicians, use precise medical terminology and cite key recommendations succinctly. "
    "When talking to patients or families, simplify explanations, avoid jargon, and be empathetic. "
    "Always base answers on retrieved context; if insufficient, say so and request clarification."
)

def retrieve(query: str, k: int = 5) -> List[Tuple[str, Dict]]:
    """Vector search top-k chunks for a query."""
    q_emb = np.array(client.embeddings.create(model=EMBED_MODEL, input=[query]).data[0].embedding, dtype="float32")
    q_emb = np.expand_dims(q_emb, axis=0)
    scores, idxs = index.search(q_emb, k)
    results = []
    for score, idx in zip(scores[0], idxs[0]):
        if idx == -1:
            continue
        results.append((id2text[idx], id2meta[idx]))
    return results

def build_context(chunks: List[Tuple[str, Dict]]) -> str:
    """Format retrieved chunks into a readable context block."""
    parts = []
    for i, (text, meta) in enumerate(chunks, 1):
        parts.append(f"[Chunk {i}] {text}")
    return "\n\n".join(parts)

def agent_answer(query: str, context: str) -> str:
    """Generate a draft answer using persona + retrieved context."""
    msgs = [
        {"role": "system", "content": SYSTEM_PERSONA},
        {"role": "system", "content": "Use the provided context. If context is thin, say so and request clarification."},
        {"role": "system", "content": f"Context:\n{context}"},
        {"role": "user", "content": query},
    ]
    resp = client.chat.completions.create(model=CHAT_MODEL, messages=msgs)
    return resp.choices[0].message.content.strip()

## Conversation Memory


In [7]:
# Simple rolling memory of recent Q&A pairs to keep context alive
MAX_MEMORY = 5  # keep last 5 exchanges
memory_buffer: List[Tuple[str, str]] = []

def add_to_memory(question: str, answer: str) -> None:
    memory_buffer.append((question, answer))
    if len(memory_buffer) > MAX_MEMORY:
        memory_buffer.pop(0)

def build_memory_context() -> str:
    if not memory_buffer:
        return ""
    formatted = []
    for i, (q, a) in enumerate(memory_buffer, 1):
        formatted.append(f"[Memory {i}] Q: {q}\nA: {a}")
    return "\n\n".join(formatted)

## Reflection Function


In [8]:
def reflect_and_improve(response: str, query: str, context: str) -> str:
    """Check medical quality; if feedback not OK, regenerate with fixes."""
    reflection = client.chat.completions.create(
        model=CHAT_MODEL,
        messages=[
            {
                "role": "system",
                "content": "You are checking whether the response is medically accurate, complete, safe, and well-structured. Respond with 'OK' or provide concrete corrections and missing points.",
            },
            {"role": "user", "content": f"Query: {query}\n\nContext:\n{context}\n\nDraft response:\n{response}"},
        ],
    )
    feedback = reflection.choices[0].message.content.strip()
    if feedback.upper() == "OK":
        return response
    # Regenerate with feedback injected
    regen_msgs = [
        {"role": "system", "content": SYSTEM_PERSONA},
        {"role": "system", "content": "Use the provided context. If context is thin, say so and ask for clarification."},
        {"role": "system", "content": f"Context:\n{context}"},
        {"role": "system", "content": f"Please fix issues noted: {feedback}"},
        {"role": "user", "content": query},
    ]
    regen = client.chat.completions.create(model=CHAT_MODEL, messages=regen_msgs)
    return regen.choices[0].message.content.strip()

## Workflow Function


In [9]:
def workflow(query: str, k: int = 5) -> str:
    """Full RAG pipeline: retrieve ‚Üí draft ‚Üí reflect ‚Üí final (with memory)."""
    retrieved = retrieve(query, k=k)
    retrieval_context = build_context(retrieved)
    memory_context = build_memory_context()

    # Combine memory + retrieval context for richer answers
    if memory_context:
        combined_context = f"{memory_context}\n\n{retrieval_context}"
    else:
        combined_context = retrieval_context

    draft = agent_answer(query, combined_context)
    final = reflect_and_improve(draft, query, combined_context)

    # Persist this turn in memory
    add_to_memory(query, final)
    return final

## Chat Loop


In [None]:
print("Type 'exit' or 'quit' to stop.")
while True:
    q = input("üë§: ")
    if q.strip().lower() in ["exit", "quit"]:
        print("Bye! Have a healthy day!")
        break
    print("Checking the guidelines...")
    answer = workflow(q)
    print(f"üßë‚Äç‚öïÔ∏è: {answer}\n")

# Reset conversation memory after the chat loop ends
memory_buffer.clear()
print("Conversation memory cleared.")

Type 'exit' or 'quit' to stop.
Bye!
Conversation memory cleared.


## Final Instructions for Usage

- Ensure your `.env` contains `OPENAI_API_KEY` (and organization/project if required).
- Run cells top-to-bottom. First run will download PDFs, embed, and build FAISS (may take a few minutes).
- Use the chat loop to interact. Type `exit` or `quit` to end.
- Persona: senior stroke neurologist; adapts to clinicians vs patients; cites context and asks for clarification when retrieval is thin.
- Reflection: every answer is checked; if not `OK`, it regenerates with the feedback.
- Memory: the last 5 Q&A turns are remembered and injected ahead of retrieval context.


## Potential improvements

- add a tool which allows it to search online or on pubmed when asked about studies / trials