# QazCode Hackathon: Simple LLM Baseline

Цель: по симптомам вернуть `top-N` диагнозов с `ICD-10` в формате, совместимом с `evaluate.py`.

Пайплайн:
1. Загружаем `data/test_set` и строим простой retrieval по текстам запросов.
2. Выбираем кандидаты ICD локально (быстро и стабильно).
3. Передаем только кандидатов в LLM (`oss-120b`), чтобы ранжировать top-3.
4. Возвращаем JSON c `diagnoses: [{rank, diagnosis, icd10_code, explanation}]`.


In [None]:
from __future__ import annotations

import json
import os
import random
import re
import statistics
import time
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any

from openai import OpenAI


## 1) Настройка LLM API

Не сохраняйте ключ в git. Используйте переменные окружения.


In [None]:
# Вставьте ключ один раз в текущей сессии ноутбука
# os.environ['QAZCODE_API_KEY'] = 'sk-...'<- не коммитить

HUB_URL = os.getenv('QAZCODE_HUB_URL', 'https://hub.qazcode.ai')
API_KEY = os.getenv('QAZCODE_API_KEY', '')
MODEL = os.getenv('QAZCODE_MODEL', 'oss-120b')

if not API_KEY:
    raise ValueError('Set QAZCODE_API_KEY in environment before running this notebook.')

client = OpenAI(base_url=HUB_URL, api_key=API_KEY)
print('Client ready:', HUB_URL, 'model=', MODEL)


## 2) Загрузка данных


In [None]:
def find_project_root() -> Path:
    candidates = [Path.cwd(), Path.cwd().parent, Path.cwd().parent.parent]
    for c in candidates:
        if (c / 'data' / 'test_set').exists():
            return c
    raise FileNotFoundError('Could not find data/test_set from current working dir.')

ROOT = find_project_root()
DATA_DIR = ROOT / 'data' / 'test_set'

records: list[dict[str, Any]] = []
for p in sorted(DATA_DIR.glob('*.json')):
    with p.open('r', encoding='utf-8') as f:
        obj = json.load(f)
    obj['_path'] = str(p)
    records.append(obj)

print('Loaded records:', len(records))
all_codes = sorted({r['gt'] for r in records})
print('Unique GT codes:', len(all_codes))


## 3) Простой retrieval для ICD кандидатов

Идея: ищем похожие симптомы по токенам и берем самые частые `gt` у соседей.


In [None]:
TOKEN_RE = re.compile(r'[a-zа-я0-9]+', flags=re.IGNORECASE)
STOPWORDS = {
    'и', 'в', 'во', 'на', 'по', 'с', 'со', 'к', 'ко', 'у', 'о', 'об', 'от', 'до', 'за',
    'что', 'как', 'это', 'а', 'но', 'или', 'не', 'нет', 'есть', 'уже', 'еще', 'очень',
    'the', 'a', 'an', 'and', 'or', 'to', 'of', 'for', 'in', 'on', 'is', 'are'
}

def tokenize(text: Any) -> list[str]:
    if text is None:
        return []
    if not isinstance(text, str):
        text = str(text)
    toks = [t.lower() for t in TOKEN_RE.findall(text)]
    return [t for t in toks if len(t) > 2 and t not in STOPWORDS]

def to_counter(text: Any) -> Counter:
    return Counter(tokenize(text))

query_vectors: list[Counter] = []
for r in records:
    query_vectors.append(to_counter(r.get('query', '')))

def weighted_jaccard(a: Counter, b: Counter) -> float:
    if not a or not b:
        return 0.0
    keys = set(a) | set(b)
    inter = sum(min(a[k], b[k]) for k in keys)
    union = sum(max(a[k], b[k]) for k in keys)
    return inter / union if union else 0.0

def retrieve_candidate_codes(symptoms: str, k_neighbors: int = 20, top_codes: int = 20) -> list[str]:
    qv = to_counter(symptoms)
    scored = []
    for i, rv in enumerate(query_vectors):
        s = weighted_jaccard(qv, rv)
        if s > 0:
            scored.append((s, records[i]['gt']))

    scored.sort(reverse=True, key=lambda x: x[0])
    neighbors = scored[:k_neighbors]

    code_score = defaultdict(float)
    for rank, (sim, code) in enumerate(neighbors, start=1):
        code_score[code] += sim / rank

    if not code_score:
        # fallback: самые частые коды в датасете
        freq = Counter(r['gt'] for r in records)
        return [c for c, _ in freq.most_common(top_codes)]

    best = sorted(code_score.items(), key=lambda x: x[1], reverse=True)
    return [c for c, _ in best[:top_codes]]


