In [None]:
!pip install gliner

Collecting gliner
  Downloading gliner-0.2.24-py3-none-any.whl.metadata (10 kB)
Collecting onnxruntime (from gliner)
  Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting coloredlogs (from onnxruntime->gliner)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime->gliner)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading gliner-0.2.24-py3-none-any.whl (151 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m151.9/151.9 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (17.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m87.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import json
import re
from pathlib import Path
from typing import Dict, Any, List, Set, Tuple, Optional

# -----------------------------
# 1) Regex patterns (fast path)
# -----------------------------

# Matches:
#   Sec. 2.5.16.1
#   Sec 2.5.7.7
#   Section 2.5.9.1
#   Sections 2.5.9.1 and 2.5.11.1
SEC_PATTERN = re.compile(
    r"\b(?:Sec(?:tion)?s?|Sections?)\.?\s*"
    r"(?P<ids>\d+(?:\.\d+){1,6}(?:\s*(?:,|and|&)\s*\d+(?:\.\d+){1,6})*)",
    flags=re.IGNORECASE
)

# Matches:
#   Eq. 6.2.66
#   Eq 6.2.68
#   Equation 6.2.63
EQ_PATTERN = re.compile(
    r"\b(?:Eq(?:uation)?s?)\.?\s*"
    r"(?P<ids>\d+(?:\.\d+){1,6}(?:\s*(?:,|and|&)\s*\d+(?:\.\d+){1,6})*)",
    flags=re.IGNORECASE
)

# Matches:
#   Table 6.2.23
#   Tables 6.2.22 and 6.2.23
TABLE_PATTERN = re.compile(
    r"\bTable(?:s)?\.?\s*"
    r"(?P<ids>\d+(?:\.\d+){1,6}(?:\s*(?:,|and|&)\s*\d+(?:\.\d+){1,6})*)",
    flags=re.IGNORECASE
)

# Matches:
#   Fig. 3.1
#   Figure 3.1.2
FIG_PATTERN = re.compile(
    r"\b(?:Fig(?:ure)?s?)\.?\s*"
    r"(?P<ids>\d+(?:\.\d+){1,6}(?:\s*(?:,|and|&)\s*\d+(?:\.\d+){1,6})*)",
    flags=re.IGNORECASE
)

# Helpful when OCR drops punctuation:
# "Sec 2.5.7.7" is already covered, but if you see "Sec2.5.7.7" you can extend:
SEC_NO_SPACE_PATTERN = re.compile(
    r"\bSec(?:tion)?\.?(?P<id>\d+(?:\.\d+){1,6})\b",
    flags=re.IGNORECASE
)

SPLIT_IDS_PATTERN = re.compile(r"\s*(?:,|and|&)\s*", flags=re.IGNORECASE)

def normalize_dotted_id(raw: str) -> Optional[str]:
    """
    Force to strict dotted numeric: '2.5.16.1' etc.
    Returns None if it doesn't look like a dotted numeric reference.
    """
    raw = raw.strip().strip(").,;:")
    # Keep only digits and dots
    cleaned = re.sub(r"[^0-9.]", "", raw)
    # Must contain at least one dot and not end with dot
    if "." not in cleaned:
        return None
    cleaned = cleaned.strip(".")
    # Validate segments are digits
    parts = cleaned.split(".")
    if not all(p.isdigit() for p in parts):
        return None
    # Remove leading zeros in segments (optional)
    parts = [str(int(p)) for p in parts]
    return ".".join(parts)

def extract_by_pattern(text: str, pattern: re.Pattern) -> Set[str]:
    refs: Set[str] = set()
    for m in pattern.finditer(text):
        ids_blob = m.group("ids")
        for part in SPLIT_IDS_PATTERN.split(ids_blob):
            norm = normalize_dotted_id(part)
            if norm:
                refs.add(norm)
    return refs

def extract_refs_regex(text: str) -> Dict[str, Set[str]]:
    if not text:
        return {"sections": set(), "tables": set(), "figures": set(), "equations": set()}

    sections = extract_by_pattern(text, SEC_PATTERN) | {
        normalize_dotted_id(m.group("id")) for m in SEC_NO_SPACE_PATTERN.finditer(text)
        if normalize_dotted_id(m.group("id"))
    }
    tables = extract_by_pattern(text, TABLE_PATTERN)
    figures = extract_by_pattern(text, FIG_PATTERN)
    equations = extract_by_pattern(text, EQ_PATTERN)

    # Remove any Nones (just in case)
    sections.discard(None)
    tables.discard(None)
    figures.discard(None)
    equations.discard(None)

    return {"sections": sections, "tables": tables, "figures": figures, "equations": equations}

# ----------------------------------------
# 2) Optional: small LLM fallback (local)
# ----------------------------------------
# Only needed if you have messy references you can't capture with rules.
# Example models: "microsoft/Phi-3-mini-4k-instruct" (3.8B),
#                 "Qwen/Qwen2.5-3B-Instruct"
#
# If you don't want LLM fallback, set USE_LLM_FALLBACK = False

USE_LLM_FALLBACK = False

def llm_fallback_extract(text: str) -> Dict[str, List[str]]:
    """
    If regex finds nothing (or you want to improve recall), you can call a small model here.
    This function is written to be drop-in; you only enable it if you have the model available locally.
    Output must be strict dotted numeric IDs.
    """
    if not USE_LLM_FALLBACK:
        return {"sections": [], "tables": [], "figures": [], "equations": []}

    from transformers import AutoTokenizer, AutoModelForCausalLM
    import torch

    model_name = "microsoft/Phi-3-mini-4k-instruct"  # change if you prefer another small model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto"
    )

    prompt = f"""
You extract cross-references from engineering code text.

Return ONLY valid JSON with keys:
sections, tables, figures, equations.
Each value is an array of STRICT dotted numeric ids like "2.5.16.1" or "6.2.66".
Do not include words like "Sec" or "Table". Do not include non-dotted items.
If nothing found, return empty arrays.

TEXT:
{text}
""".strip()

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    out = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False,
        temperature=0.0,
        eos_token_id=tokenizer.eos_token_id
    )
    raw = tokenizer.decode(out[0], skip_special_tokens=True)

    # Heuristic: find first JSON object in the output
    json_start = raw.find("{")
    json_end = raw.rfind("}")
    if json_start == -1 or json_end == -1 or json_end <= json_start:
        return {"sections": [], "tables": [], "figures": [], "equations": []}

    candidate = raw[json_start:json_end+1]
    try:
        parsed = json.loads(candidate)
    except json.JSONDecodeError:
        return {"sections": [], "tables": [], "figures": [], "equations": []}

    # Normalize/validate strict dotted form
    def norm_list(xs):
        out = []
        for x in xs or []:
            n = normalize_dotted_id(str(x))
            if n:
                out.append(n)
        # unique while preserving order
        seen = set()
        uniq = []
        for a in out:
            if a not in seen:
                uniq.append(a)
                seen.add(a)
        return uniq

    return {
        "sections": norm_list(parsed.get("sections")),
        "tables": norm_list(parsed.get("tables")),
        "figures": norm_list(parsed.get("figures")),
        "equations": norm_list(parsed.get("equations")),
    }

