In [None]:
import os
import textwrap
from typing import List
from sentence_transformers import SentenceTransformer
from llama_cpp import Llama
import chromadb
import json
import argparse

In [None]:
CHROMA_PATH = "/Users/mateuszbulanda-gorol/Desktop/Projects/rag_lex_project/vectordb/chroma_db"
COLLECTION_NAME = "isap_acts"
EMB_MODEL = "clarin-pl/roberta-polish-sentence-transformer-v1"
LLAMA_GGUF_PATH = "/Users/mateuszbulanda-gorol/Desktop/Projects/rag_lex_project/models"
LLAMA_N_CTX = 4096
TOP_K = 5
MAX_GEN_TOKENS = 512
PROMPT_MAX_CHARS = 12000

In [None]:
def init_embedding_model(model_name: str):
    print(f"[init] Loading embedding model: {model_name}")
    return SentenceTransformer(model_name)

In [None]:
def init_chroma_client(chroma_path: str):
    print(f"[init] Opening CromaDB persistent as: {chroma_path}")
    client = chromadb.PersistentClient(path=chroma_path)
    collection = client.get_collection(COLLECTION_NAME)
    return client, collection

In [None]:
def init_llama(model_path: str, n_ctx: int):
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"LLAMA model file not found at: {model_path}")
    print(f"[init] Loading model (gguf) from {model_path}")
    llm = Llama(model_path=model_path, n_ctx=n_ctx, n_threads=8, n_gpu_layers=-1)
    return llm

In [None]:
def build_retrieval_centext(collection, emb_model, query: str, top_k: int, char_limit: int = PROMPT_MAX_CHARS):
    """
    1. Embed query
    2. Query Chroma
    3. Compose contect string with citations
    Returns: contect_text, hits (list of dicts)
    """
    print("[retrieve] Embedding query...")
    q_emb = emb_model.encode([query]).tolist()
    print("[retrive] Querying ChromaDB...")
    results = collection.query(query_embeddings=q_emb, n_results=top_k, include=["documents","metadatas","distances"])
    docs = results["documents"][0]
    metadatas = results["metadatasa"][0]
    distances = results.get("distances", [[]])[0]

    hits = []
    for i, (doc, meta, dist) in enumerate(zip(docs, metadatas, distances)):
        # Normalize metadata representation for citation
        citation = meta.get("filename", "unknown")
        if meta.get("year") and meta.get("pos"):
            citation = f"{meta.get('publisher', '')}_{meta.get('year', '')}_poz.{meta.get('pos', '')}"
        # capture chunk id if present
        chunk_id = meta.get("chunk_id", None)
        hits.append({
            "rank": i+1
            ,"score": float(dist) if dist is not None else None
            ,"text": doc
            ,"meta": meta
            ,"citation": f"{citation}#{chunk_id}" if chunk_id is not None else citation
        })
    
    # Compose context text with simple formatting
    parts = []
    total_chars = 0
    for h in hits:
        block = f"[Źródło: {h['citation']}]\n{h['text'].strip()}\n"
        if total_chars + len(block) > char_limit:
            break
        parts.append(block)
        total_chars += len(block)
    
    context_text = "\n---\n".join(parts)
    return context_text, hits

In [None]:
def build_prompt(question: str, context: str) -> str:
    """
    Polish prompt template. Instruct model to answer based only on provided sources and cite them.
    """
    prompt = f"""
        Jesteś ekspertem prawa polskiego. Odpowiedz na zadane pytanie **wyłącznie** na podstawie poniszych fragmentów ustaw (nie wymyślaj dodatkowych informacji).
        Jeśli fragmenty są sprzeczne lub niewystarczające, zaznacz to i wskaż jakie dodatkowe źródła byłyby potrzebne.
        Podaj odpowiedź po polsku. Na końcu krótkie odniesienie do uytych źródeł (np. "Źródła: DU_2020_poz.123#4).

        Poniej fragmenty aktów:
        {context}

        Pytanie: {question}

        Odpowiedź:
        """.strip()
    # optionally shorten whitespace
    return textwrap.shorten(prompt, width=LLAMA_N_CTX*4, placeholder="...")

In [None]:
def generate_answer(llm, prompt: str, max_tokens: int = MAX_GEN_TOKENS):
    print("[gen] Generating answer from LLM...")
    resp = llm.create(prompt=prompt, max_tokens=max_tokens, temperature=0.0)
    text = resp.get("choices", [{}])[0].get("text", "").strip()
    return text, resp

In [None]:
def pretty_print_hits(hits: List[dict]):
    print("\n[Retrived passages]")
    for h in hits:
        print(f"Rank{h['rank']} | citation: {h['citation']} | score: {h['score']}")
        snippet = h["text"][:400].replace("\n", " ")
        print(" ",snippet, "...")
        print("---")

In [None]:
def main(args):
    emb_model = init_embedding_model(EMB_MODEL)
    client, collection = init_chroma_client(CHROMA_PATH)
    llm = init_llama(LLAMA_GGUF_PATH, n_ctx=LLAMA_N_CTX)

    query = args.query
    print(f"\n[query] {query}\n")

    context_text, hits = build_retrieval_centext(collection, emb_model, query, top_k=TOP_K)
    if not context_text.strip():
        print("[warn] Brak pasujących fragmentów w bazie.")
    else:
        pretty_print_hits(hits)
    
    prompt = build_prompt(query, context_text)
    # optional: print prompt preview (truncated)
    print("\n[prompt preview]\n", prompt[:2000], "\n...")

    answer, raw = generate_answer(llm, prompt, max_tokens=MAX_GEN_TOKENS)

    print("\n[FINAL ANSWER]\n")
    print(answer)
    print("\n[RAW LLM RESPONSE METADATA]\n", json.dump(raw, indent=2, ensure_ascii=False))

In [None]:
parser = argparse.ArgumentParser(description="Simple RAG: Chroma + PLLuM(gguf via llama-cpp)")
parser.add_argument("--query", "-q", type=str, required=True, help="Pytanie po poslku")
args = parser.parse_args()

In [None]:
main(args)

python rag_llama_chroma.py -q "Jakie są zasady wprowadzenia stanu wyjątkowego w Polsce?"
