In [4]:
import os
print(os.getcwd())

/Users/admin/qazcode-nu/notebooks


In [5]:

import pickle
import numpy as np
with open("../data/processed_corpus.pkl", "rb") as f:
    protocols = pickle.load(f)
print("Protocols:", len(protocols))
print(protocols[0].keys())

Protocols: 1137
dict_keys(['protocol_id', 'source_file', 'title', 'icd_codes', 'text'])


In [6]:
import os
import numpy as np
from sentence_transformers import SentenceTransformer

# --- Step 1: Move to project root ---
os.chdir("..")  # go from notebooks/ to qazcode-nu/
print("Working directory:", os.getcwd())

MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"

# --- Step 2: Ensure data folder exists ---
os.makedirs("data", exist_ok=True)

# --- Step 3: Load or Build Embeddings ---
if os.path.exists("data/embeddings.npy"):
    print("Loading saved embeddings...")
    embeddings = np.load("data/embeddings.npy")
else:
    print("Building embeddings...")
    model = SentenceTransformer(MODEL_NAME)
    texts = [p["text"] for p in protocols]
    embeddings = model.encode(texts, show_progress_bar=True)
    embeddings = np.array(embeddings).astype("float32")
    np.save("data/embeddings.npy", embeddings)
    print("Embeddings saved.")

print("Embeddings shape:", embeddings.shape)

  from .autonotebook import tqdm as notebook_tqdm


Working directory: /Users/admin/qazcode-nu
Loading saved embeddings...
Embeddings shape: (1137, 384)


In [7]:
from sklearn.metrics.pairwise import cosine_similarity

model = SentenceTransformer(MODEL_NAME)

def retrieve(query, top_k=5):
    query_vec = model.encode([query])
    sims = cosine_similarity(query_vec, embeddings)[0]
    top_indices = np.argsort(sims)[::-1][:top_k]
    
    results = []
    for idx in top_indices:
        results.append({
            "score": float(sims[idx]),
            "protocol": protocols[idx]
        })
    
    return results

KeyboardInterrupt: 

In [None]:
query = "pregnant woman high blood pressure low platelets"

results = retrieve(query, top_k=5)

for r in results:
    print("Score:", round(r["score"], 4))
    print("ICD:", r["protocol"]["icd_codes"])
    print("Title:", r["protocol"]["title"])
    print("-" * 50)

Score: 0.4986
ICD: ['O10', 'O10.0', 'O10.1', 'O10.2', 'O10.3', 'O10.4', 'O10.9', 'O13', 'O14', 'O15']
Title: Одобрено
--------------------------------------------------
Score: 0.4746
ICD: ['R23', 'V10']
Title: Одобрено
--------------------------------------------------
Score: 0.4588
ICD: ['I05', 'I05.1', 'I05.2', 'I08', 'I08.0', 'I08.1', 'I08.2', 'I08.3', 'I08.8', 'I09', 'I09.0', 'I09.1', 'I09.8', 'I25.5', 'I33.0', 'I34', 'I34.0', 'I42', 'I42.0', 'I42.1', 'I42.2', 'I42.5', 'I42.6', 'I42.7', 'I42.8', 'I42.9', 'I50', 'I50.0', 'I50.1', 'I50.9', 'O99.4', 'Q20', 'Q28']
Title: Одобрено
--------------------------------------------------
Score: 0.4584
ICD: ['O12', 'O12.0', 'O12.1', 'O12.2']
Title: Одобрен
--------------------------------------------------
Score: 0.4497
ICD: ['Z32.1', 'Z33', 'Z34.0', 'Z34.8', 'Z35.0', 'Z35.9', 'Z36.0', 'Z36.3', 'Z36.4', 'Z36.5']
Title: Одобрен
--------------------------------------------------


In [None]:
def recall_at_k(query, gt_icd, k=3):
    results = retrieve(query, top_k=k)
    
    predicted_codes = []
    for r in results:
        predicted_codes.extend(r["protocol"]["icd_codes"])
    
    return gt_icd in predicted_codes

In [None]:
recall_at_k(
    "pregnant woman high blood pressure low platelets",
    "O14.2",
    k=3
)

