In [None]:
# Week 8 • Legal QnA RAG (Agentic)

import os
from dataclasses import dataclass, field
from typing import List, Dict, Any, Tuple, Protocol


from pathlib import Path


from openai import OpenAI
import chromadb
from chromadb import PersistentClient
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction

from transformers import AutoTokenizer
import gradio as gr

from dotenv import load_dotenv
load_dotenv(override=True)

In [None]:
@dataclass(frozen=True)
class Settings:
    # Data
    data_root: str = "knowledge_base/bare_acts"  # same as W5
    
    # Vector store
    db_path: str = "vector_db_w8"
    collection: str = "bare_acts"
    embed_model: str = "sentence-transformers/all-MiniLM-L6-v2"
    chunk_max_tokens: int = 256
    chunk_overlap: float = 0.50

    # Retrieval
    expansions: int = 5
    topk_per_rewrite: int = 10
    neighbor_radius: int = 2
    max_blocks_for_llm: int = 20  # cap merged blocks for token control

    # LLM (default + selectable)
    gen_model: str = "gpt-4o-mini"     # default
    # You can list any OpenAI-compatible IDs you have access to
    selectable_models: tuple = ("gpt-4o-mini",)

    temperature: float = 0.2
    max_tokens: int = 220

SET = Settings()


In [None]:
@dataclass
class Trace:
    rewrites: List[str] = field(default_factory=list)
    retrievals: List[Dict[str, Any]] = field(default_factory=list)  # [{query, hits:[{source,chunk_id,start,end,preview}]}]
    merged_blocks: List[Dict[str, Any]] = field(default_factory=list)  # [{source,start,end,preview}]
    notes: List[str] = field(default_factory=list)

    def as_markdown(self) -> str:
        lines = []
        if self.rewrites:
            lines.append("#### 🔁 Query expansions")
            for i, q in enumerate(self.rewrites, 1):
                lines.append(f"{i}. `{q}`")
            lines.append("")

        if self.retrievals:
            lines.append("#### 🔎 Retrieval per rewrite (top-10 each)")
            for r in self.retrievals:
                lines.append(f"- **Rewrite:** `{r['query']}`")
                for h in r.get("hits", []):
                    lines.append(
                        f"  - [{h['source']} #{h['chunk_id']}] {h['start']}:{h['end']} — {h['preview']}"
                    )
            lines.append("")
        
        if self.merged_blocks:
            lines.append("#### 🧩 Context expansion (± neighbors, merged)")
            for b in self.merged_blocks:
                lines.append(f"- [{b['source']}] {b['start']}:{b['end']} — {b['preview']}")
            lines.append(f"\n**Total merged blocks:** {len(self.merged_blocks)}")
            lines.append("")
        
        if self.notes:
            lines.append("#### 📝 Notes")
            for n in self.notes:
                lines.append(f"- {n}")
            lines.append("")

        return "\n".join(lines) if lines else "_No logs._"


In [None]:
class LLM(Protocol):
    def complete(
        self,
        system: str,
        user: str,
        temperature: float,
        max_tokens: int,
        model: str | None = None,   # <- allow per-call model override
    ) -> str: ...

class Index(Protocol):
    def rebuild(self, docs: List[Tuple[str, str]]) -> None: ...
    def query(self, q: str, k: int) -> List[Dict]: ...
    def get_by_ids(self, ids: List[str]) -> Tuple[List[str], List[Dict]]: ...
    def raw_doc(self, src: str) -> str: ...
    @property
    def tokenizer(self) -> AutoTokenizer: ...


In [None]:
class OpenAILLM(LLM):
    def __init__(self, default_model: str):
        self._client = OpenAI()
        self._default_model = default_model

    def complete(
        self,
        system: str,
        user: str,
        temperature: float = 0.2,
        max_tokens: int = 220,
        model: str | None = None,
    ) -> str:
        model_id = model or self._default_model
        resp = self._client.chat.completions.create(
            model=model_id,
            messages=[
                {"role": "system", "content": system},
                {"role": "user", "content": user},
            ],
            seed=42,
        )
        return (resp.choices[0].message.content or "").strip()


