In [4]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
import faiss
import re

try:
    import PyPDF2

    HAS_PYPDF2 = True
except Exception:
    HAS_PYPDF2 = False

LLM_NAME = "google/flan-t5-base"
EMBED_MODEL = "all-MiniLM-L6-v2"
top_k = 2
DEBUG = True
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME)
embedder = SentenceTransformer(EMBED_MODEL)

knowledge_chunks = [
    "Python is a high-level programming language that emphasizes simplicity and readability.",
    "Java is a versatile, object-oriented programming language designed for portability across platforms.",
    "Python is dynamically typed and concise, while Java is statically typed and verbose.",
    "RAG combines retrieval from a knowledge base with a generator model to produce grounded answers."
]

chunk_embeddings = embedder.encode(knowledge_chunks, convert_to_numpy=True)
d = chunk_embeddings.shape[1]
index = faiss.IndexFlatL2(d)
index.add(chunk_embeddings)


def rag_tool(query, k=top_k):
    q_emb = embedder.encode([query], convert_to_numpy=True)
    D, I = index.search(q_emb, k)
    picked = [knowledge_chunks[i] for i in I[0] if 0 <= i < len(knowledge_chunks)]
    context = " ".join(picked)
    if DEBUG:
        print("\n[DEBUG:RAG] D:", D, "I:", I)
        print("[DEBUG:RAG] Retrieved:", picked)
    return context if context else "No relevant context found."


def calculator_tool(query):
    expr = re.findall(r"[0-9+\-*/().^]+", query)
    if not expr:
        return "No math expression found."
    raw = "".join(expr).strip()
    raw = raw.replace("^", "**")
    try:
        result = eval(raw, {"__builtins__": {}}, {})
        return f"The result is {result}."
    except Exception:
        return "Sorry, I couldn't calculate that."


def _read_pdf_text(pdf_path: str, max_chars: int = 6000) -> str:
    if not HAS_PYPDF2:
        return "PyPDF2 not installed."
    try:
        text_parts = []
        with open(pdf_path, "rb") as f:
            reader = PyPDF2.PdfReader(f)
            for page in reader.pages:
                t = page.extract_text() or ""
                text_parts.append(t)
        text = "\n".join(text_parts)
        if not text.strip():
            return "No readable text."
        return text[:max_chars]
    except FileNotFoundError:
        return "Pdf not found at path"
    except Exception as e:
        return f"Something went wrong. {e}"

In [11]:
def pdf_tool(query):
    m = re.search(r"(?:pdf:|path=|file=)\s*([^\s]+\.pdf)", query, re.IGNORECASE)
    if not m:
        return f"Pdf not found at path"
    pdf_path = m.group(1)
    raw_text = _read_pdf_text(pdf_path)
    if DEBUG:
        print("\n[DEBUG:RAG] PDF:", pdf_path)
        print("[DEBUG:RAG] Text:", raw_text)
    if not isinstance(raw_text, str) or "error" in raw_text.lower() or raw_text.startswith("not installed"):
        return raw_text
    prompt = f"Summarize the following document in 5-7 bullet points:\n\n{raw_text}\n\nSummary:"
    inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True)
    outputs = llm_model.generate(**inputs, max_length=220)
    summary = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary


TOOLS = {
    "RAG": rag_tool,
    "CALC": calculator_tool,
    "PDF": pdf_tool,
}


def decide_tool(query: str) -> str:
    """
    Ask the LLM to choose a tool.
    Returns one of: 'RAG', 'CALC', 'PDF'
    """
    instruction = f"""
You are a controller agent. Choose the best tool for the user's question.
Available tools:
- RAG: knowledge questions about programming, technology, and general facts.
- CALC: arithmetic or math expressions.
- PDF: when the user asks to read or summarize a PDF. The query will include a path like pdf:./file.pdf or path=./file.pdf

Question: {query}
Respond with exactly one of: RAG, CALC, PDF
"""
    inputs = llm_tokenizer(instruction, return_tensors="pt", truncation=True)
    out = llm_model.generate(**inputs, max_length=8)
    decision = llm_tokenizer.decode(out[0], skip_special_tokens=True).strip().upper()

    # Lightweight guards
    if "PDF" in decision or re.search(r"\.pdf\b", query, re.IGNORECASE):
        decision = "PDF"
    elif any(tok in query.lower() for tok in ["+", "-", "*", "/", "^", "calculate", "eval"]):
        decision = "CALC"
    elif decision not in TOOLS:
        decision = "RAG"

    if DEBUG:
        print("\n[DEBUG:AGENT] Decision:", decision)
    return decision