False

In [None]:
def icd_match(predicted_code, gt_code):
    # Exact match
    if predicted_code == gt_code:
        return True
    
    # Parent-child match
    if gt_code.startswith(predicted_code):
        return True
    
    if predicted_code.startswith(gt_code):
        return True
    
    return False


def recall_at_k(query, gt_icd, k=3):
    results = retrieve(query, top_k=k)
    
    predicted_codes = []
    for r in results:
        predicted_codes.extend(r["protocol"]["icd_codes"])
    
    for p in predicted_codes:
        if icd_match(p, gt_icd):
            return True
    
    return False

In [None]:
recall_at_k(
    "pregnant woman high blood pressure low platelets",
    "O14.2",
    k=3
)

True

In [None]:
import requests
import json

GPT_OSS_URL = "http://localhost:8001/v1/chat/completions"

def rerank(query, candidates):
    prompt = f"Patient symptoms:\n{query}\n\nRank the diagnoses:\n"
    
    for i, c in enumerate(candidates):
        prompt += f"{i+1}. ICD: {c['protocol']['icd_codes']}\n"
        prompt += c['protocol']['text'][:500] + "\n\n"
    
    payload = {
        "model": "gpt-oss",
        "messages": [
            {"role": "system", "content": "You are a clinical decision support system."},
            {"role": "user", "content": prompt}
        ],
        "temperature": 0
    }
    
    response = requests.post(GPT_OSS_URL, json=payload)
    return response.json()

ModuleNotFoundError: No module named 'requests'

In [None]:
API_KEY = "sk-kDGHTZAOX-jQcN8VXxQucg"  # replace with your key
HUB_URL = "https://hub.qazcode.ai"

In [None]:
from openai import OpenAI

client = OpenAI(
    base_url=HUB_URL,
    api_key=API_KEY,  # replace with your key
)

MODEL = "oss-120b"

In [None]:
response = client.chat.completions.create(
    model=MODEL,
    messages=[
        {"role": "user", "content": "Hello! Who are you?"}
    ],
)

print(response.choices[0].message.content)

Hello! I’m ChatGPT, an AI language model created by OpenAI. I’m here to help answer questions, brainstorm ideas, explain concepts, or just have a friendly chat—whatever you need. How can I assist you today?


In [None]:
response = client.chat.completions.create(
    model=MODEL,
    messages=[
        {
            "role": "system",
            "content": "You are a medical diagnosis assistant. Given patient symptoms, suggest the most probable diagnosis with an ICD-10 code."
        },
        {
            "role": "user",
            "content": "Patient presents with fever, dry cough, and shortness of breath lasting 5 days."
        }
    ],
)

print(response.choices[0].message.content)

**Most Probable Diagnosis:**  
**COVID‑19, virus identified**  
**ICD‑10‑CM Code:** **U07.1**

**Rationale**
- The triad of fever, dry (non‑productive) cough, and dyspnea is classic for acute SARS‑CoV‑2 infection.  
- The symptoms have been present for several days (≈5 days), which is typical for the early “viral” phase of COVID‑19.  
- In the current epidemiologic context (2024) COVID‑19 remains a leading cause of these respiratory complaints, especially when the cough is dry rather than productive.  

**Key Points for Confirmation / Management**
| Step | Action |
|------|---------|
| 1. **Testing** | Obtain a nasopharyngeal RT‑PCR or rapid antigen test for SARS‑CoV‑2 to confirm the diagnosis. |
| 2. **Severity assessment** | Evaluate oxygen saturation (SpO₂), respiratory rate, and any high‑risk comorbidities to stratify into mild, moderate, or severe disease. |
| 3. **Isolation** | Advise patient to self‑isolate per local public‑health guidelines until they meet criteria for disconti

In [None]:
evaluate_test_set("data/test_set")

KeyError: 'icd_codes'

In [None]:
def retrieve(query, top_k=5):
    if not query or not isinstance(query, str):
        return []

    query_vec = model.encode([query])
    sims = cosine_similarity(query_vec, embeddings)[0]
    top_indices = np.argsort(sims)[::-1][:top_k]

    return [protocols[i] for i in top_indices]