In [None]:
# --- Streaming run logger (ADD) ---
from dataclasses import dataclass, field
from time import perf_counter
from html import escape

@dataclass
class RunLogger:
    lines: List[str] = field(default_factory=list)
    t0: float = field(default_factory=perf_counter)

    def add(self, msg: str):
        dt = perf_counter() - self.t0
        self.lines.append(f"[{dt:05.2f}s] {msg}")

    def html(self) -> str:
        # Simple monospace log panel
        safe = [escape(x) for x in self.lines[-300:]]  # cap last 300 lines
        return "<div style='font-family:ui-monospace,Menlo,Consolas;white-space:pre-wrap'>" + "<br/>".join(safe) + "</div>"


In [None]:
def _one_line(text: str, limit: int = 120) -> str:
    s = " ".join(text.split())
    return s[:limit] + "…" if len(s) > limit else s

def _hit_view(h: Dict) -> Dict:
    m = h["meta"]
    return {
        "source": m["source"],
        "chunk_id": int(m["chunk_id"]),
        "start": int(m["start_tok"]),
        "end": int(m["end_tok"]),
        "preview": _one_line(h["text"]),
    }

def _hit_id(h: Dict) -> str:
    m = h["meta"]
    # Example: [bns 4720:4860 (#38)]
    return f"[{m['source']} {int(m['start_tok'])}:{int(m['end_tok'])} (#{int(m['chunk_id'])})]"


In [None]:
def load_bare_acts(root: str) -> List[Tuple[str, str]]:
    """Return (source_id, text). Each .txt is one bare act file."""
    base = Path(root)
    if not base.exists():
        raise FileNotFoundError(f"Folder not found: {base.resolve()}")
    pairs: List[Tuple[str, str]] = []
    for p in sorted(base.glob("*.txt")):
        pairs.append((p.stem, p.read_text(encoding="utf-8")))
    if not pairs:
        raise RuntimeError(f"No .txt files found under {base}")
    return pairs

acts_raw = load_bare_acts(SET.data_root)
print("Bare Acts loaded:", [s for s, _ in acts_raw][:5], "…")


In [None]:
def chunk_text_with_spans(
    text: str,
    tokenizer: AutoTokenizer,
    max_tokens: int,
    overlap_ratio: float,
) -> List[Tuple[int, int, int, int, str]]:
    """
    Returns a list of chunks with both token and character spans:
    (start_tok, end_tok, start_char, end_char, chunk_text)
    """
    enc = tokenizer(
        text,
        add_special_tokens=False,
        return_offsets_mapping=True,
    )
    ids = enc["input_ids"]
    offs = enc["offset_mapping"]  # list[(char_start, char_end)]
    if not ids:
        return []

    step = max(1, int(max_tokens * (1.0 - overlap_ratio)))
    out = []
    for start in range(0, len(ids), step):
        end = min(start + max_tokens, len(ids))
        if start >= end:
            break
        # token -> char
        start_char = offs[start][0]
        end_char = offs[end - 1][1]
        chunk = text[start_char:end_char].strip()
        if chunk:
            out.append((start, end, start_char, end_char, chunk))
        if end >= len(ids):
            break
    return out


