<a href="https://colab.research.google.com/github/fzamzami/aoc_id/blob/master/Ethics_Compliance_v0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install RapidFuzz pypdf python-docx #compressed-tensors

Collecting RapidFuzz
  Downloading rapidfuzz-3.14.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Collecting pypdf
  Downloading pypdf-6.3.0-py3-none-any.whl.metadata (7.1 kB)
Collecting python-docx
  Downloading python_docx-1.2.0-py3-none-any.whl.metadata (2.0 kB)
Downloading rapidfuzz-3.14.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m67.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pypdf-6.3.0-py3-none-any.whl (328 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m328.9/328.9 kB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading python_docx-1.2.0-py3-none-any.whl (252 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.0/253.0 kB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: RapidFuzz, python-docx, pypdf
Successfully installed RapidFuzz-3.14.3 pypdf-6.3.0 python-do

In [None]:
import os
import re
import json
import uuid
import gc
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict
import pandas as pd
from pathlib import Path
import numpy as np
import torch
from pydantic import BaseModel, Field, field_validator
from rapidfuzz import fuzz
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from pypdf import PdfReader
from docx import Document as DocxDocument

In [7]:
# =========================
# GLOBAL CONFIG (edit here)
# =========================
OUT_DIR = "./out"
FN_THEMES = "10_themes.json"
FN_CLAUSES_DEDUP = "25_clauses_dedup.jsonl"

# ---------- program doc to audit ----------
PROGRAM_INPUT_PATH = "sampledoc.txt"  # change to your .pdf / .docx / .txt / .md

FN_PROG_CHUNKS = "40_program_chunks.jsonl"
FN_PROG_EMB    = "45_program_embeddings.npz"
FN_RETRIEVAL   = "60_retrieval.jsonl"
FN_HYDE_CACHE  = "55_hyde.jsonl"          # cache HyDE generations
FN_JUDGE_RAW   = "70_judgments_raw.jsonl"
FN_JUDGE_MAJ   = "80_judgments_majority.jsonl"

MODEL_NAME       = "humain-ai/ALLaM-7B-Instruct-preview"
EMBED_MODEL_NAME = "intfloat/multilingual-e5-base"

_HAS_CUDA = torch.cuda.is_available()

# ==== Use your A100 ====
USE_GPU_LLM       = True
USE_GPU_EMBEDDER  = True

LLM_DEVICE  = "cuda" if (_HAS_CUDA and USE_GPU_LLM) else "cpu"
EMB_DEVICE  = "cuda" if (_HAS_CUDA and USE_GPU_EMBEDDER) else "cpu"
DTYPE       = torch.bfloat16 if LLM_DEVICE == "cuda" else torch.float32

# Speed knobs (safe defaults for A100 40/80GB)
EMBED_BATCH_SIZE  = 512         # big batches to saturate GPU
HYDE_BATCH_SIZE   = 32          # LLM batch size for HyDE
JUDGE_BATCH_SIZE  = 16          # LLM batch size for Judge

# Token budgets
PROG_INPUT_TOKENS    = 900
PROG_OVERLAP_TOKENS  = 100
TOP_K                = 3        # fewer chunks => smaller judge prompts & faster
USE_HYDE             = True
HYDE_MAX_NEW_TOKENS  = 96
CONTEXT_TOKEN_CAP    = 900      # clamp context pack to keep VRAM in check
MIN_SNIPPET_LEN      = 200

# Judge parameters (3x self-consistency; low-temp for stability)
JUDGE_SAMPLES        = 3
JUDGE_TEMPERATURE    = 0.2
JUDGE_TOP_P          = 0.95
JUDGE_MAX_NEW_TOKENS = 196

# Enable TF32 on A100
if _HAS_CUDA:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_float32_matmul_precision("high")

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# =========================
# Schemas (same as Stage-1)
# =========================
class EvidenceSpan(BaseModel):
    quote: str
    chunk_index: int
    start_char: int
    end_char: int

class ThemeItem(BaseModel):
    name: str
    short_description: str
    category: str
    severity_guidance: str
    example_phrases: List[str] = Field(default_factory=list)
    evidence: List[EvidenceSpan] = Field(default_factory=list)

SEVERITY_ENUM = {"low", "medium", "high"}

class ClauseItem(BaseModel):
    clause_id: Optional[str] = None
    source_chunk_index: int
    text: str
    themes: List[str] = Field(default_factory=list)
    category: str
    severity: str
    tags: List[str] = Field(default_factory=list)
    question_yes_no: str
    ideal_decision: str = Field(..., pattern="^(include|exclude)$")
    reasoning: str
    language: str
    evidence: List[EvidenceSpan] = Field(default_factory=list)

    @field_validator("severity")
    @classmethod
    def _sev_ok(cls, v):
        v2 = v.lower().strip()
        if v2 not in SEVERITY_ENUM:
            raise ValueError(f"severity must be one of {SEVERITY_ENUM}")
        return v2

class ProgramChunk(BaseModel):
    chunk_id: str
    page: int
    start_char: int
    end_char: int
    text: str
    section: str = ""
    language: str = "en"

class RetrievalRecord(BaseModel):
    clause_id: str
    topk: List[Dict[str, Any]]
    hyde_text: str = ""

VALID_LABELS = {"Compliant", "Violation", "Missing", "Needs-Review"}

class JudgeAnswer(BaseModel):
    label: str = Field(..., pattern="^(Compliant|Violation|Missing|Needs-Review)$")
    rationale: str
    evidence: List[Dict[str, Any]] = Field(default_factory=list)

class JudgeSample(BaseModel):
    clause_id: str
    sample_id: int
    answer: JudgeAnswer
    context_chunk_ids: List[str] = Field(default_factory=list)

class JudgeMajority(BaseModel):
    clause_id: str
    final_label: str
    votes: Dict[str, int]
    samples: List[int] = Field(default_factory=list)

# =========================
# Utils
# =========================
def ensure_outdir(path: str):
    os.makedirs(path, exist_ok=True)

def free_gpu():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

def read_text_any(input_path: str) -> Tuple[str, List[Tuple[int, str]]]:
    ext = os.path.splitext(input_path)[1].lower()
    if ext in [".txt", ".md"]:
        with open(input_path, "r", encoding="utf-8") as f:
            s = f.read()
        return s, [(1, s)]
    if ext == ".pdf":
        reader = PdfReader(input_path)
        pages = []
        for i, p in enumerate(reader.pages, start=1):
            pages.append((i, (p.extract_text() or "")))
        return "\n".join(t for _, t in pages), pages
    if ext == ".docx":
        doc = DocxDocument(input_path)
        s = "\n".join(p.text for p in doc.paragraphs)
        return s, [(1, s)]
    raise ValueError(f"Unsupported file: {ext}")

def token_windows(tok: AutoTokenizer, text: str, target_tokens: int, overlap_tokens: int):
    enc = tok(text, add_special_tokens=False, return_offsets_mapping=True)
    offs = enc["offset_mapping"]
    n = len(offs)
    if n == 0:
        return
    i = 0
    while i < n:
        j = min(n, i + target_tokens)
        start = offs[i][0]
        end = offs[j - 1][1]
        yield start, end, text[start:end]
        if j == n:
            break
        i = max(0, j - overlap_tokens)

def index_chunks_by_id(chunks: List[ProgramChunk]) -> Dict[str, ProgramChunk]:
    return {c.chunk_id: c for c in chunks}

def clamp_context_by_tokens(tok: AutoTokenizer, texts: List[str], cap_tokens: int) -> List[str]:
    acc, total = [], 0
    for t in texts:
        n = len(tok(t, add_special_tokens=False)["input_ids"])
        if total + n > cap_tokens:
            break
        acc.append(t); total += n
    return acc

# =========================
# LLM wrapper (GPU-first, left padding, flash_attn2/SDPA)
# =========================
@dataclass
class HFConfig:
    model_name: str = MODEL_NAME
    temperature: float = 0.0
    max_new_tokens: int = 196
    top_p: float = 1.0
    retries: int = 1

class HFChat:
    def __init__(self, cfg: HFConfig):
        self.cfg = cfg
        self.tok = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
        # Left padding for decoder-only models => correct + fast batching
        self.tok.padding_side = "left"
        if self.tok.pad_token_id is None and self.tok.eos_token_id is not None:
            self.tok.pad_token = self.tok.eos_token

        attn_kwargs = {}
        if LLM_DEVICE == "cuda":
            # Prefer flash_attention_2; fallback to SDPA if not available
            try:
                attn_kwargs["attn_implementation"] = "flash_attention_2"
            except Exception:
                pass

        if LLM_DEVICE == "cuda":
            try:
                self.model = AutoModelForCausalLM.from_pretrained(
                    cfg.model_name,
                    torch_dtype=DTYPE,
                    device_map="auto",
                    low_cpu_mem_usage=True,
                    **attn_kwargs
                )
            except Exception:
                self.model = AutoModelForCausalLM.from_pretrained(
                    cfg.model_name,
                    torch_dtype=DTYPE,
                    device_map="auto",
                    low_cpu_mem_usage=True,
                )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                cfg.model_name, torch_dtype=torch.float32
            ).to("cpu")

        # small speed ups
        self.model.eval()
        if hasattr(self.model.config, "use_cache"):
            self.model.config.use_cache = True

    def _build_prompts(self, system: str, users: List[str]) -> Tuple[Dict[str, torch.Tensor], List[int]]:
        prompts = []
        if getattr(self.tok, "chat_template", None):
            for u in users:
                msgs = [{"role": "system", "content": system},
                        {"role": "user", "content": u}]
                prompts.append(self.tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True))
        else:
            for u in users:
                prompts.append(f"<|system|>\n{system}\n<|user|>\n{u}\n<|assistant|>\n")
        # pad_to_multiple_of 8 -> better tensor core utilization
        enc = self.tok(prompts, return_tensors="pt", padding=True, truncation=False, pad_to_multiple_of=8)
        in_lens = enc["attention_mask"].sum(dim=1).tolist()
        return {k: v.to(LLM_DEVICE) for k, v in enc.items()}, in_lens

    @staticmethod
    def _extract_json(text: str) -> Optional[Dict[str, Any]]:
        m = re.search(r"\{.*\}", text, flags=re.DOTALL)
        if not m:
            return None
        raw = m.group(0)
        try:
            return json.loads(raw)
        except json.JSONDecodeError:
            raw2 = re.sub(r",\s*([}\]])", r"\1", raw)
            try:
                return json.loads(raw2)
            except Exception:
                return None

    def generate_json_batch(self, system: str, users: List[str],
                            max_new_tokens_override: Optional[int] = None,
                            temperature: Optional[float] = None,
                            top_p: Optional[float] = None) -> List[Optional[Dict[str, Any]]]:
        max_new = max_new_tokens_override or self.cfg.max_new_tokens
        temp    = self.cfg.temperature if temperature is None else temperature
        topp    = self.cfg.top_p if top_p is None else top_p

        inputs, in_lens = self._build_prompts(system, users)
        with torch.inference_mode():
            out = self.model.generate(
                **inputs,
                max_new_tokens=max_new,
                temperature=temp,
                top_p=topp,
                do_sample=(temp > 0),
                eos_token_id=self.tok.eos_token_id,
                pad_token_id=self.tok.pad_token_id,
                use_cache=True,
            )
        seqs = out
        results = []
        for i in range(seqs.size(0)):
            gen_ids = seqs[i, in_lens[i]:]
            text = self.tok.decode(gen_ids, skip_special_tokens=True)
            results.append(self._extract_json(text))
        return results