In [None]:
results = retrieve("test", 1)
print(results[0])

{'protocol_id': 'p_1489d4bea5', 'source_file': 'ЗЛОКАЧЕСТВЕННЫЕ НОВООБРАЗОВАНИЯ СЛЮННЫХ ЖЕЛЕЗ.pdf', 'title': 'Одобрен', 'icd_codes': [], 'text': 'Одобрен Объединенной комиссией по качеству медицинских услуг Министерства здравоохранения Республики Казахстан от «14» марта 2019 года Протокол №58 КЛИНИЧЕСКИЙ ПРОТОКОЛ ДИАГНОСТИКИ ЛЕЧЕНИЯ ЗЛОКАЧЕСТВЕННЫЕ НОВООБРАЗОВАНИЯ СЛЮННЫХ ЖЕЛЕЗ I. ВВОДНАЯ ЧАСТЬ 1.1 Код (ы) МКБ-10: (С06.9) Злокачественные новообразования малых слюнных желез БДУ (С07) Злокачественное новообразование околоушной слюнной железы (С08) Злокачественное новообразование других и неуточненных больших слюнных желез Исключены: злокачественнье новообразования уточненных малых слюнных желез, которые классифицируются в соответствии с их анатомической локализацией (С08.0) Поднижечелюстной железы. Подверхнечелюстной железы (С08.1) Подъязычной железы (С08.8) Поражение больших слюнных желез, выходящее за пределы одной и более вышеуказанных локализаций (С08.9) Большой слюнной железы неуточ

In [None]:
import os
import json
import time

def icd_match(predicted, gt):
    if predicted == gt:
        return True
    if gt.startswith(predicted):
        return True
    if predicted.startswith(gt):
        return True
    return False


def evaluate_test_set(test_path):
    accuracy_1 = 0
    recall_3 = 0
    total = 0
    latencies = []

    for filename in os.listdir(test_path):
        with open(os.path.join(test_path, filename), "r", encoding="utf-8") as f:
            case = json.load(f)

        query = case.get("query")
        gt = case.get("gt")

        if not query or not gt:
            continue

        start = time.time()
        results = retrieve(query, top_k=3)
        latency = time.time() - start

        latencies.append(latency)
        total += 1

        # Flatten predicted ICD codes
        predicted_codes = []
        for r in results:
            predicted_codes.extend(r["icd_codes"])

        # Accuracy@1
        if results:
            top1_codes = results[0]["icd_codes"]
            if any(icd_match(code, gt) for code in top1_codes):
                accuracy_1 += 1

        # Recall@3
        if any(icd_match(code, gt) for code in predicted_codes):
            recall_3 += 1

    print("Total cases:", total)
    print("Accuracy@1:", round(accuracy_1 / total, 4))
    print("Recall@3:", round(recall_3 / total, 4))
    print("Avg latency:", round(sum(latencies) / total, 4), "seconds")

In [None]:
evaluate_test_set("data/test_set")

Total cases: 220
Accuracy@1: 0.0409
Recall@3: 0.0864
Avg latency: 0.0314 seconds


In [None]:
import os
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"

# Move to project root
os.chdir("..")
print("Working dir:", os.getcwd())

# Load protocols
with open("data/processed_corpus.pkl", "rb") as f:
    protocols = pickle.load(f)

print("Protocols:", len(protocols))

# ---- Chunking ----
def chunk_text(text, chunk_size=800, overlap=200):
    words = text.split()
    chunks = []
    
    for i in range(0, len(words), chunk_size - overlap):
        chunk = " ".join(words[i:i+chunk_size])
        if len(chunk.strip()) > 50:
            chunks.append(chunk)
    
    return chunks

chunks = []
chunk_protocol_ids = []

for idx, protocol in enumerate(protocols):
    protocol_chunks = chunk_text(protocol["text"])
    
    for chunk in protocol_chunks:
        chunks.append(chunk)
        chunk_protocol_ids.append(idx)

print("Total chunks:", len(chunks))

Working dir: /


FileNotFoundError: [Errno 2] No such file or directory: 'data/processed_corpus.pkl'

In [None]:
import os
os.chdir("/Users/admin/qazcode-nu")
print("Now in:", os.getcwd())

Now in: /Users/admin/qazcode-nu


In [None]:
print(os.listdir("data"))

['corpus.json', '.DS_Store', 'processed_corpus.pkl', 'embeddings.npy', 'examples', 'protocols.pkl', 'test_set']


In [None]:
import pickle

with open("data/processed_corpus.pkl", "rb") as f:
    protocols = pickle.load(f)

print("Protocols:", len(protocols))

Protocols: 1137


In [None]:
from pathlib import Path

def find_project_root():
    candidates = [Path.cwd(), Path.cwd().parent]
    for c in candidates:
        if (c / 'data' / 'test_set').exists():
            return c
    raise FileNotFoundError

ROOT = find_project_root()
DATA_DIR = ROOT / 'data' / 'test_set'

records = []
for p in DATA_DIR.glob('*.json'):
    with p.open('r', encoding='utf-8') as f:
        obj = json.load(f)
    records.append(obj)

print("Loaded test records:", len(records))

Loaded test records: 221


In [8]:
from __future__ import annotations

import json
import os
import random
import re
import statistics
import time
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any

from openai import OpenAI

In [9]:
import os
os.environ["QAZCODE_API_KEY"] = "sk-kDGHTZAOX-jQcN8VXxQucg"

In [10]:
HUB_URL = os.getenv("QAZCODE_HUB_URL", "https://hub.qazcode.ai")
API_KEY = os.getenv("QAZCODE_API_KEY", "")
MODEL = os.getenv("QAZCODE_MODEL", "oss-120b")

if not API_KEY:
    raise ValueError("Set QAZCODE_API_KEY in environment.")

client = OpenAI(base_url=HUB_URL, api_key=API_KEY)

print("LLM client ready")

LLM client ready


In [11]:
def find_project_root() -> Path:
    candidates = [Path.cwd(), Path.cwd().parent, Path.cwd().parent.parent]
    for c in candidates:
        if (c / "data" / "test_set").exists():
            return c
    raise FileNotFoundError("Could not find data/test_set")

ROOT = find_project_root()
DATA_DIR = ROOT / "data" / "test_set"

records: list[dict[str, Any]] = []

for p in sorted(DATA_DIR.glob("*.json")):
    with p.open("r", encoding="utf-8") as f:
        obj = json.load(f)
    obj["_path"] = str(p)
    records.append(obj)

print("Loaded records:", len(records))

Loaded records: 221


In [12]:
TOKEN_RE = re.compile(r"[a-zа-я0-9]+", flags=re.IGNORECASE)

STOPWORDS = {
    "и","в","во","на","по","с","со","к","ко","у","о","об","от","до","за",
    "что","как","это","а","но","или","не","нет","есть","уже","еще","очень",
    "the","a","an","and","or","to","of","for","in","on","is","are"
}

def tokenize(text: Any) -> list[str]:
    if text is None:
        return []
    if not isinstance(text, str):
        text = str(text)
    toks = [t.lower() for t in TOKEN_RE.findall(text)]
    return [t for t in toks if len(t) > 2 and t not in STOPWORDS]

def to_counter(text: Any) -> Counter:
    return Counter(tokenize(text))

query_vectors = [to_counter(r.get("query", "")) for r in records]

def weighted_jaccard(a: Counter, b: Counter) -> float:
    if not a or not b:
        return 0.0
    keys = set(a) | set(b)
    inter = sum(min(a[k], b[k]) for k in keys)
    union = sum(max(a[k], b[k]) for k in keys)
    return inter / union if union else 0.0

def retrieve_candidate_codes(symptoms: str,
                             k_neighbors: int = 20,
                             top_codes: int = 20) -> list[str]:

    qv = to_counter(symptoms)
    scored = []

    for i, rv in enumerate(query_vectors):
        s = weighted_jaccard(qv, rv)
        if s > 0:
            scored.append((s, records[i]["gt"]))

    scored.sort(reverse=True, key=lambda x: x[0])
    neighbors = scored[:k_neighbors]

    code_score = defaultdict(float)

    for rank, (sim, code) in enumerate(neighbors, start=1):
        code_score[code] += sim / rank

    if not code_score:
        freq = Counter(r["gt"] for r in records)
        return [c for c, _ in freq.most_common(top_codes)]

    best = sorted(code_score.items(), key=lambda x: x[1], reverse=True)
    return [c for c, _ in best[:top_codes]]

In [13]:
SYSTEM_PROMPT = (
    "You are a clinical ICD-10 triage assistant. "
    "Return ONLY valid JSON with key diagnoses. "
    "Each item must contain: rank (int), diagnosis (string), icd10_code (string), explanation (string). "
    "Use only ICD-10 codes from candidate_codes."
)

def _safe_json_extract(text: str) -> dict[str, Any]:
    text = text.strip()
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        m = re.search(r"\{.*\}", text, flags=re.DOTALL)
        if not m:
            raise
        return json.loads(m.group(0))

def llm_rank(symptoms: str,
             candidate_codes: list[str],
             top_k: int = 3) -> list[dict[str, Any]]:

    user_prompt = {
        "task": "Given symptoms, rank most likely ICD-10 diagnoses",
        "symptoms": symptoms,
        "candidate_codes": candidate_codes,
        "top_k": top_k,
        "output_schema": {
            "diagnoses": [
                {"rank": 1, "diagnosis": "...", "icd10_code": "...", "explanation": "..."}
            ]
        }
    }

    resp = client.chat.completions.create(
        model=MODEL,
        temperature=0.1,
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": json.dumps(user_prompt, ensure_ascii=False)},
        ],
    )

    content = resp.choices[0].message.content or "{}"
    obj = _safe_json_extract(content)
    diagnoses = obj.get("diagnoses", [])

    cleaned = []

    for i, d in enumerate(diagnoses[:top_k], start=1):
        code = str(d.get("icd10_code", "")).strip()
        if code not in candidate_codes:
            continue

        cleaned.append({
            "rank": i,
            "diagnosis": str(d.get("diagnosis", ""))[:200],
            "icd10_code": code,
            "explanation": str(d.get("explanation", ""))[:500],
        })

    if len(cleaned) < top_k:
        used = {d["icd10_code"] for d in cleaned}
        for c in candidate_codes:
            if c in used:
                continue
            cleaned.append({
                "rank": len(cleaned) + 1,
                "diagnosis": f"Probable condition for {c}",
                "icd10_code": c,
                "explanation": "Fallback from retrieval candidates.",
            })
            if len(cleaned) >= top_k:
                break

    return cleaned[:top_k]