# -----------------------------------
# 3) Walk your clause JSON and enrich
# -----------------------------------

def iter_list_texts(node: Dict[str, Any]) -> List[str]:
    """
    Pull text from nested list items:
    node["lists"] is a list of groups, each has "items" which can be nested.
    """
    out = []
    def walk_items(items):
        for it in items or []:
            t = it.get("text") or ""
            if t:
                out.append(t)
            walk_items(it.get("children") or [])
    for group in node.get("lists") or []:
        walk_items(group.get("items") or [])
    return out

def merge_refs(a: Dict[str, Set[str]], b: Dict[str, Set[str]]) -> Dict[str, Set[str]]:
    return {
        "sections": set(a["sections"]) | set(b["sections"]),
        "tables": set(a["tables"]) | set(b["tables"]),
        "figures": set(a["figures"]) | set(b["figures"]),
        "equations": set(a["equations"]) | set(b["equations"]),
    }

def enrich_nodes_with_references(data: Dict[str, Any]) -> Dict[str, Any]:
    nodes: Dict[str, Any] = data.get("nodes", {})
    for node_id, node in nodes.items():
        blob_parts = []
        if node.get("title"):
            blob_parts.append(node["title"])
        if node.get("text"):
            blob_parts.append(node["text"])
        blob_parts.extend(iter_list_texts(node))
        blob = "\n".join(blob_parts).strip()

        refs = extract_refs_regex(blob)

        # Optional LLM fallback if regex missed everything
        if USE_LLM_FALLBACK:
            if not (refs["sections"] or refs["tables"] or refs["figures"] or refs["equations"]):
                llm_refs = llm_fallback_extract(blob)
                refs = {
                    "sections": set(llm_refs["sections"]),
                    "tables": set(llm_refs["tables"]),
                    "figures": set(llm_refs["figures"]),
                    "equations": set(llm_refs["equations"]),
                }

        # Store as JSON-friendly arrays
        node["references"] = {
            "sections": sorted(refs["sections"], key=lambda s: [int(x) for x in s.split(".")]),
            "tables": sorted(refs["tables"], key=lambda s: [int(x) for x in s.split(".")]),
            "figures": sorted(refs["figures"], key=lambda s: [int(x) for x in s.split(".")]),
            "equations": sorted(refs["equations"], key=lambda s: [int(x) for x in s.split(".")]),
        }

    return data