In [None]:
class ChromaBareActsIndex(Index):
    def __init__(self, settings: Settings):
        self.set = settings
        self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=self.set.embed_model)
        self._tokenizer = AutoTokenizer.from_pretrained(self.set.embed_model)
        self._client: PersistentClient = PersistentClient(path=self.set.db_path)
        self._col = self._client.get_or_create_collection(
            name=self.set.collection,
            embedding_function=self.embed_fn
        )
        self._doc_cache: Dict[str, str] = {}

    @property
    def tokenizer(self) -> AutoTokenizer:
        return self._tokenizer

    def rebuild(self, docs: List[Tuple[str, str]]) -> None:
        self._doc_cache = {src: txt for src, txt in docs}
        try:
            self._client.delete_collection(self.set.collection)
        except Exception:
            pass
        self._col = self._client.get_or_create_collection(
            name=self.set.collection,
            embedding_function=self.embed_fn
        )

        ids, texts, metas = [], [], []
        for src, text in docs:
            spans = chunk_text_with_spans(
                text, tokenizer=self._tokenizer,
                max_tokens=self.set.chunk_max_tokens,
                overlap_ratio=self.set.chunk_overlap
            )
            for idx, (start_tok, end_tok, start_char, end_char, ch) in enumerate(spans):
                ids.append(f"{src}-{idx}")
                texts.append(ch)
                metas.append({
                    "source": src,
                    "chunk_id": idx,
                    "start_tok": int(start_tok),
                    "end_tok": int(end_tok),
                    "start_char": int(start_char),
                    "end_char": int(end_char),
                })
        if ids:
            self._col.add(ids=ids, documents=texts, metadatas=metas)
        print(f"Indexed {len(texts)} chunks from {len(docs)} files → {self.set.collection}")

    def query(self, q: str, k: int) -> List[Dict]:
        res = self._col.query(query_texts=[q], n_results=k)
        docs  = res.get("documents", [[]])[0]
        metas = res.get("metadatas", [[]])[0]
        ids   = res.get("ids", [[]])[0]
        out = []
        for _id, d, m in zip(ids, docs, metas):
            if not m: 
                continue
            m = dict(m)
            m["id"] = _id
            out.append({"text": d, "meta": m})
        return out

    def get_by_ids(self, ids: List[str]) -> Tuple[List[str], List[Dict]]:
        """
        Return documents & metadatas in the SAME ORDER as the requested ids.
        Compatible with Chroma clients that don't allow 'ids' in `include`.
        """
        if not ids:
            return [], []

        # 'ids' is NOT a valid item in `include`; ask only for docs+metas
        got = self._col.get(ids=ids, include=["documents", "metadatas"])

        ret_ids = (got.get("ids") or [])
        docs    = (got.get("documents") or [])
        metas   = (got.get("metadatas") or [])

        # If the client returns 'ids', reorder using it
        if ret_ids:
            table = {i: (d, m) for i, d, m in zip(ret_ids, docs, metas)}
            ordered_docs, ordered_metas = [], []
            for want in ids:
                if want in table:
                    d, m = table[want]
                    ordered_docs.append(d)
                    ordered_metas.append(m)
            return ordered_docs, ordered_metas

        # Fallback: some clients omit 'ids' in the response; do deterministic 1-by-1
        ordered_docs, ordered_metas = [], []
        for _id in ids:
            sub = self._col.get(ids=[_id], include=["documents", "metadatas"])
            d = (sub.get("documents") or [None])[0]
            m = (sub.get("metadatas") or [None])[0]
            if d is not None and m is not None:
                ordered_docs.append(d)
                ordered_metas.append(m)
        return ordered_docs, ordered_metas



    def raw_doc(self, src: str) -> str:
        return self._doc_cache[src]


In [None]:
index = ChromaBareActsIndex(SET)
index.rebuild(acts_raw)  # run once; safe to re-run if content changed


In [None]:
import json, re