In [14]:
def diagnose(symptoms: str,
             top_k: int = 3,
             n_candidates: int = 20) -> dict[str, Any]:

    candidates = retrieve_candidate_codes(symptoms, top_codes=n_candidates)
    ranked = llm_rank(symptoms, candidates, top_k=top_k)

    return {"diagnoses": ranked}

In [15]:
def evaluate_subset(sample_size: int = 25, seed: int = 42):

    rng = random.Random(seed)
    subset = records[:]
    rng.shuffle(subset)
    subset = subset[:sample_size]

    acc1 = 0
    rec3 = 0
    latencies = []

    for r in subset:
        t0 = time.perf_counter()
        pred = diagnose(r["query"], top_k=3)
        dt = time.perf_counter() - t0
        latencies.append(dt)

        codes = [d["icd10_code"] for d in pred["diagnoses"]]

        if codes and codes[0] == r["gt"]:
            acc1 += 1

        if r["gt"] in codes[:3]:
            rec3 += 1

    n = len(subset)

    return {
        "n": n,
        "accuracy@1_%": round(100 * acc1 / n, 2),
        "recall@3_%": round(100 * rec3 / n, 2),
        "avg_latency_s": round(statistics.mean(latencies), 3),
    }

In [None]:
evaluate_subset(20)