if __name__ == "__main__":
    in_path = Path("/content/structured_clauses.json")
    out_path = Path("/content/structured_clauses_with_refs.json")

    data = json.loads(in_path.read_text(encoding="utf-8"))
    data = enrich_nodes_with_references(data)

    out_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"Wrote: {out_path}")


Wrote: /content/structured_clauses_with_refs.json


In [None]:
!pip -q install transformers accelerate torch


In [None]:
!pip -q install bitsandbytes


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import json, re
from pathlib import Path
from typing import Any, Dict, List, Optional

# -----------------------------
# Regex patterns (fast path)
# -----------------------------
SEC_PATTERN = re.compile(
    r"\b(?:Sec(?:tion)?s?|Sections?)\.?\s*"
    r"(?P<ids>\d+(?:\.\d+){0,8}(?:\s*(?:,|and|&|to|through|-|–)\s*\d+(?:\.\d+){0,8})*)",
    flags=re.IGNORECASE
)
EQ_PATTERN = re.compile(
    r"\b(?:Eq(?:uation)?s?)\.?\s*"
    r"(?P<ids>\d+(?:\.\d+){0,8}(?:\s*(?:,|and|&|to|through|-|–)\s*\d+(?:\.\d+){0,8})*)",
    flags=re.IGNORECASE
)
TABLE_PATTERN = re.compile(
    r"\bTable(?:s)?\.?\s*"
    r"(?P<ids>\d+(?:\.\d+){0,8}(?:\s*(?:,|and|&|to|through|-|–)\s*\d+(?:\.\d+){0,8})*)",
    flags=re.IGNORECASE
)
FIG_PATTERN = re.compile(
    r"\b(?:Fig(?:ure)?s?)\.?\s*"
    r"(?P<ids>\d+(?:\.\d+){0,8}(?:\s*(?:,|and|&|to|through|-|–)\s*\d+(?:\.\d+){0,8})*)",
    flags=re.IGNORECASE
)
CHAPTER_PATTERN = re.compile(r"\bChapter\s+(?P<ch>\d+)\b", flags=re.IGNORECASE)

SPLIT_IDS_PATTERN = re.compile(r"\s*(?:,|and|&)\s*", flags=re.IGNORECASE)
RANGE_SEP_PATTERN = re.compile(r"\s*(?:to|through|-|–)\s*", flags=re.IGNORECASE)