class QueryExpander:
    """
    Expands a user question into N retrieval-friendly keyphrase queries.
    - Tries LLM (strict JSON array) first.
    - If parsing fails or returns too few items, falls back to deterministic domain-aware rewrites.
    - Exposes .last_info to let the UI log whether fallback was used.
    """
    def __init__(self, llm: LLM, n: int):
        self.llm = llm
        self.n = n
        self.last_info: Dict[str, Any] = {}

    @staticmethod
    def _extract_json_array(text: str) -> List[str]:
        """
        Try to safely extract a JSON array from messy model output.
        """
        try:
            # Fast path: direct JSON
            parsed = json.loads(text)
            return [x for x in parsed if isinstance(x, str)]
        except Exception:
            pass

        # Fallback: regex the first [...] block
        m = re.search(r"\[(.|\n|\r)*\]", text)
        if m:
            try:
                parsed = json.loads(m.group(0))
                return [x for x in parsed if isinstance(x, str)]
            except Exception:
                return []
        return []

    @staticmethod
    def _dedupe_keep_order(items: List[str]) -> List[str]:
        seen = set()
        out = []
        for it in items:
            k = it.strip().lower()
            if k and k not in seen:
                seen.add(k)
                out.append(it.strip())
        return out

    def _deterministic_fallback(self, question: str) -> List[str]:
        """
        Domain-aware, retrieval-ready keyphrases for Bare Acts.
        Keeps them short (good for vector/BM25), includes act names & synonyms.
        """
        q = re.sub(r"[?]+$", "", question).strip()
        # Heuristic tokens
        base = q.lower()
        # Try to infer a core noun/verb pair for variety
        variants = [
            f"{q} section",
            f"{q} provision bare act",
            f"{q} indian penal code",
            f"{q} bharatiya nyaya sanhita",
            f"{q} punishment section key words",
        ]
        # Add generic legal synonyms if helpful
        synonyms = [
            "murder punishment section",
            "culpable homicide punishment",
            "offence of murder penalty",
            "ipc section for murder",
            "bns murder punishment"
        ]
        pool = self._dedupe_keep_order(variants + synonyms)
        return pool[: self.n] if len(pool) >= self.n else (pool + [q])[: self.n]

    def expand(self, question: str, *, model_override: str | None = None) -> List[str]:
        sys = (
            "You generate EXACTLY the requested number of short, retrieval-friendly queries "
            "for Indian Bare Acts (IPC/BNS/Constitution). Keep them concise (4–20 words), "
            "keyphrase-style (no punctuation, no quotes), and diversify wording and act names. "
            "Do NOT add commentary or mention any bare act or section number. Respond ONLY as a JSON array of strings.\n\n"
            "Good examples:\n"
            '["murder punishment", "punishment for murder", '
            '"attempt to murder", "culpable homicide amounting to murder", "provision murder penalty"]'
        )
        user = f"Question:\n{question}\n\nReturn {self.n} diverse keyphrase queries as a JSON array."

        raw = self.llm.complete(system=sys, user=user, temperature=0.2, max_tokens=300, model=model_override)
        queries = self._extract_json_array(raw)
        queries = self._dedupe_keep_order(queries)

        used_fallback = False
        if len(queries) < self.n:
            used_fallback = True
            queries = self._deterministic_fallback(question)

        # final safety: trim and cap
        queries = [re.sub(r"[^\w\s\-./]", "", q).strip() for q in queries]
        queries = [q for q in queries if q]
        queries = queries[: self.n]

        self.last_info = {
            "used_fallback": used_fallback,
            "raw": raw,
            "final": queries,
        }
        return queries


In [None]:
class Retriever:
    def __init__(self, index: Index, k: int):
        self.index = index
        self.k = k

    def topk(self, query: str) -> List[Dict]:
        return self.index.query(query, k=self.k)