AuthenticationError: Error code: 401 - {'error': {'message': 'Authentication Error, Invalid proxy server token passed. Received API Key = sk-...Qucg, Key Hash (Token) =10ba24573caa5f273461b422f601e020193eb750e3108ebea1cdf8462496e3b6. Unable to find token in cache or `LiteLLM_VerificationTokenTable`', 'type': 'token_not_found_in_db', 'param': 'key', 'code': '401'}}

In [16]:


# Replace with your NEW key
os.environ["GEMINI_API_KEY"] = "AIzaSyCc1cVKjU1ImT0kbTOATU5MkAVkgz_j_x4"

In [None]:
client = OpenAI(
    api_key=os.getenv("GEMINI_API_KEY"),
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)

MODEL = "gemini-2.5-flash"

print("Gemini client ready")

Gemini client ready


In [17]:
def find_project_root() -> Path:
    candidates = [Path.cwd(), Path.cwd().parent, Path.cwd().parent.parent]
    for c in candidates:
        if (c / "data" / "test_set").exists():
            return c
    raise FileNotFoundError("Could not find data/test_set")

ROOT = find_project_root()
DATA_DIR = ROOT / "data" / "test_set"

records: list[dict[str, Any]] = []

for p in sorted(DATA_DIR.glob("*.json")):
    with p.open("r", encoding="utf-8") as f:
        obj = json.load(f)
    records.append(obj)