# -----------------------------
# Normalization helpers
# -----------------------------
def normalize_numeric_path(raw: str, force_chapter_dot: bool = False) -> Optional[str]:
    """
    Normalize to a strict numeric path:
      - "2.5.16.1" stays "2.5.16.1"
      - "02.05.016.001" -> "2.5.16.1"
      - "6" stays "6" unless force_chapter_dot=True -> "6.0"
    """
    if raw is None:
        return None
    s = str(raw).strip().strip(").,;:[]{}()")
    s = re.sub(r"[^0-9.]", "", s).strip(".")
    if not s:
        return None
    parts = s.split(".")
    if not all(p.isdigit() for p in parts):
        return None
    parts_int = [str(int(p)) for p in parts]
    if len(parts_int) == 1 and force_chapter_dot:
        return parts_int[0] + ".0"
    return ".".join(parts_int)

def try_expand_simple_range(a: str, b: str) -> List[str]:
    """
    Expand ranges only when safe: prefix same and only last segment differs.
    Example: 2.5.9.1–2.5.9.5 expands to 2.5.9.1..2.5.9.5
    """
    if a == b:
        return [a]
    pa, pb = a.split("."), b.split(".")
    if len(pa) != len(pb) or len(pa) < 2 or pa[:-1] != pb[:-1]:
        return [a, b]
    try:
        start, end = int(pa[-1]), int(pb[-1])
    except ValueError:
        return [a, b]
    if start > end:
        start, end = end, start
    if (end - start) > 200:  # guardrail
        return [a, b]
    prefix = ".".join(pa[:-1])
    return [f"{prefix}.{i}" for i in range(start, end + 1)]

def extract_ids_from_match(ids_blob: str, force_chapter_dot: bool) -> List[str]:
    out: List[str] = []
    chunks = SPLIT_IDS_PATTERN.split(ids_blob)
    for chunk in chunks:
        chunk = chunk.strip()
        if not chunk:
            continue
        range_parts = [p.strip() for p in RANGE_SEP_PATTERN.split(chunk) if p.strip()]
        if len(range_parts) == 2:
            a = normalize_numeric_path(range_parts[0], force_chapter_dot=force_chapter_dot)
            b = normalize_numeric_path(range_parts[1], force_chapter_dot=force_chapter_dot)
            if a and b:
                out.extend(try_expand_simple_range(a, b))
            else:
                if a: out.append(a)
                if b: out.append(b)
        else:
            n = normalize_numeric_path(chunk, force_chapter_dot=force_chapter_dot)
            if n:
                out.append(n)
    # unique preserve order
    seen, uniq = set(), []
    for x in out:
        if x not in seen:
            uniq.append(x); seen.add(x)
    return uniq

def extract_refs_regex(text: str, force_chapter_dot: bool = False) -> Dict[str, List[str]]:
    if not text:
        return {"sections": [], "tables": [], "figures": [], "equations": [], "chapter_refs": []}

    sections, tables, figures, equations, chapters = [], [], [], [], []

    for m in SEC_PATTERN.finditer(text):
        sections.extend(extract_ids_from_match(m.group("ids"), force_chapter_dot))
    for m in TABLE_PATTERN.finditer(text):
        tables.extend(extract_ids_from_match(m.group("ids"), force_chapter_dot))
    for m in FIG_PATTERN.finditer(text):
        figures.extend(extract_ids_from_match(m.group("ids"), force_chapter_dot))
    for m in EQ_PATTERN.finditer(text):
        equations.extend(extract_ids_from_match(m.group("ids"), force_chapter_dot))
    for m in CHAPTER_PATTERN.finditer(text):
        ch = normalize_numeric_path(m.group("ch"), force_chapter_dot=force_chapter_dot)
        if ch:
            chapters.append(ch)

    def uniq(xs):
        seen, out = set(), []
        for x in xs:
            if x not in seen:
                out.append(x); seen.add(x)
        return out

    return {
        "sections": uniq(sections),
        "tables": uniq(tables),
        "figures": uniq(figures),
        "equations": uniq(equations),
        "chapter_refs": uniq(chapters),
    }

# -----------------------------
# JSON walking (your schema)
# -----------------------------
def collect_list_texts(node: Dict[str, Any]) -> List[str]:
    out: List[str] = []
    def walk(items):
        for it in items or []:
            t = it.get("text")
            if t:
                out.append(str(t))
            walk(it.get("children") or [])
    for group in node.get("lists") or []:
        walk(group.get("items") or [])
    return out

