In [1]:
!pip install -q bitsandbytes transformers langchain langchain_community langchain_huggingface langchain-chroma rank_bm25 chromadb


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/67.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m35.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m45.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.8/19.8 MB[0m [31m67.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m284.2/284.2 kB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m88.3 MB/s[0m eta [36m0:00:00

In [2]:
!git clone https://github.com/barbwozn/KRUS-data-.git

Cloning into 'KRUS-data-'...
remote: Enumerating objects: 385, done.[K
remote: Counting objects: 100% (111/111), done.[K
remote: Compressing objects: 100% (90/90), done.[K
remote: Total 385 (delta 29), reused 4 (delta 0), pack-reused 274 (from 1)[K
Receiving objects: 100% (385/385), 216.92 KiB | 929.00 KiB/s, done.
Resolving deltas: 100% (117/117), done.


In [None]:
# -*- coding: utf-8 -*-
# ============================================================
#  KRUS — RAG 3.1 (pod schemat: dataset, measure, value, region, period, typ)
#  • CLEAN: value/period/measure/typ (nawias → typ, literówki)
#  • TAGI: region/typ (synonimy) → metadane + wstrzyknięcie do tekstu
#  • RETRIEVE: BM25 + Dense + MQ + HyDE → RRF → (opcjonalnie) Rerank (BGE m3)
#  • REGUŁY: najnowszy okres + ogółem jeśli brak specyfiki
#  • LLM: PLLuM-12B-chat (4-bit) — można wyłączyć i zwracać liczby bez LLM
# ============================================================

import os, re, json, unicodedata, random
from typing import List, Dict, Any, Tuple, Optional
import pandas as pd
import torch

# --- LangChain
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import ContextualCompressionRetriever

# --- Cross-encoder (reranker) – opcjonalnie
_HAS_RERANK = True
try:
    from langchain_community.cross_encoders import HuggingFaceCrossEncoder
    from langchain.retrievers.document_compressors import CrossEncoderReranker
except Exception:
    _HAS_RERANK = False

# --- Chroma (nowe/stare API)
try:
    from langchain_chroma import Chroma
    _CHROMA_NEW = True
except ImportError:
    from langchain_community.vectorstores import Chroma
    _CHROMA_NEW = False

# -----------------------------
# Ścieżki
# -----------------------------
CSV_DIR = "C:/Users/admin/Desktop/KRUS-data-/KRUS-data-/all_dane"  
PERSIST_DIR = "chroma_statystyki"
RERANKER_MODEL = "radlab/polish-cross-encoder"
EMBEDDER_MODEL = "intfloat/multilingual-e5-large"
# -----------------------------
# Normalizacja / utils
# -----------------------------
def strip_acc(s: str) -> str:
    return "".join(c for c in unicodedata.normalize("NFD", str(s)) if unicodedata.category(c) != "Mn")

def norm_text(s: str) -> str:
    s = strip_acc(str(s)).lower()
    s = re.sub(r"\s+", " ", s).strip()
    return s

def _serialize_tags(tags: List[str]) -> str:
    return ",".join(sorted({norm_text(t) for t in tags if t and str(t).strip()}))

def _parse_tags_field(val) -> set[str]:
    if val is None: return set()
    if isinstance(val, list): return {norm_text(x) for x in val}
    s = str(val).strip()
    if s.startswith("[") and s.endswith("]"):
        try:
            arr = json.loads(s)
            if isinstance(arr, list): return {norm_text(x) for x in arr}
        except Exception:
            pass
    parts = re.split(r"[,\|;]\s*", s)
    return {norm_text(p) for p in parts if p}

# -----------------------------
# Smart CSV reader (PL encodings)
# -----------------------------
def _read_csv_smart(path: str) -> pd.DataFrame:
    for enc in ["utf-8-sig", "cp1250", "iso-8859-2", "latin-1"]:
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception:
            pass
    return pd.read_csv(path)

# -----------------------------
# Synonimy regionów/typów
# -----------------------------
REGION_GENERAL = {norm_text(x) for x in ["ogółem","razem","polska","kraj","ogolnie","ogółem/razem","cały kraj","caly kraj"]}
REGION_SYNS = {
    "ogolem": ["ogółem","razem","polska","kraj","ogolnie","cały kraj","caly kraj","ogółem/razem"],
}
TYPE_SYNS = {
    "emerytura": ["emerytury","świadczenia emerytalne","swiadczenia emerytalne","emerytalne"],
    "renta_niezdolnosc": ["renty z tytułu niezdolności do pracy","renta niezdolność","renty_niezdolnosc_do_pracy","niezdolnosc_do_pracy"],
    "renta_rodzinna": ["renty rodzinne","renta rodzinna"],
    "swiadczenia_zabiegowe_robotnicze": ["świadczenia zabiegowe robotnicze","swiadczenia zabiegowe robotnicze"],
}

VOIVODESHIPS = [
    "dolnośląskie","kujawsko-pomorskie","lubelskie","lubuskie","łódzkie",
    "małopolskie","mazowieckie","opolskie","podkarpackie","podlaskie",
    "pomorskie","śląskie","świętokrzyskie","warmińsko-mazurskie",
    "wielkopolskie","zachodniopomorskie"
]
VOIV_NORM = {norm_text(v): v for v in VOIVODESHIPS}

def expand_region_tags(val: Optional[str]) -> List[str]:
    if val is None or str(val).strip()=="" or norm_text(val) in REGION_GENERAL:
        return sorted(REGION_GENERAL)
    v = norm_text(val)
    tags = {v}
    for syns in REGION_SYNS.values():
        syns_n = {norm_text(s) for s in syns}
        if v in syns_n: tags |= syns_n
    return sorted(tags)

def expand_type_tags(val: Optional[str]) -> List[str]:
    if val is None or str(val).strip()=="":
        return []
    v = norm_text(val)
    for k, syns in TYPE_SYNS.items():
        syns_n = {norm_text(s) for s in syns}
        if v==k or v in syns_n:
            return [k] + sorted(syns_n)
    return [v]

# -----------------------------
# Parsing liczb i okresów
# -----------------------------
def parse_value(x) -> Optional[float]:
    if x is None: return None
    s = str(x).strip()
    if s == "" or norm_text(s) in {"nan", "brak", "null"}:
        return None
    s = s.replace(" ", "").replace("\u00A0", "").replace(",", ".")
    try:
        return float(s)
    except Exception:
        return None

def norm_period(p: Optional[str]) -> Optional[str]:
    if not p or str(p).strip()=="":
        return None
    s = str(p).replace("\u00A0"," ").strip()
    s = s.upper()
    s = re.sub(r"\s+", "", s)
    s = s.replace("/", "-").replace("_","-")
    if re.match(r"^20\d{2}-Q[1-4]$", s):
        return s
    m = re.match(r"^(20\d{2})Q([1-4])$", s)
    if m:
        return f"{m.group(1)}-Q{m.group(2)}"
    return s  # zostaw surowe (np. "202-Q1") — będzie mniej premiowane

def clean_measure_and_type(measure: str, typ: Optional[str]) -> Tuple[str, Optional[str]]:
    m = str(measure or "").strip()
    t = None if (typ is None or str(typ).strip()=="") else str(typ).strip()
    m = m.replace("przciętna", "przeciętna")
    m = m.replace(" w zl", " w zł")
    paren = re.search(r"\(([^)]+)\)", m)
    if paren and (t is None or t==""):
        t = paren.group(1).strip()
    m = re.sub(r"\s*\([^)]+\)\s*", " ", m)
    m = re.sub(r"\s+", " ", m).strip()
    return m, t

def make_page_text(row: dict) -> str:
    fields = [
        ("dataset", row.get("dataset")),
        ("measure", row.get("measure_clean")),
        ("value", row.get("value_float")),
        ("region", row.get("region") or "-"),
        ("period", row.get("period_norm") or row.get("period_raw") or "-"),
        ("typ", row.get("typ_clean") or "-"),
    ]
    base = " | ".join(f"{k}: {v}" for k, v in fields)
    syn = []
    if row.get("tags_region"): syn.append("region_syn: " + row["tags_region"])
    if row.get("tags_type"):   syn.append("type_syn: " + row["tags_type"])
    return base + (" | " + " | ".join(syn) if syn else "")

# -----------------------------
# Budowa Document z CSV (schemat standardowy)
# -----------------------------
def build_documents_from_standard_csv(dataset_name: str, csv_path: str) -> List[Document]:
    df = _read_csv_smart(csv_path)
    required = ["dataset","measure","value","region","period","typ"]
    for c in required:
        if c not in df.columns:
            raise ValueError(f"Brak wymaganej kolumny: {c}")
    docs: List[Document] = []
    for i, row in df.iterrows():
        dataset = str(row.get("dataset") or "").strip()
        measure_raw = row.get("measure")
        typ_raw     = row.get("typ")
        measure_clean, typ_clean = clean_measure_and_type(measure_raw, typ_raw)

        value_float = parse_value(row.get("value"))
        region_raw  = None if pd.isna(row.get("region")) else str(row.get("region")).strip()
        period_raw  = None if pd.isna(row.get("period")) else str(row.get("period")).strip()
        period_norm_ = norm_period(period_raw)

        tags_region = _serialize_tags(expand_region_tags(region_raw))
        tags_type   = _serialize_tags(expand_type_tags(typ_clean))

        rec = {
            "dataset": dataset or dataset_name,
            "measure_clean": measure_clean,
            "typ_clean": typ_clean,
            "value_float": value_float,
            "region": region_raw,
            "period_raw": period_raw,
            "period_norm": period_norm_,
            "tags_region": tags_region,
            "tags_type": tags_type,
        }
        page_text = make_page_text(rec)
        meta: Dict[str, Any] = {
            "dataset": dataset or dataset_name,
            "source_file": os.path.basename(csv_path),
            "row_index": int(i),
            "okres": period_norm_ or period_raw,
            "region": region_raw,
            "type": typ_clean,
            "measure": measure_clean,
            "value": value_float,
            "tags_region": tags_region,
            "tags_type": tags_type,
        }
        docs.append(Document(page_content=page_text, metadata=meta))
    return docs

# -----------------------------
# EMBEDDINGS — MXBAI LARGE (1024d)
# -----------------------------
from transformers import AutoModel, AutoTokenizer

class MXBAIEmbeddings(Embeddings):
    def __init__(self, model_id: str = EMBEDDER_MODEL, device: Optional[str] = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.tok = AutoTokenizer.from_pretrained(model_id)
        self.net = AutoModel.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16 if self.device=="cuda" else torch.float32
        ).to(self.device)
        self.net.eval()

    @torch.inference_mode()
    def _encode(self, texts: List[str], is_query=False, batch_size=32) -> List[List[float]]:
        prefix = "query: " if is_query else "passage: "
        out = []
        for i in range(0, len(texts), batch_size):
            batch = [prefix + t for t in texts[i:i+batch_size]]
            enc = self.tok(batch, padding=True, truncation=True, return_tensors="pt").to(self.device)
            last = self.net(**enc).last_hidden_state
            attn = enc["attention_mask"].unsqueeze(-1)
            emb = (last * attn).sum(dim=1) / torch.clamp(attn.sum(dim=1), min=1)
            emb = torch.nn.functional.normalize(emb, p=2, dim=1)
            out.extend(emb.detach().cpu().tolist())
        return out

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return self._encode(texts, is_query=False)

    def embed_query(self, text: str) -> List[float]:
        return self._encode([text], is_query=True)[0]

# -----------------------------
# Załaduj wszystkie CSV z folderu
#  • jeżeli mają dokładnie schemat standardowy → użyj build_documents_from_standard_csv
#  • inaczej pomiń lub dopisz własny builder
# -----------------------------
def _looks_standard(df: pd.DataFrame) -> bool:
    needed = {"dataset","measure","value","region","period","typ"}
    return needed.issubset(set(df.columns))

CSV_SOURCES = {}
if os.path.isdir(CSV_DIR):
    for fn in os.listdir(CSV_DIR):
        if fn.lower().endswith(".csv"):
            CSV_SOURCES[os.path.splitext(fn)[0]] = os.path.join(CSV_DIR, fn)

all_docs: List[Document] = []
for name, path in CSV_SOURCES.items():
    try:
        df_head = _read_csv_smart(path).head(1)
        if _looks_standard(df_head):
            docs = build_documents_from_standard_csv(name, path)
            all_docs.extend(docs)
        else:
            # Jeżeli masz inne pliki o innym schemacie — tu możesz dodać inny builder.
            # Na razie je pomijamy, by uniknąć śmieci.
            print(f"[UWAGA] Pomijam '{path}' — niestandardowy schemat kolumn.")
    except Exception as e:
        print(f"[BŁĄD] {path}: {e}")

print(f"Zbudowano dokumentów: {len(all_docs)}")

# -----------------------------
# Wektorownia Chroma
# -----------------------------
emb = MXBAIEmbeddings()
vectorstore = Chroma(
    collection_name="statystyki",
    embedding_function=emb,
    persist_directory=PERSIST_DIR
)

def _collection_count(vs) -> int:
    coll = getattr(vs, "_collection", None)
    try: return coll.count() if coll is not None else 0
    except Exception: return 0

if _collection_count(vectorstore) == 0 and all_docs:
    vectorstore.add_documents(all_docs)
    getattr(vectorstore, "persist", lambda: None)()

# -----------------------------
# BM25 retriever
# -----------------------------
bm25 = BM25Retriever.from_documents(all_docs)
bm25.k = 80

# -----------------------------
# MultiQuery + HyDE (prosty, bez LLM)
# -----------------------------
def make_mq_prompts(q: str) -> List[str]:
    qn = norm_text(q)
    variants = {
        f"{q}",
        qn.replace("ile wynosi", "podaj wartość"),
        qn.replace("ile", "jaka jest liczba"),
        qn + " (ogółem, Polska)",
        qn + " (najnowszy okres)",
    }
    return list(variants)

def make_hyde(q: str) -> str:
    qn = norm_text(q)
    hints = []
    if "emery" in qn: hints.append("świadczenia emerytalne")
    if "renta" in qn: hints.append("renty")
    if "macierz" in qn: hints.append("zasiłki macierzyńskie")
    if "ubezpiec" in qn: hints.append("ubezpieczeni")
    if "płatnik" in qn or "platnik" in qn: hints.append("płatnicy składek")
    if "zdrowotn" in qn: hints.append("ubezpieczenie zdrowotne")
    base = "pseudo: " + " ; ".join(hints) if hints else "pseudo: statystyki KRUS"
    return base + " ; czas: najnowszy dostępny ; region: ogółem/Polska"

# -----------------------------
# RRF — Reciprocal Rank Fusion
# -----------------------------
def rrf_merge(runs: List[List[Document]], k: int = 50, k_rrf: int = 60) -> List[Document]:
    scores: Dict[Tuple[str,int], float] = {}
    pos_maps: List[Dict[Tuple[str,int], int]] = []
    for run in runs:
        m = {}
        for i, d in enumerate(run):
            key = (d.metadata.get("source_file","?"), d.metadata.get("row_index",-1))
            m[key] = i
        pos_maps.append(m)
    keys = set().union(*[set(m.keys()) for m in pos_maps])
    for key in keys:
        score = 0.0
        for m in pos_maps:
            if key in m:
                rank = m[key] + 1
                score += 1.0 / (k_rrf + rank)
        scores[key] = score
    best_doc_for_key: Dict[Tuple[str,int], Document] = {}
    for run in runs:
        for d in run:
            key = (d.metadata.get("source_file","?"), d.metadata.get("row_index",-1))
            if key not in best_doc_for_key:
                best_doc_for_key[key] = d
    merged = sorted(best_doc_for_key.values(),
                    key=lambda d: scores[(d.metadata.get("source_file","?"), d.metadata.get("row_index",-1))],
                    reverse=True)
    return merged[:k]

# -----------------------------
# Dense search helpers
# -----------------------------
def dense_search(query: str, k: int = 80) -> List[Document]:
    retr = vectorstore.as_retriever(search_kwargs={"k": k})
    return retr.invoke(query)

def dense_search_on(texts: List[str], k: int = 40) -> List[Document]:
    retr = vectorstore.as_retriever(search_kwargs={"k": k})
    out = []
    for t in texts:
        out.extend(retr.invoke(t))
    seen = set()
    uniq = []
    for d in out:
        key = (d.metadata.get("source_file","?"), d.metadata.get("row_index",-1))
        if key in seen: continue
        seen.add(key); uniq.append(d)
    return uniq

# -----------------------------
# Reranker (opcjonalnie)
# -----------------------------
if _HAS_RERANK:
    try:
        cross_encoder = CrossEncoder(RERANKER_MODEL, device=device)
        reranker = CrossEncoderReranker(model=cross_encoder, top_n=30)
    except Exception:
        _HAS_RERANK = False
        reranker = None
else:
    reranker = None

# -----------------------------
# Okres → sortowanie
# -----------------------------
def period_key(p: Optional[str]) -> Tuple[int,int,int]:
    if not p or str(p).strip()=="":
        return (0,0,0)
    s = str(p).lower().strip()
    m = re.match(r"^(20\d{2})[-_/ ]?q([1-4])$", s)
    if m: return (int(m.group(1)), int(m.group(2)), 0)
    m = re.match(r"^(20\d{2})[-_/](\d{1,2})[-_/](\d{1,2})$", s)
    if m:
        y, mo, d = map(int, m.groups())
        daynum = (mo-1)*31 + d
        q = (mo-1)//3 + 1
        return (y, q, daynum)
    m = re.match(r"^(20\d{2})$", s)
    if m: return (int(m.group(1)), 0, 0)
    return (0,0,0)

# -----------------------------
# Heurystyki zapytania: region/typ/najnowszy
# -----------------------------
REL_WORDS_NOW = {"teraz","obecnie","w tym roku","aktualnie","bieżący","biezacy","najnowszy","ostatni","najświeższy","nowszy"}

def parse_region(q: str) -> Optional[str]:
    qn = norm_text(q)
    for nn in VOIV_NORM.keys():
        if nn in qn: return nn
    m = re.search(r"\bw\s+([a-ząćęłńóśźż\- ]+?)(?:\?|$|,|\.)", qn)
    if m:
        cand = norm_text(m.group(1).strip())
        best, best_len = None, 0
        for nn in VOIV_NORM.keys():
            common = os.path.commonprefix([nn, cand])
            if len(common) > best_len:
                best, best_len = nn, len(common)
        if best_len >= 5: return best
    return None

def parse_type(q: str) -> Optional[str]:
    qn = norm_text(q)
    best, best_len = None, 0
    for k, syns in TYPE_SYNS.items():
        for s in [k]+syns:
            s_n = norm_text(s)
            if s_n in qn and len(s_n) > best_len:
                best, best_len = k, len(s_n)
    return best

def wants_latest(q: str) -> bool:
    qn = norm_text(q)
    if any(w in qn for w in REL_WORDS_NOW): return True
    # jeśli nie wskazano innego okresu, preferuj najnowszy
    return True

# -----------------------------
# SMART RETRIEVE
# -----------------------------
def retrieve(query: str, k_final: int = 24) -> List[Document]:
    q = query

    d0 = dense_search(q, k=80)
    mq = make_mq_prompts(q);     d1 = dense_search_on(mq, k=40)
    hy = make_hyde(q);           d2 = dense_search_on([hy], k=40)
    bm25.k = 80;                 d3 = bm25.get_relevant_documents(q)

    fused = rrf_merge([d0, d1, d2, d3], k=80, k_rrf=60)

    if _HAS_RERANK and reranker is not None:
        reranked = reranker.compress_documents(documents=fused, query=q)
    else:
        reranked = fused[:30]

    region_req = parse_region(q)
    type_req   = parse_type(q)
    latest     = wants_latest(q)

    pool = []
    for d in reranked:
        m = d.metadata or {}
        reg = norm_text(m.get("region")) if m.get("region") else ""
        typ = norm_text(m.get("type")) if m.get("type") else ""
        tags_r = _parse_tags_field(m.get("tags_region"))
        tags_t = _parse_tags_field(m.get("tags_type"))
        if region_req:
            rr = norm_text(region_req)
            if (rr not in reg) and (rr not in tags_r):
                continue
        if type_req:
            tt = norm_text(type_req)
            if (tt not in typ) and (tt not in tags_t):
                continue
        pool.append(d)

    if len(pool) < 10:
        pool = reranked  # fallback gdy filtr zbyt agresywny

    if latest:
        def boost_general(d: Document) -> int:
            tags_r = _parse_tags_field(d.metadata.get("tags_region"))
            return 1 if len(REGION_GENERAL.intersection(tags_r)) > 0 else 0
        pool = sorted(
            pool,
            key=lambda d: (boost_general(d), period_key(d.metadata.get("okres"))),
            reverse=True
        )

    return pool[:k_final]

# -----------------------------
# LLM (PLLuM-12B-chat) — opcjonalnie
# -----------------------------
from transformers import AutoTokenizer as HF_AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from langchain_huggingface import HuggingFacePipeline

def build_pllum_llm(model_id: str = "CYFRAGOVPL/PLLuM-12B-chat", use_4bit: bool = True, max_new_tokens: int = 220):
    if use_4bit:
        bnb = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
        llm_model = AutoModelForCausalLM.from_pretrained(
            model_id, device_map="auto", quantization_config=bnb, trust_remote_code=True
        )
    else:
        llm_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True)

    tok = HF_AutoTokenizer.from_pretrained(model_id, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    gen_pipe = pipeline(
        "text-generation",
        model=llm_model,
        tokenizer=tok,
        max_new_tokens=max_new_tokens,
        temperature=0.0,
        do_sample=False,
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id,
        return_full_text=False,
    )
    return HuggingFacePipeline(pipeline=gen_pipe)

USE_LLM = True
llm_pllum = build_pllum_llm() if USE_LLM else None

def _trim(s: str, n: int = 500) -> str:
    s = str(s)
    return s if len(s) <= n else s[:n].rstrip() + "…"

def format_context_for_llm(docs: List[Document], max_docs: int = 6, max_snip_chars: int = 420) -> str:
    lines = []
    for i, d in enumerate(docs[:max_docs], 1):
        meta = d.metadata or {}
        okres  = meta.get("okres", "")
        region = meta.get("region", "")
        dataset = meta.get("dataset", "")
        source  = meta.get("source_file", "")
        rowidx  = meta.get("row_index", "")
        header = f"[{i}] dataset={dataset} | source={source} | okres={okres or '-'} | region={region or '-'} | row={rowidx}"
        body = _trim(d.page_content, max_snip_chars)
        lines.append(header + "\n" + body)
    return "\n\n".join(lines)

SYS_MSG = (
    "Jesteś asystentem danych KRUS. Odpowiadasz zwięźle, liczbowo i WYŁĄCZNIE na podstawie podanych fragmentów.\n"
    "Zasady:\n"
    "• Jeśli pytanie nie precyzuje daty/czasu ⇒ wybierz NAJNOWSZY okres w danych czyli z rokiem 2025 lub w formacie XX.XX.2025.\n"
    "• Jeśli brak województwa/typu ⇒ preferuj rekordy ogólne (Ogółem/Polska).\n"
    "• Gdy w danych brakuje odpowiedzi ⇒ powiedz wprost, że brak informacji.\n"
    "• W odpowiedzi NIE dodawaj źródeł ani numerów fragmentów."
)

USER_TEMPLATE = (
    "<PYTANIE>\n{question}\n</PYTANIE>\n\n"
    "<FRAGMENTY>\n{context}\n</FRAGMENTY>\n\n"
    "Sformułuj 1–2 zdaniową odpowiedź po polsku, podaj tylko liczby/jednostki i okres/region. "
    "Nie dodawaj źródeł ani numerów fragmentów."
)

def build_prompt(question: str, context: str) -> str:
    return f"<s>[INST] <<SYS>>{SYS_MSG}<</SYS>>\n" + USER_TEMPLATE.format(question=question, context=context) + "[/INST]"

def answer_pllum(question: str, k_ctx: int = 8) -> str:
    # Pobierz dokumenty z retrieva
    docs = retrieve(question, k_final=max(12, k_ctx))
    if not docs:
        return "Brak danych pasujących do pytania.\n[ŹRÓDŁA]\n– brak"

    # Analizuj pytanie, aby lepiej filtrować dokumenty
    region_req = parse_region(question)
    type_req = parse_type(question)
    latest = wants_latest(question)
    
    # Najpierw znajdź najlepiej dopasowane dokumenty według kryteriów
    scored_docs = []
    for doc in docs:
        score = 0
        mv = doc.metadata
        
        # Punkty za dopasowanie regionu
        if region_req:
            doc_region = norm_text(mv.get("region", ""))
            tags_r = _parse_tags_field(mv.get("tags_region"))
            region_norm = norm_text(region_req)
            if region_norm in doc_region or region_norm in tags_r:
                score += 10
        else:
            # Jeśli nie określono regionu, preferuj ogółem
            tags_r = _parse_tags_field(mv.get("tags_region"))
            if len(REGION_GENERAL.intersection(tags_r)) > 0:
                score += 5
        
        # Punkty za dopasowanie typu
        if type_req:
            doc_type = norm_text(mv.get("type", ""))
            tags_t = _parse_tags_field(mv.get("tags_type"))
            type_norm = norm_text(type_req)
            if type_norm in doc_type or type_norm in tags_t:
                score += 10
        
        # Punkty za najnowszy okres
        if latest:
            period_score = period_key(mv.get("okres"))
            score += period_score[0] * 1000 + period_score[1] * 10 + period_score[2]
        
        # Punkty za obecność wartości numerycznej
        if mv.get("value") is not None:
            score += 3
            
        scored_docs.append((score, doc))
    
    # Sortuj według score (najlepsze pierwsze)
    scored_docs.sort(key=lambda x: x[0], reverse=True)
    best_docs = [doc for score, doc in scored_docs]
    
    # Wybierz dokumenty do kontekstu LLM - najlepiej dopasowane
    context_docs = best_docs[:k_ctx]
    
    # Przygotuj kontekst dla LLM
    ctx = format_context_for_llm(context_docs, max_docs=len(context_docs))

    # --- Tryb bez LLM: zwróć top 1 rekord ---
    if not USE_LLM:
        top = best_docs[0]
        mv = top.metadata
        return (
            f"{mv.get('value')} (okres: {mv.get('okres') or '-'}, "
            f"region: {mv.get('region') or 'ogółem'}, "
            f"measure: {mv.get('measure')})\n\n"
            f"[ŹRÓDŁA]\n- #1 dataset={mv.get('dataset')} | source={mv.get('source_file')} | "
            f"okres={mv.get('okres')} | region={mv.get('region') or '-'}"
        )

    # --- Tryb z LLM ---
    prompt = build_prompt(question, ctx)
    try:
        raw_answer = llm_pllum.invoke(prompt)
    except Exception as e:
        headers_only = "\n".join(line.split("\n", 1)[0] for line in ctx.split("\n\n"))
        return f"Nie udało się wygenerować odpowiedzi LLM ({e}).\nDostępny kontekst:\n{headers_only}"

    # --- Znajdź dokumenty, które najprawdopodobniej były używane do odpowiedzi ---
    # Weź dokumenty użyte w kontekście LLM (te same, które model widział)
    used_docs = context_docs[:3]  # Pokaż max 3 najlepiej dopasowane
    
    # Alternatywnie: można też filtrować na podstawie słów kluczowych z odpowiedzi
    answer_words = set(norm_text(raw_answer).split())
    if len(answer_words) > 3:  # Jeśli odpowiedź ma jakąś treść
        # Znajdź dokumenty, które mają największe pokrycie słów z odpowiedzią
        word_matched_docs = []
        for doc in context_docs:
            doc_words = set(norm_text(doc.page_content).split())
            overlap = len(answer_words.intersection(doc_words))
            if overlap > 0:
                word_matched_docs.append((overlap, doc))
        
        if word_matched_docs:
            word_matched_docs.sort(key=lambda x: x[0], reverse=True)
            used_docs = [doc for overlap, doc in word_matched_docs[:3]]

    # --- Budowanie sekcji źródeł z rzeczywiście użytych dokumentów ---
    sources = []
    for i, doc in enumerate(used_docs, start=1):
        mv = doc.metadata
        # Dodaj wartość do opisu źródła, jeśli dostępna
        value_str = f", value={mv.get('value', '-')}" if mv.get('value') is not None else ""
        sources.append(
            f"- #{i} dataset={mv.get('dataset', '-')}, "
            f"source={mv.get('source_file', '-')}, "
            f"okres={mv.get('okres', '-')}, "
            f"region={mv.get('region', 'ogółem')}, "
            f"type={mv.get('type', '-')}"
            f"{value_str}"
        )

    # Sklejanie odpowiedzi + źródeł
    final_answer = raw_answer.strip() + "\n\n[ŹRÓDŁA]\n" + "\n".join(sources)
    return final_answer


Zbudowano dokumentów: 3030


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Device set to use cuda:0



Q: ile wynosi przeciętna liczba świadczeń emerytalnych?
Odpowiedź: 760307 emerytur w Polsce w 2025 pierwszym kwartale.

[ŹRÓDŁA]
- #1 dataset=EMERYTURY I RENTY REALIZOWANE PRZEZ KRUS, source=all_data.csv, okres=2025-Q1, region=-,type=-, value=760307.0
- #2 dataset=EMERYTURY I RENTY REALIZOWANE PRZEZ KRUS, source=all_data.csv, okres=2025-Q1, region=-,type=-, value=-
- #3 dataset=EMERYTURY I RENTY REALIZOWANE PRZEZ KRUS, source=all_data.csv, okres=2025-Q1, region=-,type=-, value=25034.0

Q: ile wynosi przeciętne świadczenie emerytalne?
Odpowiedź: 2184,49 złotych (emerytury, 1 kwartał 2025 roku, ogółem w Polsce).

[ŹRÓDŁA]
- #1 dataset=EMERYTURY I RENTY REALIZOWANE PRZEZ KRUS, source=all_data.csv, okres=2025-Q1, region=-,type=-, value=2178.42
- #2 dataset=EMERYTURY I RENTY REALIZOWANE PRZEZ KRUS, source=all_data.csv, okres=2025-Q1, region=-,type=-, value=2184.49
- #3 dataset=EMERYTURY I RENTY REALIZOWANE PRZEZ KRUS, source=all_data.csv, okres=2025-Q1, region=-,type=-, value=760307.0

Q: 

In [8]:
if __name__ == "__main__":
    tests = [
        "ile zlozono wnioskow o przyznanie emerytury?",
    ]
    for q in tests:
        print("\nQ:", q)
        print(answer_pllum(q, k_ctx=8))



Q: ile zlozono wnioskow o przyznanie emerytury?
W pierwszym kwartale 2025 roku złożono 8915 wniosków o przyznanie emerytury.

[ŹRÓDŁA]
- #1 dataset=WNIOSKI O PRZYZNANIE EMERYTUR I RENT WEDŁUG RODZAJÓW, source=all_data.csv, okres=2025-Q1, region=-type=-, value=8915.0
- #2 dataset=WNIOSKI O PRZYZNANIE EMERYTUR I RENT WEDŁUG RODZAJÓW, source=all_data.csv, okres=2025-Q1, region=-type=-, value=10257.0
- #3 dataset=WNIOSKI O PRZYZNANIE EMERYTUR I RENT WEDŁUG RODZAJÓW, source=all_data.csv, okres=2025-Q1, region=-type=-, value=1678.0