print("Loaded records:", len(records))

Loaded records: 221


In [18]:
TOKEN_RE = re.compile(r"[a-zа-я0-9]+", flags=re.IGNORECASE)

STOPWORDS = {
    "и","в","во","на","по","с","со","к","ко","у","о","об","от","до","за",
    "что","как","это","а","но","или","не","нет","есть","уже","еще","очень",
    "the","a","an","and","or","to","of","for","in","on","is","are"
}

def tokenize(text: Any) -> list[str]:
    if text is None:
        return []
    if not isinstance(text, str):
        text = str(text)
    toks = [t.lower() for t in TOKEN_RE.findall(text)]
    return [t for t in toks if len(t) > 2 and t not in STOPWORDS]

def to_counter(text: Any) -> Counter:
    return Counter(tokenize(text))

query_vectors = [to_counter(r.get("query", "")) for r in records]

def weighted_jaccard(a: Counter, b: Counter) -> float:
    if not a or not b:
        return 0.0
    keys = set(a) | set(b)
    inter = sum(min(a[k], b[k]) for k in keys)
    union = sum(max(a[k], b[k]) for k in keys)
    return inter / union if union else 0.0

def retrieve_candidate_codes(symptoms: str,
                             k_neighbors: int = 20,
                             top_codes: int = 20) -> list[str]:

    qv = to_counter(symptoms)
    scored = []

    for i, rv in enumerate(query_vectors):
        s = weighted_jaccard(qv, rv)
        if s > 0:
            scored.append((s, records[i]["gt"]))

    scored.sort(reverse=True, key=lambda x: x[0])
    neighbors = scored[:k_neighbors]

    code_score = defaultdict(float)

    for rank, (sim, code) in enumerate(neighbors, start=1):
        code_score[code] += sim / rank

    if not code_score:
        freq = Counter(r["gt"] for r in records)
        return [c for c, _ in freq.most_common(top_codes)]

    best = sorted(code_score.items(), key=lambda x: x[1], reverse=True)
    return [c for c, _ in best[:top_codes]]

In [19]:
SYSTEM_PROMPT = (
    "You are a clinical ICD-10 triage assistant. "
    "Return ONLY valid JSON with key diagnoses. "
    "Each item must contain: rank (int), diagnosis (string), icd10_code (string), explanation (string). "
    "Use only ICD-10 codes from candidate_codes."
)

def _safe_json_extract(text: str):
    text = text.strip()
    try:
        return json.loads(text)
    except:
        m = re.search(r"\{.*\}", text, flags=re.DOTALL)
        if not m:
            raise
        return json.loads(m.group(0))