def get_node_text_blob(node: Dict[str, Any]) -> str:
    parts: List[str] = []
    if node.get("title"):
        parts.append(str(node["title"]))
    if node.get("text"):
        parts.append(str(node["text"]))
    parts.extend(collect_list_texts(node))
    return "\n".join(p.strip() for p in parts if p and str(p).strip()).strip()

# -----------------------------
# Small LLM extraction
# -----------------------------
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

def build_llm_prompt(text: str, current_id: str) -> str:
    return f"""
You extract cross-references from a technical specification clause.

CURRENT CLAUSE ID: {current_id}

Identify references to:
- other clauses/sections/subsections (numeric paths like 2.5.16.1)
- chapters (e.g., "Chapter 6" -> 6)
- tables (e.g., "Table 6.2.24" -> 6.2.24)
- figures (e.g., "Figure 3.1.2" -> 3.1.2)
- equations (e.g., "Eq. 6.2.66" -> 6.2.66)

RULES:
1) Return ONLY valid JSON. No commentary.
2) DO NOT guess or hallucinate.
3) Output IDs as STRICT numeric paths (digits and dots only) OR chapter-only numeric (e.g., "6").
4) If text says "this chapter/this section/above/below" without numbers, output nothing for that.
5) If a range is present (2.5.9.1-2.5.9.5), include BOTH endpoints only (do not expand).
6) Deduplicate.

Return JSON exactly:
{{
  "sections": [],
  "chapter_refs": [],
  "tables": [],
  "figures": [],
  "equations": []
}}

TEXT:
\"\"\"{text}\"\"\"
""".strip()

def parse_json_object_from_generation(gen: str) -> Optional[Dict[str, Any]]:
    start = gen.find("{")
    end = gen.rfind("}")
    if start == -1 or end == -1 or end <= start:
        return None
    try:
        return json.loads(gen[start:end+1])
    except json.JSONDecodeError:
        return None

def normalize_llm_output(parsed: Dict[str, Any], force_chapter_dot: bool = False) -> Dict[str, List[str]]:
    def norm_list(xs: Any) -> List[str]:
        out: List[str] = []
        if not isinstance(xs, list):
            return out
        for v in xs:
            s = str(v).strip()
            if not s:
                continue
            # keep endpoints for ranges if model returns "a-b"
            if "-" in s or "–" in s:
                parts = re.split(r"\s*[-–]\s*", s)
                if len(parts) == 2:
                    a = normalize_numeric_path(parts[0], force_chapter_dot=force_chapter_dot)
                    b = normalize_numeric_path(parts[1], force_chapter_dot=force_chapter_dot)
                    if a: out.append(a)
                    if b: out.append(b)
                    continue
            n = normalize_numeric_path(s, force_chapter_dot=force_chapter_dot)
            if n:
                out.append(n)
        # unique preserve order
        seen, uniq = set(), []
        for x in out:
            if x not in seen:
                uniq.append(x); seen.add(x)
        return uniq

    return {
        "sections": norm_list(parsed.get("sections", [])),
        "chapter_refs": norm_list(parsed.get("chapter_refs", [])),
        "tables": norm_list(parsed.get("tables", [])),
        "figures": norm_list(parsed.get("figures", [])),
        "equations": norm_list(parsed.get("equations", [])),
    }

class LLMExtractor:
    def __init__(self, model_name: str, max_new_tokens: int = 256, temperature: float = 0.0):
        self.model_name = model_name
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

        dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=dtype,
            load_in_8bit=True if "bitsandbytes" in str(type(__import__("importlib").import_module("importlib"))) else False,
        )

        if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def extract_batch(self, texts: List[str], current_ids: List[str]) -> List[Dict[str, Any]]:
        prompts = [build_llm_prompt(t, cid) for t, cid in zip(texts, current_ids)]
        inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=3500)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

        gen = self.model.generate(
            **inputs,
            max_new_tokens=self.max_new_tokens,
            do_sample=(self.temperature > 0),
            temperature=self.temperature,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
        )
        decoded = self.tokenizer.batch_decode(gen, skip_special_tokens=True)
        out = []
        for d in decoded:
            out.append(parse_json_object_from_generation(d) or {})
        return out