def agent(query: str) -> str:
    tool_name = decide_tool(query)
    tool_fn = TOOLS[tool_name]
    tool_result = tool_fn(query)

    # For RAG outputs, optionally let LLM compose a concise final answer using the retrieved context.
    if tool_name == "RAG":
        prompt = f"Use ONLY the context below to answer the question concisely.\n\nContext:\n{tool_result}\n\nQuestion: {query}\nAnswer:"
        inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True)
        out = llm_model.generate(**inputs, max_length=180)
        final_answer = llm_tokenizer.decode(out[0], skip_special_tokens=True)
        return final_answer

    # For CALC/PDF, the tool_result is already final text.
    return tool_result

In [12]:
if __name__ == "__main__":
    print("User: What is Python?")
    print("Agent:", agent("What is Python?"))

    print("\nUser: Compare Python and Java.")
    print("Agent:", agent("Compare Python and Java."))

    print("\nUser: What is 2 + 5 * 3?")
    print("Agent:", agent("What is 2 + 5 * 3?"))

    # Provide a real PDF path you have locally:
    # e.g., put 'sample.pdf' in the same folder and run:
    print("\nUser: summarize pdf: ./sample.pdf")
    print("Agent:", agent("summarize pdf: ./sample.pdf"))

User: What is Python?

[DEBUG:AGENT] Decision: RAG

[DEBUG:RAG] D: [[0.29500717 0.9449347 ]] I: [[0 2]]
[DEBUG:RAG] Retrieved: ['Python is a high-level programming language that emphasizes simplicity and readability.', 'Python is dynamically typed and concise, while Java is statically typed and verbose.']
Agent: high-level programming language

User: Compare Python and Java.

[DEBUG:AGENT] Decision: RAG

[DEBUG:RAG] D: [[0.43986782 0.7197268 ]] I: [[2 0]]
[DEBUG:RAG] Retrieved: ['Python is dynamically typed and concise, while Java is statically typed and verbose.', 'Python is a high-level programming language that emphasizes simplicity and readability.']
Agent: Python

User: What is 2 + 5 * 3?

[DEBUG:AGENT] Decision: CALC
Agent: The result is 17.

User: summarize pdf: ./sample.pdf

[DEBUG:AGENT] Decision: PDF

[DEBUG:RAG] PDF: ./sample.pdf
[DEBUG:RAG] Text: Quantum mechanics  is the fundamental physical  theory  that describes the behavior of matter 
and of light; its unusual characte

In [19]:
"""
- Advanced chunking (sliding window + simple merging)
- Dense retrieval with FAISS (sentence-transformers)
- Sparse retrieval with BM25 (rank_bm25)
- Hybrid retrieval combining both
- Cross-encoder reranking (sentence-transformers CrossEncoder)
- Simple evaluation helpers & debug logging
"""

from typing import List, Tuple, Dict
import numpy as np

# models
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
import faiss
from rank_bm25 import BM25Okapi

# LLM part for generation (optional; you can plug your generator)
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# ---------------------------
# CONFIG (tweak these)
# ---------------------------
EMBED_MODEL = "all-MiniLM-L6-v2"  # embeddings (dense)
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"  # cross-encoder
QA_LLM = "google/flan-t5-base"  # used only for final answer (optional)
CHUNK_TOKENS = 50  # approx tokens per chunk
CHUNK_OVERLAP = 30  # sliding-window overlap
TOP_K_DENSE = 10
TOP_K_SPARSE = 10
HYBRID_K = 10  # union size before rerank
FINAL_TOPK = 3  # top-k after rerank -> feed LLM


# ---------------------------
# UTIL: simple tokenizer (whitespace)
# ---------------------------
def simple_tokenize(text: str) -> List[str]:
    return [t for t in text.split() if t.strip()]


# ---------------------------
# 1) Advanced chunking
#    - sliding window on tokens (approx)
#    - merge very short lines with next line (basic heading handling)
# ---------------------------
def chunk_text_advanced(text: str,
                        chunk_tokens: int = CHUNK_TOKENS,
                        overlap: int = CHUNK_OVERLAP) -> List[str]:

    # Preprocess lines: merge short heading-like lines with next line
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    merged_lines = []
    i = 0
    while i < len(lines):
        line = lines[i]
        # heuristic: short line (<=6 tokens) likely heading -> merge with next
        if len(simple_tokenize(line)) <= 6 and (i + 1) < len(lines):
            merged = line + " " + lines[i + 1]
            merged_lines.append(merged)
            i += 2
        else:
            merged_lines.append(line)
            i += 1

    # join into one token list, but keep mapping to text spans
    tokens = []
    token_to_text = []
    for chunk in merged_lines:
        tks = simple_tokenize(chunk)
        for tk in tks:
            tokens.append(tk)
            token_to_text.append(chunk)  # reference original chunk for context reconstruction

    # sliding window
    chunks = []
    start = 0
    n = len(tokens)
    while start < n:
        end = min(start + chunk_tokens, n)
        # reconstruct approximate text span from token_to_text
        slice_texts = token_to_text[start:end]
        # to avoid too repetitive repeats, take unique contiguous
        out_lines = []
        prev = None
        for s in slice_texts:
            if s != prev:
                out_lines.append(s)
            prev = s
        chunk_text = " ".join(out_lines)
        chunks.append(chunk_text)
        if end == n:
            break
        start += (chunk_tokens - overlap)
    # deduplicate near-identical chunks while preserving order
    seen = set()
    final_chunks = []
    for c in chunks:
        key = c[:200]  # rough key
        if key not in seen:
            final_chunks.append(c)
            seen.add(key)
    return final_chunks

