In [5]:
!pip install openai



In [6]:
from __future__ import annotations
import json
import os
import random
import re
import statistics
import time
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any
from openai import OpenAI

In [7]:
os.environ['QAZCODE_API_KEY'] = 'sk-kDGHTZAOX-jQcN8VXxQucg'

In [8]:
HUB_URL = os.getenv('QAZCODE_HUB_URL', 'https://hub.qazcode.ai')

In [9]:
API_KEY = os.getenv('QAZCODE_API_KEY', '')

In [10]:
MODEL = os.getenv('QAZCODE_MODEL', 'oss-120b')

In [11]:
if not API_KEY:
    raise ValueError('Set QAZCODE_API_KEY in environment before running this notebook.')

In [12]:
client = OpenAI(base_url=HUB_URL, api_key=API_KEY)
print('Client ready:', HUB_URL, 'model=', MODEL)

Client ready: https://hub.qazcode.ai model= oss-120b


Loading the data

In [13]:
from pathlib import Path
import json
from typing import Any

# Find project root by locating data/test_set
ROOT = next(
    p for p in [Path.cwd(), *Path.cwd().parents]
    if (p / "data" / "test_set").is_dir()
)

DATA_DIR = ROOT / "data" / "test_set"

records: list[dict[str, Any]] = []
for p in sorted(DATA_DIR.glob("*.json")):
    obj = json.loads(p.read_text(encoding="utf-8"))
    obj["_path"] = str(p)
    records.append(obj)

all_codes = sorted({r["gt"] for r in records})

In [15]:
import re
from typing import Any
from collections import Counter, defaultdict

TOKEN_RE = re.compile(r"[a-zа-я0-9]+", re.IGNORECASE)
STOPWORDS = {
    "и","в","во","на","по","с","со","к","ко","у","о","об","от","до","за",
    "что","как","это","а","но","или","не","нет","есть","уже","еще","очень",
    "the","a","an","and","or","to","of","for","in","on","is","are",
}

def tokenize(text: Any) -> list[str]:
    if text is None:
        return []
    if not isinstance(text, str):
        text = str(text)
    toks = (t.lower() for t in TOKEN_RE.findall(text))
    return [t for t in toks if len(t) > 2 and t not in STOPWORDS]

def to_counter(text: Any) -> Counter:
    return Counter(tokenize(text))

# assumes `records` exists
query_vectors = [to_counter(r.get("query", "")) for r in records]

def weighted_jaccard(a: Counter, b: Counter) -> float:
    if not a or not b:
        return 0.0
    keys = a.keys() | b.keys()
    inter = sum(min(a[k], b[k]) for k in keys)
    union = sum(max(a[k], b[k]) for k in keys)
    return inter / union if union else 0.0

def retrieve_candidate_codes(symptoms: str, k_neighbors: int = 20, top_codes: int = 20) -> list[str]:
    qv = to_counter(symptoms)

    scored = [
        (weighted_jaccard(qv, rv), records[i]["gt"])
        for i, rv in enumerate(query_vectors)
    ]
    scored = [(s, code) for s, code in scored if s > 0]
    scored.sort(key=lambda x: x[0], reverse=True)

    neighbors = scored[:k_neighbors]
    if not neighbors:
        freq = Counter(r["gt"] for r in records)
        return [c for c, _ in freq.most_common(top_codes)]

    code_score = defaultdict(float)
    for rank, (sim, code) in enumerate(neighbors, start=1):
        code_score[code] += sim / rank

    best = sorted(code_score.items(), key=lambda x: x[1], reverse=True)
    return [c for c, _ in best[:top_codes]]

Вызов LLM и строгий JSON парсинг

In [16]:
SYSTEM_PROMPT = (
    "You are a clinical ICD-10 triage assistant. "
    "Return ONLY valid JSON. "
    "Output schema: {diagnoses: [{rank:int, diagnosis:str, icd10_code:str, explanation:str}]}. "
    "Use only ICD-10 codes from candidate_codes."
)

def safe_json_extract(text: str) -> dict[str, Any]:
    text = (text or "").strip()
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        m = re.search(r"\{.*\}", text, flags=re.DOTALL)
        if not m:
            raise
        return json.loads(m.group(0))