In [None]:
class ContextExpander:
    def __init__(
        self,
        index: Index,
        radius: int,
        max_blocks: int,
        pad_words: int = 100,
        words_to_tokens: float = 1.4,
    ):
        """
        Expand context by:
          1) Adding ±neighbor_radius chunks around every hit (by chunk_id)
          2) Converting those chunk spans to padded token ranges
          3) Merging overlapping token ranges per source
        """
        self.index = index
        self.radius = max(0, int(radius))
        self.max_blocks = int(max_blocks)
        # approx tokens to pad on both sides of each span
        self.pad_tokens = int(pad_words * words_to_tokens)
        # for logging/inspection
        self.last_ids: List[str] = []

    @staticmethod
    def _merge_spans(spans: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
        if not spans:
            return []
        spans = sorted(spans, key=lambda x: x[0])
        merged = [spans[0]]
        for s, e in spans[1:]:
            ls, le = merged[-1]
            if s <= le:
                merged[-1] = (ls, max(le, e))
            else:
                merged.append((s, e))
        return merged

    @staticmethod
    def _dedupe_keep_order(items: List[str]) -> List[str]:
        seen = set()
        out: List[str] = []
        for it in items:
            if it not in seen:
                seen.add(it)
                out.append(it)
        return out

    def _neighbor_ids_for_hit(self, meta: Dict[str, Any]) -> List[str]:
        """
        Build vector-store IDs for the hit's chunk and its ±radius neighbors.
        We rely on the indexing convention: id == f"{source}-{chunk_id}".
        """
        src = meta["source"]
        cid = int(meta["chunk_id"])
        ids: List[str] = []
        for d in range(-self.radius, self.radius + 1):
            n = cid + d
            if n < 0:
                continue
            ids.append(f"{src}-{n}")
        return ids

    def expand_and_merge(self, hits: List[Dict]) -> List[Dict]:
        # 1) Collect all candidate ids (hits + ±neighbors), dedupe, keep order
        all_ids: List[str] = []
        for h in hits:
            m = h["meta"]
            hit_id = m.get("id") or f"{m['source']}-{int(m['chunk_id'])}"
            all_ids.append(hit_id)
            all_ids.extend(self._neighbor_ids_for_hit(m))

        all_ids = self._dedupe_keep_order(all_ids)
        self.last_ids = all_ids[:]  # capture for logging/inspection

        # 2) Fetch metas for those ids (order-preserving as much as possible)
        _, metas = self.index.get_by_ids(all_ids)

        # 3) Build padded token spans per source from metas
        spans_by_src: Dict[str, List[Tuple[int, int]]] = {}
        for m in metas:
            if not m:
                continue
            src = m["source"]
            s_tok = int(m["start_tok"])
            e_tok = int(m["end_tok"])
            ps = max(0, s_tok - self.pad_tokens)
            pe = e_tok + self.pad_tokens
            spans_by_src.setdefault(src, []).append((ps, pe))

        # 4) Merge, slice, and construct blocks (capped by max_blocks)
        blocks: List[Dict] = []
        for src, spans in spans_by_src.items():
            merged_tok_spans = self._merge_spans(spans)

            full_doc = self.index.raw_doc(src)
            # Tokenize & get offsets once per source
            enc = self.index.tokenizer(
                full_doc, add_special_tokens=False, return_offsets_mapping=True
            )
            ids = enc["input_ids"]
            offs = enc["offset_mapping"]

            for s_tok, e_tok in merged_tok_spans:
                if not ids:
                    continue
                s_tok = max(0, min(s_tok, len(ids)))
                e_tok = max(0, min(e_tok, len(ids)))
                if e_tok <= s_tok:
                    continue

                # token span -> char span
                s_char = offs[s_tok][0]
                e_char = offs[e_tok - 1][1]
                text = full_doc[s_char:e_char]

                blocks.append(
                    {
                        "source": src,
                        "start": int(s_tok),
                        "end": int(e_tok),
                        "text": text,
                    }
                )

        blocks.sort(key=lambda b: (b["source"], b["start"]))
        return blocks[: self.max_blocks]


In [None]:
class PromptBuilder:
    SYSTEM = (
        "You are a precise legal assistant for Indian Bare Acts.\n"
    )

    @staticmethod
    def build_user(question: str, blocks: List[Dict]) -> str:
        ctx = "\n\n---\n\n".join(
            f"[{b['source']} {b['start']}:{b['end']}]\n{b['text']}" for b in blocks
        )
        return (
            f"Question:\n{question}\n\n"
            f"Context (use only this):\n{ctx}\n\n"
            "Instructions:\n"
            "- Describe the context whether it has information to answer the question or not or whether any part can be used to answer the question.\n"
            "- Answer the original question in new paragraph with label Answer:\n"
            "- Quote or paraphrase ONLY from the context above.\n"
            "- Inline-cite with [source start:end] when using any snippet.\n"
            "- If the answer is not in context, please describe what the agentic rag found related to original question and try to tell answer based on the data."
        )





In [None]:
class Answerer:
    def __init__(self, llm: LLM, set: Settings):
        self.llm = llm
        self.set = set

    def answer(self, question: str, blocks: List[Dict], model: str | None = None) -> str:
        user = PromptBuilder.build_user(question, blocks)
        return self.llm.complete(
            system=PromptBuilder.SYSTEM,
            user=user,
            temperature=self.set.temperature,
            max_tokens=self.set.max_tokens,
            model=model or self.set.gen_model,
        )


In [None]:
class Critic:
    def __init__(self, llm: LLM):
        self.llm = llm

    def review(self, question: str, answer: str, blocks: List[Dict]) -> Dict[str, Any]:
        block_ids = [f"[{b['source']} {b['start']}:{b['end']}]" for b in blocks]
        sys = (
            "You are a meticulous legal verifier. "
            "Return ONLY JSON with keys: ok (bool), missing (list of short missing facts), followups (list of short retrieval keyphrases)."
        )
        user = f"""Question:
{question}

Proposed answer:
{answer}

Verify that every factual claim is supported by the context blocks (by their ids).
If support is weak or missing, set ok=false and propose concise retrieval keyphrases in followups.

Available context block ids:
{", ".join(block_ids)}

Return JSON only, e.g.:
{{"ok": true, "missing": [], "followups": []}}
"""
        raw = self.llm.complete(system=sys, user=user, temperature=0.0, max_tokens=220)
        try:
            m = re.search(r"\{(.|\n)*\}", raw)
            return json.loads(m.group(0)) if m else {"ok": True, "missing": [], "followups": []}
        except Exception:
            return {"ok": True, "missing": [], "followups": []}


In [None]:
import modal
_remote_expand = modal.Function.from_name("legal-query-expander-qwen3-v2", "expand")


class ModalFirstExpander(QueryExpander):
    def expand(self, question: str, *, model_override: str | None = None) -> List[str]:
        got = []
        if _remote_expand:
            try:
                got = _remote_expand.remote(question, self.n)
            except Exception:
                got = []
        if not got or len([x for x in got if isinstance(x, str) and x.strip()]) < max(1, self.n // 2):
            return super().expand(question, model_override=model_override)
        import re
        got = [re.sub(r"[^\w\s\-./]", "", q).strip() for q in got]
        return [q for q in got if q][: self.n]



In [None]:
class LegalAgent:
    def __init__(self, expander: QueryExpander, retriever: Retriever, ctx_expander: ContextExpander, answerer: Answerer, set: Settings):
        self.expander = expander
        self.retriever = retriever
        self.ctx_expander = ctx_expander
        self.answerer = answerer
        self.set = set

    def run(self, question: str, model: str | None = None) -> str:
        # Pass 1
        rewrites = self.expander.expand(question, model_override=model)
        hits: List[Dict] = []
        for q in rewrites:
            hits.extend(self.retriever.topk(q))
        blocks = self.ctx_expander.expand_and_merge(hits)
        answer1 = self.answerer.answer(question, blocks, model=model)

        # Self-critique
        critic = Critic(self.answerer.llm)
        review = critic.review(question, answer1, blocks)

        if review.get("ok", True) or not review.get("followups"):
            return answer1  # Good enough

        # Pass 2 — adapt plan using follow-up rewrites
        extra_hits: List[Dict] = []
        for q in review["followups"][:3]:   # keep bounded
            extra_hits.extend(self.retriever.topk(q))
        blocks2 = self.ctx_expander.expand_and_merge(hits + extra_hits)
        answer2 = self.answerer.answer(question, blocks2, model=model)
        return answer2 + "\n\n_Refined after self-critique (agentic step)._"

    def run_stream(self, question: str, model: str | None = None):
        """
        Generator: yields tuples (answer_or_none, logs_html) multiple times.
        Now includes a self-critique step and (if needed) a second pass, with logs.
        """
        log = RunLogger()
        log.add(f"Question: {question}")
        yield None, log.html()

        # 1) Expand queries
        rewrites = self.expander.expand(question, model_override=model)
        if getattr(self.expander, "last_info", {}).get("used_fallback", False):
            log.add("Query expansion: LLM output unparsable → using deterministic fallback.")
        log.add(f"Expanded into {len(rewrites)} queries:")
        for i, q in enumerate(rewrites, 1):
            log.add(f"  {i}. {q}")
        yield None, log.html()

        # 2) Retrieve top-k per rewrite
        all_hits: List[Dict] = []
        for i, q in enumerate(rewrites, 1):
            hits = self.retriever.topk(q)
            all_hits.extend(hits)
            top3 = ", ".join(_hit_id(h) for h in hits[:3]) or "—"
            log.add(f"Retrieval {i}/{len(rewrites)}: got {len(hits)} hits → {top3}")
            yield None, log.html()

        # 3) Context expansion / merging (with neighbor ids logged)
        blocks = self.ctx_expander.expand_and_merge(all_hits)
        used = self.ctx_expander.last_ids
        peek = ", ".join(used[:8]) + (" …" if len(used) > 8 else "")
        log.add(f"Neighbor addition: collected {len(used)} chunk-ids → {peek}")

        approx_words = int(self.ctx_expander.pad_tokens / 1.4)  # inverse of words_to_tokens≈1.4
        log.add(
            f"Context expansion: merged {len(blocks)} block(s) "
            f"(radius ±{self.ctx_expander.radius}, pad ≈{approx_words} words)."
        )
        for b in blocks:
            log.add(f"  [{b['source']} {b['start']}:{b['end']}]")
            log.add(b["text"][:50].replace("\n", " "))
        yield None, log.html()

        # 4) LLM answer — pass 1
        log.add(f"Asking LLM (pass 1): {model or self.set.gen_model}")
        yield None, log.html()
        answer1 = self.answerer.answer(question, blocks, model=model)
        log.add("Answer 1 ready.")
        yield answer1, log.html()

        # 5) Self-critique
        log.add("Running self-critique…")
        yield None, log.html()
        critic = Critic(self.answerer.llm)
        review = critic.review(question, answer1, blocks)
        ok = review.get("ok", True)
        missing = review.get("missing") or []
        followups = review.get("followups") or []
        # compact display
        def _preview(lst, n=5):
            return ", ".join(lst[:n]) + (f" … (+{len(lst)-n} more)" if len(lst) > n else "") if lst else "—"
        log.add(f"Self-critique result: ok={ok}; missing={_preview(missing)}; followups={_preview(followups)}")
        yield None, log.html()

        if ok or not followups:
            log.add("Critique passed or no follow-ups. Finalizing pass 1.")
            yield answer1, log.html()
            return

        # 6) Pass 2 — follow-up retrievals
        extra_hits: List[Dict] = []
        limited = followups[:3]  # keep bounded
        for i, q in enumerate(limited, 1):
            hits = self.retriever.topk(q)
            extra_hits.extend(hits)
            top3 = ", ".join(_hit_id(h) for h in hits[:3]) or "—"
            log.add(f"Follow-up retrieval {i}/{len(limited)}: got {len(hits)} hits → {top3}")
            yield None, log.html()

        blocks2 = self.ctx_expander.expand_and_merge(all_hits + extra_hits)
        used2 = self.ctx_expander.last_ids
        peek2 = ", ".join(used2[:8]) + (" …" if len(used2) > 8 else "")
        log.add(f"Neighbor addition (pass 2): collected {len(used2)} chunk-ids → {peek2}")

        approx_words2 = int(self.ctx_expander.pad_tokens / 1.4)
        log.add(
            f"Context expansion (pass 2): merged {len(blocks2)} block(s) "
            f"(radius ±{self.ctx_expander.radius}, pad ≈{approx_words2} words)."
        )
        for b in blocks2[:10]:  # don’t spam the log
            log.add(f"  [{b['source']} {b['start']}:{b['end']}]")
        yield None, log.html()

        # 7) LLM answer — pass 2
        log.add(f"Asking LLM (pass 2): {model or self.set.gen_model}")
        yield None, log.html()
        answer2 = self.answerer.answer(question, blocks2, model=model)
        final_answer = answer2 + "\n\n_Refined after self-critique (agentic step)._"
        log.add("Answer 2 ready (refined).")
        yield final_answer, log.html()


In [None]:
def make_agent(gen_model: str) -> LegalAgent:
    llm = OpenAILLM(gen_model)
    expander = ModalFirstExpander(llm=llm, n=SET.expansions)   # <— here
    retriever = Retriever(index=index, k=SET.topk_per_rewrite)
    ctx_expander = ContextExpander(index=index, radius=SET.neighbor_radius, max_blocks=SET.max_blocks_for_llm)
    answerer = Answerer(llm=llm, set=SET)
    return LegalAgent(expander, retriever, ctx_expander, answerer, SET)


In [None]:
agent = make_agent(SET.gen_model)
print("Agent ready (global).")


In [None]:
import gradio as gr

MODEL_CHOICES = [
    "gpt-4o-mini",  # default
    "gpt-4o",
    "gpt-5",
    "gpt-5-nano",
]

DEFAULT_Q = "is death by ignorance like giving medicine to old person amounts to a murder ?"

def chat_stream(chat_history: List[Tuple[str, str]], message: str, model_choice: str, topk: int, radius: int, max_blocks: int):
    # Always create a fresh agent per request (avoids NameError & state bleed)
    _agent = make_agent(model_choice)
    _agent.retriever.k = int(topk)
    _agent.ctx_expander.radius = int(radius)
    _agent.ctx_expander.max_blocks = int(max_blocks)

    # Initialize history and add a placeholder response
    if chat_history is None:
        chat_history = []
    chat_history = chat_history + [(message, "🛠️ running pipeline…")]
    yield chat_history, "<i>starting…</i>"

    # Stream logs + final answer
    for ans, logs_html in _agent.run_stream(message, model=model_choice):
        assistant_text = ans if ans is not None else "⏳ working…"
        chat_history[-1] = (message, assistant_text)
        # logs_html already contains inline styles; just pass it through
        yield chat_history, logs_html

with gr.Blocks(title="Week 8 • Legal QnA RAG (Agentic)") as app:
    gr.Markdown("### 🧑‍⚖️ Legal Q&A on Bare Acts — Agentic RAG (Week 8)")
    gr.Markdown("Flow: **query expansion → multi-retrieval → context expansion (±neighbors, merged) → LLM answer**")

    with gr.Row():
        # Left: classic chat
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="Chat", height=420)
            with gr.Row():
                msg = gr.Textbox(value=DEFAULT_Q, label="Ask a legal question", scale=5)
                send = gr.Button("Send", variant="primary", scale=1)

        # Right: logs panel (simple header + HTML)
        with gr.Column(scale=2):
            gr.Markdown("#### Agent Logs")
            logs_html = gr.HTML(value="<div style='font-family:ui-monospace,Menlo,Consolas;white-space:pre-wrap;border:1px solid #ddd;border-radius:8px;padding:10px;height:380px;overflow:auto;background:#fafafa'>Idle</div>")

    with gr.Accordion("Advanced", open=False):
        with gr.Row():
            model_dd = gr.Dropdown(choices=MODEL_CHOICES, value=SET.gen_model, label="LLM Model", scale=2)
            topk = gr.Slider(2, 20, value=SET.topk_per_rewrite, step=1, label="Top-K per rewrite", scale=2)
            radius = gr.Slider(0, 4, value=SET.neighbor_radius, step=1, label="Neighbor radius (±)", scale=2)
            cap = gr.Slider(4, 60, value=SET.max_blocks_for_llm, step=1, label="Max merged blocks for LLM", scale=2)

    # Wire streaming generator to both Send click and Enter submit
    send.click(
        fn=chat_stream,
        inputs=[chatbot, msg, model_dd, topk, radius, cap],
        outputs=[chatbot, logs_html],
        show_progress=False,
    ).then(lambda: "", None, msg)  # clear textbox

    msg.submit(
        fn=chat_stream,
        inputs=[chatbot, msg, model_dd, topk, radius, cap],
        outputs=[chatbot, logs_html],
        show_progress=False,
    ).then(lambda: "", None, msg)  # clear textbox

app.launch(inbrowser=True)
