In [None]:
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

from dotenv import load_dotenv
from openai import OpenAI
import gradio as gr

from pathlib import Path
from typing import List, Tuple
from transformers import AutoTokenizer


# ---- load env ----
load_dotenv(override=True)

# ---- OpenAI-compatible base URLs (Gemini & Groq) ----
GEMINI_BASE = "https://generativelanguage.googleapis.com/v1beta/openai/"
GROQ_BASE   = "https://api.groq.com/openai/v1"

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")  # Gemini
GROQ_API_KEY   = os.getenv("GROQ_API_KEY")    # Groq

# ---- create clients only if keys exist ----
openai_client = OpenAI() if OPENAI_API_KEY else None
gemini_client = OpenAI(api_key=GOOGLE_API_KEY, base_url=GEMINI_BASE) if GOOGLE_API_KEY else None
groq_client   = OpenAI(api_key=GROQ_API_KEY,   base_url=GROQ_BASE)   if GROQ_API_KEY   else None

# ---- model registry (label -> client/model) ----
MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}
def _register(label: str, client: Optional[OpenAI], model_id: str):
    if client is not None:
        MODEL_REGISTRY[label] = {"client": client, "model": model_id}

# OpenAI
_register("OpenAI • GPT-5",        openai_client, "gpt-5")
_register("OpenAI • GPT-5 Nano",   openai_client, "gpt-5-nano")
_register("OpenAI • GPT-4o-mini",  openai_client, "gpt-4o-mini")

# Gemini (Google)
_register("Gemini • 2.5 Pro",      gemini_client, "gemini-2.5-pro")
_register("Gemini • 2.5 Flash",    gemini_client, "gemini-2.5-flash")

# Groq
_register("Groq • Llama 3.1 8B",   groq_client,   "llama-3.1-8b-instant")
_register("Groq • Llama 3.3 70B",  groq_client,   "llama-3.3-70b-versatile")
_register("Groq • GPT-OSS 20B",    groq_client,   "openai/gpt-oss-20b")
_register("Groq • GPT-OSS 120B",   groq_client,   "openai/gpt-oss-120b")

AVAILABLE_MODELS = list(MODEL_REGISTRY.keys())
DEFAULT_MODEL = AVAILABLE_MODELS[0] if AVAILABLE_MODELS else "OpenAI • GPT-4o-mini"

print("Providers configured →",
      f"OpenAI:{bool(OPENAI_API_KEY)}  Gemini:{bool(GOOGLE_API_KEY)}  Groq:{bool(GROQ_API_KEY)}")
print("Models available     →", ", ".join(AVAILABLE_MODELS) or "None (add API keys in .env)")


In [None]:
@dataclass(frozen=True)
class LLMRoute:
    client: OpenAI
    model: str

class MultiLLM:
    """OpenAI-compatible chat across providers (OpenAI, Gemini, Groq)."""
    def __init__(self, registry: Dict[str, Dict[str, object]]):
        self._routes: Dict[str, LLMRoute] = {
            k: LLMRoute(client=v["client"], model=str(v["model"])) for k, v in registry.items()
        }
        if not self._routes:
            raise RuntimeError("No LLM providers configured. Add API keys in .env.")

    def complete(self, *, model_label: str, system: str, user: str) -> str:
        if model_label not in self._routes:
            raise ValueError(f"Unknown model: {model_label}")
        r = self._routes[model_label]
        resp = r.client.chat.completions.create(
            model=r.model,
            messages=[{"role":"system","content":system},
                      {"role":"user","content":user}]
        )
        return (resp.choices[0].message.content or "").strip()


In [None]:

# MiniLM embedding model & tokenizer (BERT WordPiece)
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

# Use the model's practical window with 50% overlap
MAX_TOKENS = 256          # all-MiniLM-L6-v2 effective limit used by Sentence-Transformers
OVERLAP_RATIO = 0.50      # 50% sliding window overlap

TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_NAME)

def chunk_text(
    text: str,
    tokenizer: AutoTokenizer = TOKENIZER,
    max_tokens: int = MAX_TOKENS,
    overlap_ratio: float = OVERLAP_RATIO,
) -> List[str]:
    """
    Token-aware sliding window chunking for MiniLM.
    - Windows of `max_tokens`
    - Step = max_tokens * (1 - overlap_ratio)  -> 50% overlap by default
    """
    ids = tokenizer.encode(text, add_special_tokens=False)
    if not ids:
        return []

    step = max(1, int(max_tokens * (1.0 - overlap_ratio)))
    out: List[str] = []
    for start in range(0, len(ids), step):
        window = ids[start : start + max_tokens]
        if not window:
            break
        toks = tokenizer.convert_ids_to_tokens(window)
        chunk = tokenizer.convert_tokens_to_string(toks).strip()
        if chunk:
            out.append(chunk)
        if start + max_tokens >= len(ids):
            break
    return out

def load_bare_acts(root: str = "knowledge_base/bare_acts") -> List[Tuple[str, str]]:
    """Return list of (source_id, text). `source_id` is filename stem."""
    base = Path(root)
    if not base.exists():
        raise FileNotFoundError(f"Folder not found: {base.resolve()}")
    pairs: List[Tuple[str, str]] = []
    for p in sorted(base.glob("*.txt")):
        pairs.append((p.stem, p.read_text(encoding="utf-8")))
    if not pairs:
        raise RuntimeError("No .txt files found under knowledge_base/bare_acts")
    return pairs

acts_raw = load_bare_acts()
print("Bare Acts loaded:", [s for s, _ in acts_raw])
print(f"Chunking → max_tokens={MAX_TOKENS}, overlap={int(OVERLAP_RATIO*100)}%")