# =========================
# Stage 2: Program Ingest + Embeddings (GPU)
# =========================
def stage_40_program_chunks(program_path: str, llm_tok: AutoTokenizer) -> List[ProgramChunk]:
    ensure_outdir(OUT_DIR)
    out_path = os.path.join(OUT_DIR, FN_PROG_CHUNKS)
    if os.path.exists(out_path):
        items = []
        with open(out_path, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    items.append(ProgramChunk(**json.loads(line)))
        return items

    full, pages = read_text_any(program_path)
    chunks: List[ProgramChunk] = []
    with open(out_path, "w", encoding="utf-8") as out_f:
        for page_idx, page_text in pages:
            for w_i, (start, end, txt) in enumerate(token_windows(llm_tok, page_text, PROG_INPUT_TOKENS, PROG_OVERLAP_TOKENS)):
                if len(txt.strip()) < MIN_SNIPPET_LEN:  # skip tiny fragments
                    continue
                cid = f"pg{page_idx:03d}_w{w_i:03d}"
                item = ProgramChunk(chunk_id=cid, page=page_idx, start_char=start, end_char=end, text=txt)
                chunks.append(item)
                out_f.write(json.dumps(item.model_dump(), ensure_ascii=False) + "\n")
    return chunks

def stage_45_program_embeddings(chunks: List[ProgramChunk]) -> Tuple[np.ndarray, np.ndarray]:
    out_path = os.path.join(OUT_DIR, FN_PROG_EMB)
    if os.path.exists(out_path):
        z = np.load(out_path, allow_pickle=True)
        return z["embeddings"], z["chunk_ids"]

    print(f"[emb] loading SentenceTransformer on {EMB_DEVICE} ...")
    embedder = SentenceTransformer(EMBED_MODEL_NAME, device=EMB_DEVICE)
    # IMPORTANT for E5: prefix "passage:" (trained that way)
    passages = [f"passage: {c.text}" for c in chunks]
    vecs = embedder.encode(
        passages,
        batch_size=EMBED_BATCH_SIZE,
        convert_to_numpy=True,
        normalize_embeddings=True,
        show_progress_bar=True,
    ).astype(np.float32)

    chunk_ids = np.array([c.chunk_id for c in chunks], dtype=object)
    np.savez_compressed(out_path, embeddings=vecs, chunk_ids=chunk_ids)

    # Free embedder (we'll reuse GPU for LLM)
    del embedder
    free_gpu()
    return vecs, chunk_ids

# =========================
# Stage 3: Retrieval (+ HyDE, cached)
# =========================
HYDE_SYSTEM = (
    "You generate a short hypothetical passage that would be found in a program document relevant to a compliance clause. "
    "Keep it factual-sounding but it's allowed to be hypothetical. Output JSON only."
)
HYDE_USER_TMPL = """Clause (verbatim):
\"\"\"{clause_text}\"\"\"

Binary check:
{question}

Return JSON:
{{"hypo_doc": "70-120 words that would likely appear in a program that addresses or violates this clause; avoid bullet points."}}"""

def load_clauses() -> List[ClauseItem]:
    path = os.path.join(OUT_DIR, FN_CLAUSES_DEDUP)
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                items.append(ClauseItem(**json.loads(line)))
    return items

def load_themes() -> List[ThemeItem]:
    path = os.path.join(OUT_DIR, FN_THEMES)
    with open(path, "r", encoding="utf-8") as f:
        arr = json.load(f)
    return [ThemeItem(**t) for t in arr]

def build_query_for_clause(c: ClauseItem) -> str:
    parts = [c.question_yes_no.strip(), c.text.strip()]
    if c.tags:
        parts.append(", ".join(c.tags[:4]))
    if c.themes:
        parts.append(", ".join(c.themes[:4]))
    q = " ".join(p for p in parts if p).strip()
    return f"query: {q[:600]}"  # E5 "query:" prefix

def hyde_batch_cached(llm: HFChat, clauses: List[ClauseItem]) -> List[str]:
    """Cache HyDE text per clause_id to avoid re-generation on reruns."""
    cache_path = os.path.join(OUT_DIR, FN_HYDE_CACHE)
    cache = {}
    if os.path.exists(cache_path):
        with open(cache_path, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    obj = json.loads(line)
                    cache[obj["clause_id"]] = obj.get("hypo_doc", "")

    hyde_texts = [""] * len(clauses)
    prompts, idxs = [], []
    for i, c in enumerate(clauses):
        if c.clause_id in cache:
            hyde_texts[i] = cache[c.clause_id]
        else:
            prompts.append(HYDE_USER_TMPL.format(clause_text=c.text, question=c.question_yes_no))
            idxs.append(i)

    # batched new generations
    for i in range(0, len(prompts), HYDE_BATCH_SIZE):
        batch_prompts = prompts[i:i + HYDE_BATCH_SIZE]
        objs = llm.generate_json_batch(HYDE_SYSTEM, batch_prompts,
                                       max_new_tokens_override=HYDE_MAX_NEW_TOKENS,
                                       temperature=0.2, top_p=0.95)
        for j, obj in enumerate(objs or []):
            slot = idxs[i + j]
            text = ""
            if obj and isinstance(obj, dict) and "hypo_doc" in obj:
                text = str(obj["hypo_doc"])[:800]
            hyde_texts[slot] = text
            cache[clauses[slot].clause_id] = text

    # persist cache
    if len(idxs) > 0:
        with open(cache_path, "w", encoding="utf-8") as f:
            for cid, t in cache.items():
                f.write(json.dumps({"clause_id": cid, "hypo_doc": t}, ensure_ascii=False) + "\n")

    return hyde_texts

def stage_60_retrieval(chunks: List[ProgramChunk], prog_emb: np.ndarray, clauses: List[ClauseItem], llm: HFChat) -> List[RetrievalRecord]:
    out_path = os.path.join(OUT_DIR, FN_RETRIEVAL)
    if os.path.exists(out_path):
        items = []
        with open(out_path, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    items.append(RetrievalRecord(**json.loads(line)))
        return items

    # Reclaim GPU for embedder if needed (LLM is already loaded; embed on GPU anyway)
    print(f"[retrieval] loading E5 on {EMB_DEVICE} ...")
    embedder = SentenceTransformer(EMBED_MODEL_NAME, device=EMB_DEVICE)

    hyde_texts = [""] * len(clauses)
    if USE_HYDE:
        print("[hyde] generating hypothetical passages (batched) ...")
        hyde_texts = hyde_batch_cached(llm, clauses)

    queries = [build_query_for_clause(c) for c in clauses]
    q_vecs = embedder.encode(queries, batch_size=EMBED_BATCH_SIZE, convert_to_numpy=True,
                             normalize_embeddings=True, show_progress_bar=False).astype(np.float32)

    if USE_HYDE:
        hyde_q = [f"query: {t}" if t else "query: " for t in hyde_texts]
        h_vecs = embedder.encode(hyde_q, batch_size=EMBED_BATCH_SIZE, convert_to_numpy=True,
                                 normalize_embeddings=True, show_progress_bar=False).astype(np.float32)
        q_vecs = (q_vecs + h_vecs) / 2.0

    # free embedder right away to leave VRAM to LLM for Judge
    del embedder
    free_gpu()

    prog_T = prog_emb.T
    recs: List[RetrievalRecord] = []
    with open(out_path, "w", encoding="utf-8") as out_f:
        for i, c in enumerate(clauses):
            sims = (q_vecs[i] @ prog_T).astype(np.float32)
            top_list = []
            if sims.size > 0:
                top_k = min(TOP_K, sims.shape[0])
                top_idx = np.argpartition(-sims, top_k - 1)[:top_k]
                top_idx = top_idx[np.argsort(-sims[top_idx])]
                for idx in top_idx:
                    ch = chunks[idx]
                    top_list.append({"chunk_id": ch.chunk_id, "page": ch.page, "score": float(sims[idx])})
            rec = RetrievalRecord(clause_id=c.clause_id, topk=top_list, hyde_text=(hyde_texts[i] if USE_HYDE else ""))
            recs.append(rec)
            out_f.write(json.dumps(rec.model_dump(), ensure_ascii=False) + "\n")
    return recs

# =========================
# Stage 4: Judge (batched on GPU)
# =========================
JUDGE_SYSTEM = (
    "You are a compliance judge. Use only the provided program excerpts as evidence. "
    "Return JSON only. If evidence is insufficient or contradictory, answer 'Needs-Review' or 'Missing'."
)

JUDGE_USER_TMPL = """CLAUSE:
- id: {clause_id}
- severity: {severity}
- category: {category}
- tags: {tags}
- themes: {themes}
- yes/no question: {question}
- clause text (exact): \"\"\"{clause_text}\"\"\"

PROGRAM CONTEXT (each item = [chunk_id | page]):
{context_block}

Return compact JSON:
{{
  "label": "Compliant|Violation|Missing|Needs-Review",
  "rationale": "one or two short sentences grounded in the evidence",
  "evidence": [
    {{"chunk_id":"...", "page": N, "quote": "exact short quote"}}
  ]
}}"""

def build_context_block(retr: RetrievalRecord, chunk_idx: Dict[str, ProgramChunk], tok: AutoTokenizer) -> Tuple[str, List[str]]:
    snippets, ordered_ids = [], []
    for item in retr.topk:
        ch = chunk_idx[item["chunk_id"]]
        snippet = ch.text.strip().replace("\n", " ")
        if len(snippet) < MIN_SNIPPET_LEN:
            continue
        pretty = f"[{ch.chunk_id} | page {ch.page}] {snippet}"
        snippets.append(pretty); ordered_ids.append(ch.chunk_id)
    kept = clamp_context_by_tokens(tok, snippets, CONTEXT_TOKEN_CAP)
    kept_ids = ordered_ids[:len(kept)]
    return "\n".join(kept) if kept else "(no usable context)", kept_ids

def judge_batch(llm: HFChat, clauses: List[ClauseItem], recs: List[RetrievalRecord],
                chunk_idx: Dict[str, ProgramChunk], sample_id: int) -> List[JudgeSample]:
    rec_by_id = {r.clause_id: r for r in recs}
    users, meta = [], []

    for c in clauses:
        r = rec_by_id.get(c.clause_id)
        ctx_block, kept_ids = ("(no retrieved context)", [])
        if r:
            ctx_block, kept_ids = build_context_block(r, chunk_idx, llm.tok)
        users.append(JUDGE_USER_TMPL.format(
            clause_id=c.clause_id, severity=c.severity, category=c.category,
            tags=", ".join(c.tags[:5]), themes=", ".join(c.themes[:5]),
            question=c.question_yes_no, clause_text=c.text, context_block=ctx_block
        ))
        meta.append((c.clause_id, kept_ids))

    # chunk into micro-batches to saturate GPU but avoid OOM
    out: List[JudgeSample] = []
    for i in range(0, len(users), JUDGE_BATCH_SIZE):
        sub_users = users[i:i+JUDGE_BATCH_SIZE]
        objs = llm.generate_json_batch(JUDGE_SYSTEM, sub_users,
                                       max_new_tokens_override=JUDGE_MAX_NEW_TOKENS,
                                       temperature=JUDGE_TEMPERATURE, top_p=JUDGE_TOP_P)
        for (cid, kept_ids), obj in zip(meta[i:i+JUDGE_BATCH_SIZE], objs):
            try:
                ans = JudgeAnswer(**obj) if obj else JudgeAnswer(label="Needs-Review", rationale="Invalid/empty JSON.", evidence=[])
            except Exception:
                ans = JudgeAnswer(label="Needs-Review", rationale="Bad JSON from model.", evidence=[])
            out.append(JudgeSample(clause_id=cid, sample_id=sample_id, answer=ans, context_chunk_ids=kept_ids))
    return out

def evidence_self_check(sample: JudgeSample, chunk_idx: Dict[str, ProgramChunk]) -> bool:
    ok_any = False
    for ev in sample.answer.evidence:
        cid = ev.get("chunk_id", "")
        quote = (ev.get("quote") or "").strip()
        if not cid or not quote:
            continue
        ch = chunk_idx.get(cid)
        if not ch:
            continue
        text = ch.text
        if quote in text or fuzz.partial_ratio(quote, text) >= 85:
            ok_any = True
        else:
            return False
    if not sample.answer.evidence:
        return sample.answer.label in {"Missing", "Needs-Review"}
    return ok_any

def majority_vote(samples: List[JudgeSample]) -> JudgeMajority:
    cnt = defaultdict(int)
    cid = samples[0].clause_id if samples else "unknown"
    idxs = []
    for s in samples:
        lab = s.answer.label if s.answer.label in VALID_LABELS else "Needs-Review"
        cnt[lab] += 1
        idxs.append(s.sample_id)
    best = sorted(cnt.items(), key=lambda x: (-x[1], x[0]))
    if len(best) >= 2 and best[0][1] == best[1][1]:
        final = "Needs-Review"
    else:
        final = best[0][0] if best else "Needs-Review"
    return JudgeMajority(clause_id=cid, final_label=final, votes=dict(cnt), samples=idxs)

def stage_70_80_judgment(clauses: List[ClauseItem], recs: List[RetrievalRecord],
                         chunks: List[ProgramChunk], llm: HFChat) -> Tuple[List[JudgeSample], List[JudgeMajority]]:
    path_raw = os.path.join(OUT_DIR, FN_JUDGE_RAW)
    path_maj = os.path.join(OUT_DIR, FN_JUDGE_MAJ)
    if os.path.exists(path_raw) and os.path.exists(path_maj):
        raw, maj = [], []
        with open(path_raw, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    obj = json.loads(line)
                    raw.append(JudgeSample(
                        clause_id=obj["clause_id"], sample_id=obj["sample_id"],
                        answer=JudgeAnswer(**obj["answer"]), context_chunk_ids=obj.get("context_chunk_ids", [])
                    ))
        with open(path_maj, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    maj.append(JudgeMajority(**json.loads(line)))
        return raw, maj

    chunk_idx = index_chunks_by_id(chunks)

    all_samples: List[JudgeSample] = []
    for s_id in range(JUDGE_SAMPLES):
        batch = judge_batch(llm, clauses, recs, chunk_idx, sample_id=s_id)
        for smp in batch:
            if not evidence_self_check(smp, chunk_idx):
                smp.answer.label = "Needs-Review"
                smp.answer.rationale = (smp.answer.rationale or "")[:180] + " Evidence unsupported by provided excerpts."
        all_samples.extend(batch)

    by_clause: Dict[str, List[JudgeSample]] = defaultdict(list)
    for s in all_samples:
        by_clause[s.clause_id].append(s)

    majors: List[JudgeMajority] = []
    for cid, arr in by_clause.items():
        majors.append(majority_vote(arr))

    with open(path_raw, "w", encoding="utf-8") as f:
        for s in all_samples:
            f.write(json.dumps({
                "clause_id": s.clause_id,
                "sample_id": s.sample_id,
                "answer": s.answer.model_dump(),
                "context_chunk_ids": s.context_chunk_ids
            }, ensure_ascii=False) + "\n")
    with open(path_maj, "w", encoding="utf-8") as f:
        for m in majors:
            f.write(json.dumps(m.model_dump(), ensure_ascii=False) + "\n")
    return all_samples, majors

# =========================
# ORCHESTRATION
# =========================
def run_audit():
    ensure_outdir(OUT_DIR)

    clauses = load_clauses()
    print(f"[load] clauses: {len(clauses)} from {os.path.join(OUT_DIR, FN_CLAUSES_DEDUP)}")
    if len(clauses) == 0:
        raise RuntimeError("No clauses found. Re-run Stage-1 to extract clauses first.")

    _themes = load_themes()
    print(f"[load] themes: {len(_themes)} from {os.path.join(OUT_DIR, FN_THEMES)}")

    print(f"[llm] loading {MODEL_NAME} on {LLM_DEVICE} (bf16/TF32, flash_attn2/SDPA)...")
    llm = HFChat(HFConfig(model_name=MODEL_NAME, temperature=0.0, max_new_tokens=196, top_p=1.0))

    # Stage 2: ingest + embed
    prog_chunks = stage_40_program_chunks(PROGRAM_INPUT_PATH, llm.tok)
    print(f"[stage-40] program chunks: {len(prog_chunks)} -> {os.path.join(OUT_DIR, FN_PROG_CHUNKS)}")
    prog_emb, prog_ids = stage_45_program_embeddings(prog_chunks)
    print(f"[stage-45] program embeddings: {prog_emb.shape} -> {os.path.join(OUT_DIR, FN_PROG_EMB)}")

    # Stage 3: retrieval (+ HyDE cache)
    recs = stage_60_retrieval(prog_chunks, prog_emb, clauses, llm)
    print(f"[stage-60] retrieval records: {len(recs)} -> {os.path.join(OUT_DIR, FN_RETRIEVAL)}")

    # Stage 4: judgment (batched on GPU)
    raw, maj = stage_70_80_judgment(clauses, recs, prog_chunks, llm)
    print(f"[stage-70] raw judge samples: {len(raw)} -> {os.path.join(OUT_DIR, FN_JUDGE_RAW)}")
    print(f"[stage-80] majority results: {len(maj)} -> {os.path.join(OUT_DIR, FN_JUDGE_MAJ)}")

    # Free LLM after use
    del llm
    free_gpu()

    label_counts = defaultdict(int)
    for m in maj:
        label_counts[m.final_label] += 1
    print("[summary] final labels:", dict(label_counts))


In [12]:
def visualize():

  # -------------------------------------------------------------------
  # 1. Load JSONL files
  # -------------------------------------------------------------------
  clauses_path = OUT_DIR + "/25_clauses_dedup.jsonl"
  judgments_path = OUT_DIR + "/70_judgments_raw.jsonl"

  def load_jsonl(path):
      rows = []
      with open(path, "r", encoding="utf-8") as f:
          for line in f:
              if line.strip():
                  rows.append(json.loads(line))
      return rows

  clauses = pd.DataFrame(load_jsonl(clauses_path))
  judgments = pd.DataFrame(load_jsonl(judgments_path))

  # -------------------------------------------------------------------
  # 2. Normalize judgment.answer.* fields into flat columns
  # -------------------------------------------------------------------
  def normalize_answer(row):
      ans = row.get("answer", {})
      if isinstance(ans, dict):
          return pd.Series({
              "answer_label": ans.get("label", ""),
              "answer_rationale": ans.get("rationale", "")
          })
      return pd.Series({"answer_label": "", "answer_rationale": ""})

  judgments_norm = judgments.join(judgments.apply(normalize_answer, axis=1))

  # -------------------------------------------------------------------
  # 3. Merge on clause_id
  # -------------------------------------------------------------------
  merged = pd.merge(
      judgments_norm,
      clauses,
      on="clause_id",
      how="left",
      suffixes=("_judgment", "_clause")
  )

  # -------------------------------------------------------------------
  # 4. Select only the requested columns
  # -------------------------------------------------------------------
  keep_cols = [
      "sample_id",
      "themes",
      "severity",
      "tags",
      "question_yes_no",
      "answer_label",
      "answer_rationale"
  ]

  df = merged[keep_cols].copy()

  # -------------------------------------------------------------------
  # 5. Clean list columns into comma-separated strings
  # -------------------------------------------------------------------
  list_cols = ["themes", "tags"]
  for col in list_cols:
      df[col] = df[col].apply(lambda x: ", ".join(x) if isinstance(x, list) else x)

  # -------------------------------------------------------------------
  # 6. Rename requested columns
  # -------------------------------------------------------------------
  df = df.rename(columns={
      "question_yes_no": "compliant_question",
      "answer_label": "current_answer"
  })

  # -------------------------------------------------------------------
  # 7. Create dropdown filters (themes, category, severity, current_answer)
  # -------------------------------------------------------------------
  dropdown_js = """
  function updateFilters() {
      var table = $('#data').DataTable();

      var filters = {
          themes: $('#filter_themes').val(),
          severity: $('#filter_severity').val(),
          current_answer: $('#filter_current_answer').val()
      };

      table.columns().every(function() {
          var colIdx = this.index();
          var colName = $(table.column(colIdx).header()).text().trim();

          if (filters[colName]) {
              this.search('^' + filters[colName] + '$', true, false).draw();
          } else {
              this.search('').draw();
          }
      });
  }
  """

  # -------------------------------------------------------------------
  # 8. Build HTML with DataTables
  # -------------------------------------------------------------------
  html_output = OUT_DIR + "/joined_filtered_dropdown_final_renamed.html"

  html = f"""
  <html>
  <head>
      <script src="https://code.jquery.com/jquery-3.5.1.js"></script>
      <link rel="stylesheet"
            href="https://cdn.datatables.net/1.13.6/css/jquery.dataTables.min.css"/>
      <script src="https://cdn.datatables.net/1.13.6/js/jquery.dataTables.min.js"></script>

      <script>
      $(document).ready(function() {{
          $('#data').DataTable({{
              scrollX: true,
              pageLength: 25
          }});

          {dropdown_js}
      }});
      </script>
  </head>
  <body>

  <h2>Interactive Table: Clauses + Judgments</h2>

  <div style="margin-bottom:20px;">
      <label>Themes:</label>
      <select id="filter_themes" onchange="updateFilters()">
          <option value="">All</option>
          {''.join(f'<option value="{v}">{v}</option>' for v in sorted(df['themes'].unique()))}
      </select>



      <label>Severity:</label>
      <select id="filter_severity" onchange="updateFilters()">
          <option value="">All</option>
          {''.join(f'<option value="{v}">{v}</option>' for v in sorted(df['severity'].unique()))}
      </select>

      <label>Current Answer:</label>
      <select id="filter_current_answer" onchange="updateFilters()">
          <option value="">All</option>
          {''.join(f'<option value="{v}">{v}</option>' for v in sorted(df['current_answer'].unique()))}
      </select>
  </div>

  <table id="data" class="display" style="width:100%">
      <thead>
          <tr>
              {''.join(f'<th>{c}</th>' for c in df.columns)}
          </tr>
      </thead>
      <tbody>
  """

  # Add rows
  for _, row in df.iterrows():
      html += "<tr>" + "".join(f"<td>{row[col]}</td>" for col in df.columns) + "</tr>\n"

  html += """
      </tbody>
  </table>

  </body>
  </html>
  """

  with open(html_output, "w", encoding="utf-8") as f:
      f.write(html)

In [11]:
if __name__ == "__main__":
    run_audit()
    visualize()

[load] clauses: 135 from ./out/25_clauses_dedup.jsonl
[load] themes: 0 from ./out/10_themes.json
[llm] loading humain-ai/ALLaM-7B-Instruct-preview on cuda (bf16/TF32, flash_attn2/SDPA)...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/686 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.03G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

[stage-40] program chunks: 1 -> ./out/40_program_chunks.jsonl
[emb] loading SentenceTransformer on cuda ...


modules.json:   0%|          | 0.00/387 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/57.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/418 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

[stage-45] program embeddings: (1, 768) -> ./out/45_program_embeddings.npz
[retrieval] loading E5 on cuda ...
[hyde] generating hypothetical passages (batched) ...
[stage-60] retrieval records: 135 -> ./out/60_retrieval.jsonl
[stage-70] raw judge samples: 405 -> ./out/70_judgments_raw.jsonl
[stage-80] majority results: 135 -> ./out/80_judgments_majority.jsonl
[summary] final labels: {'Missing': 61, 'Needs-Review': 72, 'Compliant': 2}


In [None]:
import json
import pandas as pd
from pathlib import Path

# -----------------------
# Paths
# -----------------------
clauses_path = "/mnt/data/20_clauses_raw.jsonl"
judgments_path = "/mnt/data/70_judgments_raw.jsonl"
out_html = "/mnt/data/joined_filtered_dropdown_final_renamed.html"

# -----------------------
# Load JSONL
# -----------------------
def load_jsonl(path):
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rows.append(json.loads(line))
            except Exception:
                # skip malformed lines
                continue
    return pd.DataFrame(rows)

clauses = load_jsonl(clauses_path)
judgments = load_jsonl(judgments_path)

# -----------------------
# Normalize `answer` into flat columns
# -----------------------
def safe_extract_answer(ans, key):
    try:
        return ans.get(key, "") if isinstance(ans, dict) else ""
    except Exception:
        return ""

judgments["answer_label"] = judgments["answer"].apply(lambda a: safe_extract_answer(a, "label"))
judgments["answer_rationale"] = judgments["answer"].apply(lambda a: safe_extract_answer(a, "rationale"))

# -----------------------
# Merge on clause_id
# -----------------------
if "clause_id" in judgments.columns and "clause_id" in clauses.columns:
    merged = pd.merge(judgments, clauses, on="clause_id", how="inner")
else:
    merged = pd.DataFrame()

# -----------------------
# Prepare final dataframe and flatten lists to strings
# -----------------------
keep = [
    "sample_id",
    "themes",
    "category",
    "severity",
    "tags",
    "question_yes_no",
    "expected_compliant_answer",
    "answer_label",
    "answer_rationale",
]

# Ensure columns exist
for c in keep:
    if c not in merged.columns:
        merged[c] = ""

df = merged[keep].copy()

# Flatten list columns into comma-separated strings
for col in ["themes", "tags"]:
    if col in df.columns:
        df[col] = df[col].apply(lambda x: ", ".join(x) if isinstance(x, list) else ("" if pd.isna(x) else x))

# Rename columns as requested previously
df = df.rename(columns={
    "question_yes_no": "compliant_question",
    "answer_label": "current_answer",
    "answer_rationale": "answer_rationale"  # keep same name
})

# -----------------------
# Helper to build <option> list safely
# -----------------------
def build_options(series):
    vals = [v for v in pd.unique(series.dropna()) if v != ""]
    # Convert to strings and dedupe/sort
    vals = sorted({str(v).strip() for v in vals})
    return "\n".join(f'<option value="{html_escape(v)}">{html_escape(v)}</option>' for v in vals)

def html_escape(s):
    return (str(s)
            .replace("&", "&amp;")
            .replace("<", "&lt;")
            .replace(">", "&gt;")
            .replace('"', "&quot;")
            .replace("'", "&#39;"))

# -----------------------
# Build HTML with robust JS
# -----------------------
# Determine if filter columns exist and compute option HTML
themes_opts = build_options(df["themes"]) if "themes" in df.columns else ""
category_opts = build_options(df["category"]) if "category" in df.columns else ""
severity_opts = build_options(df["severity"]) if "severity" in df.columns else ""
current_answer_opts = build_options(df["current_answer"]) if "current_answer" in df.columns else ""

# Build table header and rows
headers_html = "".join(f"<th>{html_escape(c)}</th>" for c in df.columns)
rows_html = ""
for _, r in df.iterrows():
    rows_html += "<tr>"
    for c in df.columns:
        cell = r[c]
        if pd.isna(cell):
            cell = ""
        rows_html += f"<td>{html_escape(cell)}</td>"
    rows_html += "</tr>\n"

html_doc = f"""<!doctype html>
<html>
<head>
<meta charset="utf-8"/>
<title>Joined Filtered Table (fixed dropdowns)</title>
<link rel="stylesheet" href="https://cdn.datatables.net/1.13.6/css/jquery.dataTables.min.css"/>
<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
<script src="https://cdn.datatables.net/1.13.6/js/jquery.dataTables.min.js"></script>
<style>
  table.dataTable tbody td {{ vertical-align: top; white-space: normal; }}
  .filters {{ margin-bottom: 12px; gap:10px; }}
  label {{ margin-right:6px; }}
</style>
</head>
<body>
<h2>Joined Filtered Table (dropdowns fixed)</h2>

<div class="filters">
  <label>Themes:</label>
  <select id="filter_themes"><option value="">All</option>
    {themes_opts}
  </select>

  <label>Category:</label>
  <select id="filter_category"><option value="">All</option>
    {category_opts}
  </select>

  <label>Severity:</label>
  <select id="filter_severity"><option value="">All</option>
    {severity_opts}
  </select>

  <label>Current Answer:</label>
  <select id="filter_current_answer"><option value="">All</option>
    {current_answer_opts}
  </select>

  <button id="clear_filters">Clear</button>
</div>

<table id="joined_table" class="display" style="width:100%">
  <thead>
    <tr>{headers_html}</tr>
  </thead>
  <tbody>
    {rows_html}
  </tbody>
</table>

<script>
(function() {{
  // escape regex helper
  function escRE(s) {{
    return s.replace(/[.*+?^${{}}()|[\\]\\\\]/g, "\\\\$&");
  }}

  $(document).ready(function() {{
    var table = $('#joined_table').DataTable({{
      scrollX: true,
      pageLength: 25,
      autoWidth: false
    }});

    // build header->index map (trimmed lowercase keys for robustness)
    var headerToIndex = {{}};
    table.columns().every(function(idx) {{
      var txt = $($(table.column(idx).header())).text().trim();
      headerToIndex[txt.toLowerCase()] = idx;
    }});

    // helper to apply exact-match filter on a header name
    function applyExactFilter(headerName, value) {{
      var idx = headerToIndex[headerName.toLowerCase()];
      if (typeof idx === 'undefined') return;
      if (!value) {{
        table.column(idx).search('').draw();
      }} else {{
        var rx = '^' + escRE(value) + '$';
        table.column(idx).search(rx, true, false).draw();
      }}
    }}

    // wire selects
    $('#filter_themes').on('change', function() {{
      applyExactFilter('themes', this.value);
    }});
    $('#filter_category').on('change', function() {{
      applyExactFilter('category', this.value);
    }});
    $('#filter_severity').on('change', function() {{
      applyExactFilter('severity', this.value);
    }});
    $('#filter_current_answer').on('change', function() {{
      // header name for current answer column (exact text in the table header)
      // try 'current_answer' then 'answer_label' as fallback
      if (headerToIndex['current_answer'] !== undefined) {{
        applyExactFilter('current_answer', this.value);
      }} else {{
        applyExactFilter('answer.label', this.value);
      }}
    }});

    // clear filters button
    $('#clear_filters').on('click', function() {{
      $('#filter_themes, #filter_category, #filter_severity, #filter_current_answer').val('');
      table.search('').columns().search('').draw();
    }});
  }});
}})();
</script>

</body>
</html>
"""

# Save output
Path(out_html).write_text(html_doc, encoding="utf-8")
print("Wrote:", out_html)