## 4) Вызов LLM и строгий JSON парсинг


In [None]:
SYSTEM_PROMPT = (
    'You are a clinical ICD-10 triage assistant. '
    'Return ONLY valid JSON with key diagnoses. '
    'Each item must contain: rank (int), diagnosis (string), icd10_code (string), explanation (string). '
    'Use only ICD-10 codes from candidate_codes. '
)

def _safe_json_extract(text: str) -> dict[str, Any]:
    text = text.strip()
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        m = re.search(r'\{.*\}', text, flags=re.DOTALL)
        if not m:
            raise
        return json.loads(m.group(0))

def llm_rank(symptoms: str, candidate_codes: list[str], top_k: int = 3) -> list[dict[str, Any]]:
    user_prompt = {
        'task': 'Given symptoms, rank most likely ICD-10 diagnoses',
        'symptoms': symptoms,
        'candidate_codes': candidate_codes,
        'top_k': top_k,
        'output_schema': {
            'diagnoses': [
                {'rank': 1, 'diagnosis': '...', 'icd10_code': '...', 'explanation': '...'}
            ]
        }
    }

    resp = client.chat.completions.create(
        model=MODEL,
        temperature=0.1,
        messages=[
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': json.dumps(user_prompt, ensure_ascii=False)},
        ],
    )

    content = resp.choices[0].message.content or '{}'
    obj = _safe_json_extract(content)
    diagnoses = obj.get('diagnoses', [])

    cleaned = []
    for i, d in enumerate(diagnoses[:top_k], start=1):
        code = str(d.get('icd10_code', '')).strip()
        if code not in candidate_codes:
            continue
        cleaned.append({
            'rank': i,
            'diagnosis': str(d.get('diagnosis', 'Unknown diagnosis'))[:200],
            'icd10_code': code,
            'explanation': str(d.get('explanation', ''))[:500],
        })

    # fallback если LLM вернул мало/плохой JSON
    if len(cleaned) < top_k:
        used = {d['icd10_code'] for d in cleaned}
        for c in candidate_codes:
            if c in used:
                continue
            cleaned.append({
                'rank': len(cleaned) + 1,
                'diagnosis': f'Probable condition for {c}',
                'icd10_code': c,
                'explanation': 'Fallback from retrieval candidate list.',
            })
            if len(cleaned) >= top_k:
                break

    return cleaned[:top_k]


## 5) Главная функция предсказания (формат evaluate.py)


In [None]:
def diagnose(symptoms: str, top_k: int = 3, n_candidates: int = 20) -> dict[str, Any]:
    candidates = retrieve_candidate_codes(symptoms, top_codes=n_candidates)
    ranked = llm_rank(symptoms, candidates, top_k=top_k)
    return {'diagnoses': ranked}

# smoke test
example_symptoms = records[0]['query']
result = diagnose(example_symptoms, top_k=3)
print(json.dumps(result, ensure_ascii=False, indent=2)[:1500])


## 6) Быстрая локальная проверка на части датасета

Важно: это proxy-оценка на публичных данных, а не финальный holdout leaderboard.


In [None]:
def evaluate_subset(sample_size: int = 25, seed: int = 42) -> dict[str, Any]:
    rng = random.Random(seed)
    subset = records[:]
    rng.shuffle(subset)
    subset = subset[:min(sample_size, len(subset))]

    acc1 = 0
    rec3 = 0
    latencies = []

    for r in subset:
        t0 = time.perf_counter()
        pred = diagnose(r['query'], top_k=3)
        dt = time.perf_counter() - t0
        latencies.append(dt)

        codes = [d.get('icd10_code', '') for d in pred.get('diagnoses', [])]
        if codes and codes[0] == r['gt']:
            acc1 += 1

        valid = set(r.get('icd_codes', []))
        if any(c in valid for c in codes[:3]):
            rec3 += 1

    n = len(subset) or 1
    return {
        'n': len(subset),
        'accuracy_at_1_percent': round(100 * acc1 / n, 2),
        'recall_at_3_percent': round(100 * rec3 / n, 2),
        'latency_avg_s': round(statistics.mean(latencies), 3) if latencies else None,
    }

metrics = evaluate_subset(sample_size=20)
metrics


## 7) Что дальше для продакшена

1. Вынесите `diagnose()` в `src/mock_server.py` (или новый `src/server.py`) как POST `/diagnose`.
2. Запустите `evaluate.py` на всем `data/test_set`.
3. Для ускорения: кэшируйте ответы LLM и уменьшайте `n_candidates` до 10-15.
4. Для качества: добавьте few-shot примеры из похожих кейсов в prompt.