In [20]:

# ---------------------------
# 2) Build dense embeddings + FAISS index
# ---------------------------
def build_faiss_index(chunks: List[str], embedder: SentenceTransformer):
    embeddings = embedder.encode(chunks, convert_to_numpy=True, show_progress_bar=True)
    # ensure dtype float32
    embeddings = embeddings.astype('float32')
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)  # simple exact index; replace with HNSW/IVF for big corpora
    index.add(embeddings)
    return index, embeddings


# ---------------------------
# 3) Build sparse index (BM25)
# ---------------------------
def build_bm25_index(chunks: List[str]):
    tokenized = [simple_tokenize(c.lower()) for c in chunks]
    bm25 = BM25Okapi(tokenized)
    return bm25, tokenized


# ---------------------------
# 4) Hybrid retrieval: dense top-k + sparse top-k union
# ---------------------------
def hybrid_retrieve(query: str,
                    chunks: List[str],
                    embedder: SentenceTransformer,
                    faiss_index,
                    bm25_index,
                    tokenized_chunks,
                    top_k_dense: int = TOP_K_DENSE,
                    top_k_sparse: int = TOP_K_SPARSE,
                    hybrid_k: int = HYBRID_K
                    ) -> Tuple[List[int], Dict[int, Dict]]:
    # dense search
    q_emb = embedder.encode([query], convert_to_numpy=True).astype('float32')
    D, I = faiss_index.search(q_emb, top_k_dense)
    dense_indices = [int(i) for i in I[0] if i >= 0]

    # sparse
    sparse_scores = bm25_index.get_scores(simple_tokenize(query.lower()))
    sparse_ranked = np.argsort(-np.array(sparse_scores))[:top_k_sparse].tolist()

    # union preserving order of closeness (we'll weight later)
    candidate_set = []
    for idx in dense_indices + sparse_ranked:
        if idx not in candidate_set and 0 <= idx < len(chunks):
            candidate_set.append(idx)

    # if not enough candidates, pad from full list
    if len(candidate_set) < hybrid_k:
        for i in range(len(chunks)):
            if i not in candidate_set:
                candidate_set.append(i)
            if len(candidate_set) >= hybrid_k:
                break

    # prepare candidate metadata
    metadata = {}
    for idx in candidate_set:
        metadata[idx] = {
            "chunk": chunks[idx],
            "dense_rank": dense_indices.index(idx) if idx in dense_indices else None,
            "sparse_rank": sparse_ranked.index(idx) if idx in sparse_ranked else None
        }
    return candidate_set[:hybrid_k], metadata

In [21]:

# ---------------------------
# 5) Cross-encoder reranker
# ---------------------------
def rerank_with_cross_encoder(query: str, candidate_indices: List[int], chunks: List[str], reranker: CrossEncoder,
                              top_k: int = FINAL_TOPK):
    pairs = [(query, chunks[i]) for i in candidate_indices]
    scores = reranker.predict(pairs)  # returns list of relevance scores
    # pair indices and scores
    scored = list(zip(candidate_indices, scores))
    scored.sort(key=lambda x: x[1], reverse=True)
    top = scored[:top_k]
    top_indices = [idx for idx, sc in top]
    top_scores = [sc for idx, sc in top]
    return top_indices, top_scores


# ---------------------------
# 6) Prompt builder & generator (LLM)
# ---------------------------
def build_prompt_with_context(query: str, contexts: List[str], history: List[Tuple[str, str]] = None):
    history_text = ""
    if history:
        for u, b in history[-3:]:
            history_text += f"User: {u}\nBot: {b}\n"
    context_block = "\n\n".join(contexts)
    prompt = f"""You are a helpful assistant. Use ONLY the context below to answer the question concisely.
Conversation so far:
{history_text}

Context:
{context_block}

Question: {query}
Answer:"""
    return prompt


def generate_answer_via_llm(prompt: str, qa_tokenizer, qa_model, max_new_tokens: int = 200):
    inputs = qa_tokenizer(prompt, return_tensors="pt", truncation=True)
    outputs = qa_model.generate(**inputs, max_length=max_new_tokens)
    ans = qa_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return ans