def llm_rank(symptoms: str, candidate_codes: list[str], top_k: int = 3) -> list[dict[str, Any]]:
    payload = {
        "symptoms": symptoms,
        "candidate_codes": candidate_codes,
        "top_k": top_k,
        "output_schema": {
            "diagnoses": [{"rank": 1, "diagnosis": "...", "icd10_code": "...", "explanation": "..."}]
        },
    }

    resp = client.chat.completions.create(
        model=MODEL,
        temperature=0.1,
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": json.dumps(payload, ensure_ascii=False)},
        ],
    )

    content = (resp.choices[0].message.content or "{}")
    diagnoses = safe_json_extract(content).get("diagnoses", [])

    cleaned: list[dict[str, Any]] = []
    for d in diagnoses:
        code = str(d.get("icd10_code", "")).strip()
        if code in candidate_codes:
            cleaned.append({
                "rank": len(cleaned) + 1,
                "diagnosis": str(d.get("diagnosis", "Unknown diagnosis"))[:200],
                "icd10_code": code,
                "explanation": str(d.get("explanation", ""))[:500],
            })
        if len(cleaned) >= top_k:
            break

    # fallback if LLM output is missing/invalid/out-of-list
    if len(cleaned) < top_k:
        used = {x["icd10_code"] for x in cleaned}
        for c in candidate_codes:
            if c in used:
                continue
            cleaned.append({
                "rank": len(cleaned) + 1,
                "diagnosis": f"Probable condition for {c}",
                "icd10_code": c,
                "explanation": "Fallback from retrieval candidate list.",
            })
            if len(cleaned) >= top_k:
                break

    return cleaned

Главная функция предсказания (формат evaluate.py)

In [18]:

def diagnose(symptoms: str, top_k: int = 3, n_candidates: int = 20) -> dict[str, Any]:
    candidates = retrieve_candidate_codes(symptoms, top_codes=n_candidates)
    ranked = llm_rank(symptoms, candidates, top_k=top_k)
    return {"diagnoses": ranked}

# smoke test
example_symptoms = records[0].get("query", "")
result = diagnose(example_symptoms, top_k=3)
print(json.dumps(result, ensure_ascii=False, indent=2)[:1500])

{
  "diagnoses": [
    {
      "rank": 1,
      "diagnosis": "Травма (растяжение/разрыв) связок и мышц грудного (торакального) отдела позвоночника",
      "icd10_code": "S33.0",
      "explanation": "Боль в средней части спины, усиливающаяся при наклонах, кашле, глубоком вдохе и длительном положении, типична для растяжения или разрыва мягких тканей грудного отдела позвоночника. Травма падения с ударом между лопатками часто приводит к подобным повреждениям без перелома, но с выраженной мышечной спазмованностью и болевым синдромом."
    },
    {
      "rank": 2,
      "diagnosis": "Перелом ребра",
      "icd10_code": "S22.1",
      "explanation": "Боль, усиливающаяся при дыхании, кашле и движениях туловища, а также ощущение «прокалывания» могут указывать на перелом одного или нескольких ребер. При падении на спину ребра в области между лопатками часто получают травму, что объясняет болезненность при глубоком вдохе и ограничение подвижности."
    },
    {
      "rank": 3,
      "diagnosis

Быстрая локальная проверка на части датасета

In [19]:
import random, time, statistics
from typing import Any

def evaluate_subset(sample_size: int = 25, seed: int = 42) -> dict[str, Any]:
    rng = random.Random(seed)
    subset = records[:]
    rng.shuffle(subset)
    subset = subset[:min(sample_size, len(subset))]

    acc1 = rec3 = 0
    latencies: list[float] = []

    for r in subset:
        t0 = time.perf_counter()
        pred = diagnose(r.get("query", ""), top_k=3)
        latencies.append(time.perf_counter() - t0)

        codes = [d.get("icd10_code", "") for d in pred.get("diagnoses", [])]

        if codes and codes[0] == r.get("gt"):
            acc1 += 1

        valid = set(r.get("icd_codes", []))
        if any(c in valid for c in codes[:3]):
            rec3 += 1

    n = len(subset) or 1
    return {
        "n": len(subset),
        "accuracy_at_1_percent": round(100 * acc1 / n, 2),
        "recall_at_3_percent": round(100 * rec3 / n, 2),
        "latency_avg_s": round(statistics.mean(latencies), 3) if latencies else None,
    }

metrics = evaluate_subset(sample_size=20)
metrics

{'n': 20,
 'accuracy_at_1_percent': 50.0,
 'recall_at_3_percent': 90.0,
 'latency_avg_s': 26.594}

In [None]:
{'n': 20,
 'accuracy_at_1_percent': 50.0,
 'recall_at_3_percent': 90.0,
 'latency_avg_s': 26.594}