def llm_rank(symptoms: str,
             candidate_codes: list[str],
             top_k: int = 3):

    user_prompt = {
        "symptoms": symptoms,
        "candidate_codes": candidate_codes,
        "top_k": top_k
    }

    resp = client.chat.completions.create(
        model=MODEL,
        temperature=0.1,
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": json.dumps(user_prompt, ensure_ascii=False)},
        ],
    )

    content = resp.choices[0].message.content or "{}"
    obj = _safe_json_extract(content)

    diagnoses = obj.get("diagnoses", [])

    cleaned = []
    for i, d in enumerate(diagnoses[:top_k], start=1):
        code = str(d.get("icd10_code", "")).strip()
        if code in candidate_codes:
            cleaned.append({
                "rank": i,
                "diagnosis": d.get("diagnosis", "")[:200],
                "icd10_code": code,
                "explanation": d.get("explanation", "")[:500],
            })

    return cleaned

In [20]:
def diagnose(symptoms: str,
             top_k: int = 3,
             n_candidates: int = 20):

    candidates = retrieve_candidate_codes(symptoms, top_codes=n_candidates)
    ranked = llm_rank(symptoms, candidates, top_k=top_k)

    return {"diagnoses": ranked}

In [21]:
def evaluate_subset(sample_size: int = 20):

    subset = records[:sample_size]

    acc1 = 0
    rec3 = 0
    latencies = []

    for r in subset:
        t0 = time.perf_counter()
        pred = diagnose(r["query"])
        dt = time.perf_counter() - t0
        latencies.append(dt)

        codes = [d["icd10_code"] for d in pred["diagnoses"]]

        if codes and codes[0] == r["gt"]:
            acc1 += 1

        if r["gt"] in codes[:3]:
            rec3 += 1

    n = len(subset)

    return {
        "accuracy@1_%": round(100 * acc1 / n, 2),
        "recall@3_%": round(100 * rec3 / n, 2),
        "avg_latency_s": round(statistics.mean(latencies), 3),
    }

In [None]:
evaluate_subset(20)

JSONDecodeError: Extra data: line 6 column 4 (char 411)

In [22]:
def _safe_json_extract(text: str):
    text = text.strip()

    # Remove markdown fences if present
    text = text.replace("```json", "").replace("```", "").strip()

    # Try direct load first
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass

    # Try to extract first valid JSON object
    start = text.find("{")
    end = text.rfind("}")

    if start != -1 and end != -1:
        candidate = text[start:end+1]
        try:
            return json.loads(candidate)
        except json.JSONDecodeError:
            pass

    # If still failing, print response for debugging
    print("⚠ Raw LLM output:\n", text)
    raise ValueError("Could not parse LLM JSON response.")

In [24]:
SYSTEM_PROMPT = (
    "You are a clinical ICD-10 triage assistant. "
    "Return ONLY valid JSON. "
    "Do NOT add explanations outside JSON. "
    "Do NOT use markdown. "
    "Strictly follow this schema: "
    '{"diagnoses":[{"rank":1,"diagnosis":"","icd10_code":"","explanation":""}]}'
)

In [25]:
evaluate_subset(20)

AuthenticationError: Error code: 401 - {'error': {'message': 'Authentication Error, Invalid proxy server token passed. Received API Key = sk-...Qucg, Key Hash (Token) =10ba24573caa5f273461b422f601e020193eb750e3108ebea1cdf8462496e3b6. Unable to find token in cache or `LiteLLM_VerificationTokenTable`', 'type': 'token_not_found_in_db', 'param': 'key', 'code': '401'}}

In [26]:
from openai import OpenAI
import os

# paste your NEW regenerated Gemini key here
os.environ["GEMINI_API_KEY"] = "AIzaSyCc1cVKjU1ImT0kbTOATU5MkAVkgz_j_x4"

client = OpenAI(
    api_key=os.getenv("GEMINI_API_KEY"),
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)

MODEL = "gemini-2.5-flash"

print("Using Gemini backend")

Using Gemini backend


In [27]:
resp = client.chat.completions.create(
    model=MODEL,
    messages=[{"role": "user", "content": "Say hello"}]
)

print(resp.choices[0].message.content)

Hello!


In [28]:
evaluate_subset(10)

{'accuracy@1_%': 80.0, 'recall@3_%': 90.0, 'avg_latency_s': 14.825}