# ---------------------------
# 7) Simple evaluation helper (precision@k) - requires gold_indices per query
# ---------------------------
def precision_at_k(retrieved_indices: List[int], gold_indices: List[int], k: int):
    retrieved_topk = retrieved_indices[:k]
    if not gold_indices:
        return 0.0
    hits = sum(1 for r in retrieved_topk if r in gold_indices)
    return hits / min(k, len(retrieved_topk))


# ---------------------------
# Example usage (run pipeline)
# ---------------------------
def demo_pipeline(corpus_text: str, queries: List[Dict], debug: bool = True):
    """
    corpus_text: big text to chunk
    queries: list of dict { "q": "query text", "gold": [list of gold chunk indices] (optional) }
    """
    # 1 chunking
    chunks = chunk_text_advanced(corpus_text)
    if debug:
        print(f"[INFO] Created {len(chunks)} chunks from corpus.")

    # 2 build dense index
    embedder = SentenceTransformer(EMBED_MODEL)
    faiss_index, dense_embeddings = None, None
    faiss_index, dense_embeddings = build_faiss_index(chunks, embedder)

    # 3 bm25
    bm25, tokenized_chunks = build_bm25_index(chunks)

    # 4 reranker model
    reranker = CrossEncoder(RERANKER_MODEL)

    # 5 QA LLM (optional)
    qa_tokenizer = AutoTokenizer.from_pretrained(QA_LLM)
    qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_LLM)

    results = []
    for qq in queries:
        qtext = qq["q"]
        gold = qq.get("gold", None)

        # hybrid retrieve
        cand_idxs, meta = hybrid_retrieve(qtext, chunks, embedder, faiss_index, bm25, tokenized_chunks)
        if debug:
            print("\n[DEBUG] Hybrid candidates:", cand_idxs)
            for ci in cand_idxs:
                print(f" - idx {ci}: {chunks[ci][:140]}...")

        # rerank
        top_idxs, top_scores = rerank_with_cross_encoder(qtext, cand_idxs, chunks, reranker)
        if debug:
            print("[DEBUG] Reranked top:", list(zip(top_idxs, top_scores)))

        # build prompt with top contexts
        top_contexts = [chunks[i] for i in top_idxs]
        prompt = build_prompt_with_context(qtext, top_contexts)
        answer = generate_answer_via_llm(prompt, qa_tokenizer, qa_model)

        # evaluation if gold provided
        p_at_3 = None
        if gold is not None:
            p_at_3 = precision_at_k(top_idxs, gold, 3)

        results.append({
            "query": qtext,
            "top_indices": top_idxs,
            "top_scores": top_scores,
            "answer": answer,
            "precision@3": p_at_3
        })
        if debug:
            print("[RESULT ANSWER]", answer)
            if p_at_3 is not None:
                print("[EVAL] precision@3:", p_at_3)

    return results

In [22]:

# ---------------------------
# If run as script: demo with sample text
# ---------------------------
if __name__ == "__main__":
    sample_corpus = """
Python is a high-level programming language that emphasizes simplicity and readability.
It is dynamically typed and widely used for scripting, web development, and data science.

Java is a versatile, object-oriented programming language designed for portability across platforms.
It is statically typed and commonly used in enterprise applications and Android development.

RAG (Retrieval-Augmented Generation) combines retrieval from a knowledge base with a generator model to produce grounded answers.
"""
    qs = [
        {"q": "What is Python?"},
        {"q": "Describe Java and where it is used."},
        {"q": "What is RAG?"}
    ]
    demo_pipeline(sample_corpus, qs, debug=True)


[INFO] Created 2 chunks from corpus.


Batches:   0%|          | 0/1 [00:00<?, ?it/s]


[DEBUG] Hybrid candidates: [0, 1]
 - idx 0: Python is a high-level programming language that emphasizes simplicity and readability. It is dynamically typed and widely used for scriptin...
 - idx 1: It is dynamically typed and widely used for scripting, web development, and data science. Java is a versatile, object-oriented programming l...
[DEBUG] Reranked top: [(0, 9.42434), (1, -4.8227196)]
[RESULT ANSWER] a high-level programming language

[DEBUG] Hybrid candidates: [1, 0]
 - idx 1: It is dynamically typed and widely used for scripting, web development, and data science. Java is a versatile, object-oriented programming l...
 - idx 0: Python is a high-level programming language that emphasizes simplicity and readability. It is dynamically typed and widely used for scriptin...
[DEBUG] Reranked top: [(1, 5.878658), (0, 4.2013035)]
[RESULT ANSWER] Java is a versatile, object-oriented programming language designed for portability across platforms. It is statically typed and commonly use