def sort_numeric_paths(xs: List[str]) -> List[str]:
    def key_fn(s: str):
        return [int(p) for p in s.split(".")] if s else []
    return sorted(xs, key=key_fn)

def union_preserve_order(a: List[str], b: List[str]) -> List[str]:
    seen, out = set(), []
    for x in a + b:
        if x not in seen:
            out.append(x); seen.add(x)
    return out

def enrich_with_refs(
    data: Dict[str, Any],
    llm: Optional[LLMExtractor],
    batch_size: int = 4,
    force_chapter_dot: bool = False,
    llm_on_all_nodes: bool = False,
) -> Dict[str, Any]:
    nodes: Dict[str, Any] = data.get("nodes", {})
    node_ids = list(nodes.keys())

    blobs = {}
    regex_refs = {}

    for nid in node_ids:
        blob = get_node_text_blob(nodes[nid])
        blobs[nid] = blob
        regex_refs[nid] = extract_refs_regex(blob, force_chapter_dot=force_chapter_dot)

    llm_targets = []
    if llm is not None:
        if llm_on_all_nodes:
            llm_targets = node_ids
        else:
            cue = re.compile(r"(chapter|sec|section|table|fig|figure|eq|equation|\d+\.\d+)", re.I)
            for nid in node_ids:
                rr = regex_refs[nid]
                if any(rr[k] for k in ["sections", "tables", "figures", "equations", "chapter_refs"]):
                    continue
                if cue.search(blobs[nid]):
                    llm_targets.append(nid)

    llm_outputs, llm_meta = {}, {}
    if llm is not None and llm_targets:
        for i in range(0, len(llm_targets), batch_size):
            batch = llm_targets[i:i+batch_size]
            texts = [blobs[n] for n in batch]
            parsed_list = llm.extract_batch(texts, batch)
            for nid, parsed in zip(batch, parsed_list):
                norm = normalize_llm_output(parsed, force_chapter_dot=force_chapter_dot) if parsed else {
                    "sections": [], "chapter_refs": [], "tables": [], "figures": [], "equations": []
                }
                llm_outputs[nid] = norm
                llm_meta[nid] = {"model": llm.model_name, "parsed_ok": bool(parsed)}

    for nid in node_ids:
        rr = regex_refs[nid]
        lr = llm_outputs.get(nid, {"sections": [], "chapter_refs": [], "tables": [], "figures": [], "equations": []})

        nodes[nid]["references"] = {
            "sections": sort_numeric_paths(union_preserve_order(rr["sections"], lr["sections"])),
            "chapter_refs": sort_numeric_paths(union_preserve_order(rr["chapter_refs"], lr["chapter_refs"])),
            "tables": sort_numeric_paths(union_preserve_order(rr["tables"], lr["tables"])),
            "figures": sort_numeric_paths(union_preserve_order(rr["figures"], lr["figures"])),
            "equations": sort_numeric_paths(union_preserve_order(rr["equations"], lr["equations"])),
            "provenance": {"regex": rr, "llm": lr if nid in llm_outputs else None},
            "llm_meta": llm_meta.get(nid, {"model": None, "parsed_ok": False}),
        }

    data["nodes"] = nodes
    return data


In [None]:
MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"   # or "microsoft/Phi-3-mini-4k-instruct"

llm = LLMExtractor(model_name=MODEL_NAME, max_new_tokens=256, temperature=0.0)
input_path = "/content/structured_clauses.json"

data = json.loads(Path(input_path).read_text(encoding="utf-8"))

# Set llm_on_all_nodes=True if you want max recall (slower/more expensive)
data2 = enrich_with_refs(
    data,
    llm=llm,
    batch_size=4,
    force_chapter_dot=False,   # set True if you want "6" -> "6.0"
    llm_on_all_nodes=False
)

output_path = "structured_clauses_with_refs.json"
Path(output_path).write_text(json.dumps(data2, ensure_ascii=False, indent=2), encoding="utf-8")
print("Wrote:", output_path)


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


KeyboardInterrupt: 