In [None]:
import chromadb
from chromadb import PersistentClient
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from transformers import AutoTokenizer
from typing import Dict, List, Tuple

class BareActsIndex:
    """Owns the vector DB lifecycle & retrieval (token-aware chunking)."""
    def __init__(
        self,
        db_path: str = "vector_db",
        collection: str = "bare_acts",
        embed_model: str = EMBED_MODEL_NAME,
        max_tokens: int = MAX_TOKENS,
        overlap_ratio: float = OVERLAP_RATIO,
    ):
        self.db_path = db_path
        self.collection_name = collection
        self.embed_model = embed_model
        self.max_tokens = max_tokens
        self.overlap_ratio = overlap_ratio

        self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=self.embed_model)
        self.tokenizer = AutoTokenizer.from_pretrained(self.embed_model)

        self.client: PersistentClient = PersistentClient(path=db_path)
        self.col = self.client.get_or_create_collection(
            name=self.collection_name,
            embedding_function=self.embed_fn,
        )

    def rebuild(self, docs: List[Tuple[str, str]]):
        """Idempotent rebuild: clears and re-adds chunks with metadata."""
        try:
            self.client.delete_collection(self.collection_name)
        except Exception:
            pass

        self.col = self.client.get_or_create_collection(
            name=self.collection_name,
            embedding_function=self.embed_fn,
        )

        ids, texts, metas = [], [], []
        for src, text in docs:
            for idx, ch in enumerate(
                chunk_text(
                    text,
                    tokenizer=self.tokenizer,
                    max_tokens=self.max_tokens,
                    overlap_ratio=self.overlap_ratio,
                )
            ):
                ids.append(f"{src}-{idx}")
                texts.append(ch)
                metas.append({"source": src, "chunk_id": idx})

        if ids:
            self.col.add(ids=ids, documents=texts, metadatas=metas)

        print(
            f"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name} "
            f"(tokens/chunk={self.max_tokens}, overlap={int(self.overlap_ratio*100)}%)"
        )

    def query(self, q: str, k: int = 6) -> List[Dict]:
        res = self.col.query(query_texts=[q], n_results=k)
        docs = res.get("documents", [[]])[0]
        metas = res.get("metadatas", [[]])[0]
        return [{"text": d, "meta": m} for d, m in zip(docs, metas)]

# build (or rebuild) the index once
index = BareActsIndex()
index.rebuild(acts_raw)


In [None]:
class PromptBuilder:
    """Small utility to keep prompting consistent and auditable."""
    SYSTEM = (
        "You are a precise legal assistant for Indian Bare Acts. "
        "Answer ONLY from the provided context. If the answer is not in context, say you don't know. "
        "Cite sources inline in square brackets as [file #chunk] (e.g., [bns #12]). "
        "Prefer exact quotes for critical provisions/sections."
    )

    @staticmethod
    def build_user(query: str, contexts: List[Dict]) -> str:
        ctx = "\n\n---\n\n".join(
            f"[{c['meta']['source']} #{c['meta']['chunk_id']}]\n{c['text']}" for c in contexts
        )
        return (
            f"Question:\n{query}\n\n"
            f"Context (do not use outside this):\n{ctx}\n\n"
            "Instructions:\n- Keep answers concise and faithful to the text.\n"
            "- Use [file #chunk] inline where relevant."
        )

def _snippet(txt: str, n: int = 220) -> str:
    s = " ".join(txt.strip().split())
    return (s[:n] + "…") if len(s) > n else s

class RagQAService:
    """Coordinates retrieval + generation, and returns a rich reference block."""
    def __init__(self, index: BareActsIndex, llm: MultiLLM):
        self.index = index
        self.llm = llm
        self.builder = PromptBuilder()

    def answer(self, *, question: str, model_label: str, k: int = 6) -> str:
        ctx = self.index.query(question, k=k)
        user = self.builder.build_user(question, ctx)
        reply = self.llm.complete(model_label=model_label, system=self.builder.SYSTEM, user=user)

        # Rich references: file, chunk index, snippet
        references = "\n".join(
            f"- [{c['meta']['source']} #{c['meta']['chunk_id']}] {_snippet(c['text'])}"
            for c in ctx
        )
        return f"{reply}\n\n**References**\n{references}"


In [None]:
llm = MultiLLM(MODEL_REGISTRY)
qa_service = RagQAService(index=index, llm=llm)

# quick smoke test (won't spend tokens if no keys for that provider)
if AVAILABLE_MODELS:
    print("Ready. Default model:", DEFAULT_MODEL)


In [None]:
def chat_fn(message: str, history: List[Dict], model_label: str, top_k: int) -> str:
    try:
        return qa_service.answer(question=message, model_label=model_label, k=int(top_k))
    except Exception as e:
        return f"⚠️ {e}"

DEFAULT_QUESTION = "Which sections deals with punishment for murder ?"

with gr.Blocks(title="Legal QnA • Bare Acts (RAG + Multi-LLM)") as app:
    gr.Markdown("### 🧑‍⚖️ Legal Q&A on Bare Acts (RAG) — Multi-Provider LLM")
    with gr.Row():
        model_dd = gr.Dropdown(
            choices=AVAILABLE_MODELS or ["OpenAI • GPT-4o-mini"],
            value=DEFAULT_MODEL if AVAILABLE_MODELS else None,
            label="Model"
        )
        topk = gr.Slider(2, 12, value=6, step=1, label="Top-K context")

    chat = gr.ChatInterface(
        fn=chat_fn,
        type="messages",
        additional_inputs=[model_dd, topk],
        textbox=gr.Textbox(
            value=DEFAULT_QUESTION,
            label="Ask a legal question",
            placeholder="Type your question about BNS/IPC/Constitution…"
        ),
    )

app.launch(inbrowser=True)
