In [None]:

# step0: patent.md 

# step1:  split to :  figs_MetaDict.json   full_filtered.md  -split
#  文本信息、图片引用 可以嵌入了
# -> full_split.md
# step2: full_filtered.md -->  struct: # 著录信息  # 权利要求书   # 说明书 
#   ---> str
# -> full_split_split.md
# step3: extractor                     record      claims         specification
#                         def record_extractor  claims_extractor  specification_extractor
#   --->  norm_(str) 
#  
# --> full_split_split_norm.md
# 代码逻辑可能存在复用的可能  code ++

# [xxxx] 以开头的段落  去掉[xxxx]   todo.
# 无关的内容也可以去掉，

In [None]:
# ingestion-1  embed-model 

from llama_index.embeddings.huggingface import HuggingFaceEmbedding 
from llama_index.core import Settings

embedding = HuggingFaceEmbedding(
    model_name="Qwen/Qwen3-Embedding-0.6B",
    device="cpu",                 # 建议放顶层
    cache_folder=r"E:\local_models\huggingface\cache\hub",
    trust_remote_code=True,       # 建议放顶层
    model_kwargs={"local_files_only": True},   # 允许联网 False
)

# 
Settings.embed_model = embedding
Settings.llm = None 

In [None]:
# ingestion-2  load-file -> str 

from pathlib import Path 
import json 
from typing import Tuple, Dict, Any

# load md-text json-figs 
def load_md_and_figs(md_path: Path, figs_name: str="figs.json") -> Tuple[str, Dict[str, Any]]:
    text = md_path.read_text(encoding='utf-8', errors='ignore')
    fj = md_path.with_name(figs_name)
    assert fj.is_file()
    with open(fj, "r",encoding='utf-8') as f:
        figs = json.load(f)
    return text, figs 

# test 
data_root = Path.cwd().parent / ".log/SimplePDF"
assert Path(data_root).is_dir()
mdfs: Path = next(Path(data_root).rglob('full_split_struct.md'), None)
assert  Path(mdfs).exists()
figs: Path = Path(mdfs).with_name("figs.json")
assert figs.is_file()
text_, figs_ = load_md_and_figs(md_path=mdfs)
print("纯文本长度：", len(text_))
print("figs.json 键：", list(figs_.keys()))
print("ims_desc 示例：", list((figs_.get("ims_desc") or {}).items())[:3])
print("ims_absp 示例：", list((figs_.get("ims_desc") or {}).items())[:3])
print("ims_annos 示例：", [(figs_.get("ims_annos") or "")])


In [None]:
# ingestion-2  content-str -> node 

import uuid 
from typing import List 

from llama_index.core.schema import TextNode
from llama_index.core.node_parser import SentenceSplitter

def _to_ascii_digits(s: str) -> str:
    _DIGIT_TRANS = str.maketrans("０１２３４５６７８９", "0123456789")
    return (s or "").translate(_DIGIT_TRANS)


def build_text_node_from_markdown(text: str, doc_id: str, 
                                  chunk_size: int=700, 
                                  chunk_overlap: int=128) -> List[TextNode]:
    splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    chunks  = splitter.split_text(text)
    nodes: List[TextNode] = []
    for i, ch in enumerate(chunks, 1):
        nodes.append(TextNode(
            text = ch,
            id_ = f"{doc_id}::text::{i}",
            metadata = {
                "doc_id": doc_id,
                "node_type": "text",
                "chunk_id": i
            },
        ))
    return nodes 
    
def build_figure_node_from_figs(figs: Dict[str, Any], doc_id: str) -> List[TextNode]:
    nodes: List[TextNode] = []
    
    ims_desc: Dict[str, str] = figs.get("ims_desc", {}) or {}
    ims_absp: Dict[str, str] = figs.get("ims_absp", {}) or {}
    ims_bs64: Dict[str, str] = figs.get("ims_bs64", {}) or {}
    ims_annos: str = (figs.get("ims_annos") or "").strip()
    
    # 摘要图
    im_abs = figs.get("im_abs") or []
    if isinstance(im_abs, List) and im_abs[0]:
        abs_path = im_abs[0] if isinstance(im_abs[0], str) else "" 
        assert Path(abs_path).is_file(), f"摘要图{abs_path}路径不存在"
        text_for_embed = "摘要图"
        nodes.append(TextNode(
            text = text_for_embed,
            id_ = f"{doc_id}::figure::0",
            metadata = {
                "doc_id": doc_id,
                "node_type": "figure",
                "fig_no": "0",
                "fig_desc": "摘要图",
                "fig_path": abs_path,
                "fig_bs64": "B64(omitted)" if im_abs[1:] else "",  # ---
                "fig_annos": ims_annos,
                "display_text": "摘要图",    # 前端用
            },
            excluded_embed_metadata_keys = ["fig_path", "fig_b64", "fig_annos", "display_text"],     
            excluded_llm_metadata_keys = ["fig_b64"],            
        ))
    
    # 普通图
    def _key_sorter(k: str) -> str:
        try:
            return int(_to_ascii_digits(k))
        except:
            return 10**9
    
    for k in sorted(ims_desc.keys(), key=_key_sorter):
        desc = (ims_desc.get(k) or "").strip()
        pth = (ims_absp.get(k) or "").strip()
        assert Path(pth).is_file(), f"图{k}路径{abs_path}不存在"
        text_for_embed = f"图{k}为{desc}"
        nodes.append(TextNode(
            text=text_for_embed,
            id_=f"{doc_id}::figure::{k}",
            metadata={
                "doc_id": doc_id,
                "node_type": "figure",
                "fig_no": k,
                "fig_desc": desc,
                "fig_path": pth,
                "fig_bs64": "B64(omitted)" if (ims_bs64.get(k) or "") else "",
                "fig_annos": ims_annos,
                "display_text": f"图{k} {desc}",  # 前端用
                },
            excluded_embed_metadata_keys = ["fig_path", "fig_b64", "fig_annos", "display_text"],     
            excluded_llm_metadata_keys   = ["fig_b64"],  # 太长了、可能会影响正文信息 ---
        ))
    return nodes 
    
    
# test
doc_id_demo = str(uuid.uuid5(uuid.NAMESPACE_URL, str(mdfs.resolve())))
test_nodes = build_text_node_from_markdown(text_, doc_id_demo) + build_figure_node_from_figs(figs_, doc_id_demo)
len(test_nodes), sum(1 for n in test_nodes if n.metadata["node_type"]=="figure")


In [None]:
# ingestion-3  build_index(nodes) 

from typing import Optional

from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.vector_stores.faiss import FaissVectorStore

import faiss 

persist_dir = Path.cwd().parent / ".log/faiss_db"
persist_dir.mkdir(parents=True, exist_ok=True)
faiss_index_path = persist_dir / "faiss.index"
meta_path = persist_dir / "vector_meta.json"

def build_index(md_files: List[Path], embeded_dim=1024, mode="build"): 
    """ 新建一个faiss向量库 """
    # nodes     
    all_nodes: List[TextNode] = []
    for md in md_files:
        text_, figs_ = load_md_and_figs(md)
        doc_id = str(uuid.uuid5(uuid.NAMESPACE_URL, str(md.resolve()))) 
        all_nodes += build_text_node_from_markdown(text_, doc_id)
        all_nodes += build_figure_node_from_figs(figs_, doc_id)
    
    # init
    vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(embeded_dim))
    # build
    storage_context = StorageContext.from_defaults(vector_store=vector_store)
    storage_index = VectorStoreIndex(all_nodes, storage_context=storage_context, show_progress=True)
    
    # persist 
    storage_context.persist(persist_dir=str(persist_dir))
    print(f"[build]新建索引：nodes={len(all_nodes)}")
    return storage_index, all_nodes
    

vector_store_index, nodes_cache = build_index(md_files=[mdfs])    


In [None]:
# ingestion-4.1  vector_index persist      -- faiss_db

from llama_index.core import load_index_from_storage
from llama_index.core.indices.base import BaseIndex

def load_index(embeded_dim=1024, mode="faiss") -> BaseIndex:
    
    # init
    vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(embeded_dim))
    
    # load 
    storage_context = StorageContext.from_defaults(
        persist_dir=persist_dir,
        vector_store=vector_store,
    )
    vector_index = load_index_from_storage(storage_context)

    return vector_index

vector_store_index_ = load_index()

In [None]:
# ingestion-4.2 nodes persist     --  bm25_db

# 补充说明：
# faiss向量知识库  persist持久化  到底保存了那些东西到本地
# - 向量数据库存的就是向量（向量索引）。
#   无法从向量索引中得到TextNode

# - 在 LlamaIndex 的持久化体系里，节点内容/元数据 存在 docstore.json（DocumentStore），
#   而向量索引结构的元数据（比如向量库命名空间、id 映射等）在 index_store.json；
#   真正的向量由对应的 vector_store 管（FAISS 就是 faiss.index 文件）。

# - 加载顺序通常是：StorageContext.from_defaults(persist_dir=...) → 它会把 docstore.json、index_store.json 读回来；
#   然后 FaissVectorStore.from_persist_dir(...) 或由 LlamaIndex 把 FAISS 文件挂回。

# - 检索时先用 embed_model 得到 query 向量 → FAISS 搜索得到“向量行号/内部 id” → 通过 index_store/docstore 的映射 
#   找回 Node 的 node_id → 再去 docstore 里把 TextNode 取出来给你。

# 向量索引存一份  vector_db/...  <faiss_db>
# nodes存一份    nodes_db/..    <bm25_db>


# bm25_db   -- nodes persist


"""  普通的构建 bm25  nodes_db 
from llama_index.retrievers.bm25 import BM25Retriever 
from typing import Literal

# 构建nodes_db ， 区分开cevtor_db的那一套逻辑
def nodes_persist(md_files: List[Path], mode=Literal["bm25"]) -> BM25Retriever:
    
    nodes_cache: List[TextNode] = []
    for md in md_files:
        text, figs = load_md_and_figs(md)
        doc_id = str(uuid.uuid5(uuid.NAMESPACE_URL, str(md.resolve())))
        nodes_cache += build_text_node_from_markdown(text, doc_id)
        nodes_cache += build_figure_node_from_figs(figs, doc_id)
    
    bm25_retriever = BM25Retriever.from_defaults(
        nodes=nodes_cache,
        similarity_top_k=5,
        # 默认的语言是英文，可能有点影响吧。。
    )
    root_dir = Path.cwd().parent / ".log/bm25_db"
    root_dir.mkdir(parents=True, exist_ok=True)
    bm25_retriever.persist(root_dir)
    return root_dir
     
def build_bm25_retriever(mode=Literal["local"])-> BM25Retriever:
    root_dir = Path.cwd().parent / ".log/bm25_db"
    loaded_bm25_retriever = BM25Retriever.from_persist_dir(str(root_dir))
    return  loaded_bm25_retriever
 
"""


# 结构化保存 nodes   -- bm25

import gzip, json, time, re  
from rank_bm25 import BM25Okapi 
from dataclasses import dataclass 
from llama_index.core import QueryBundle, VectorStoreIndex
from llama_index.core.schema import TextNode, NodeWithScore 


# ------------ 轻量分词（可替换为更强中文分词）-----------------
def _tokenize(s: str) -> List[str]:
    s = s or ""
    # 中文按字切 + 保留英文/数字 token
    toks_cn = [ch for ch in s if "\u4e00" <= ch <= "\u9fff"]
    toks_en = re.findall(r"[A-Za-z0-9_]+", s)
    return toks_cn + toks_en


# ------------  从 nodes_cache 构建若干辅助索引 ----------------------
def build_text_nodes_by_doc(nodes: List[TextNode]) -> Dict[str, List[TextNode]]:
    by_doc: Dict[str, List[TextNode]] = {}
    for n in nodes:
        if n.metadata.get("node_type") == "text":
            doc_id = n.metadata.get("doc_id") or "unknown"
            by_doc.setdefault(doc_id, []).append(n)
    return by_doc

def build_fig_index_by_doc(nodes: List[TextNode]) -> Dict[str, Dict[int, TextNode]]:
    """fig_no -> TextNode （每个 doc 各自维护，避免不同文档里都有图1冲突）"""
    idx: Dict[str, Dict[int, TextNode]] = {}
    for n in nodes:
        if n.metadata.get("node_type") == "figure":
            doc_id = n.metadata.get("doc_id") or "unknown"
            try:
                no = int(str(n.metadata.get("fig_no", "")).strip())
            except Exception:
                continue
            idx.setdefault(doc_id, {})[no] = n
    return idx

TEXT_NODES_BY_DOC: Dict[str, List[TextNode]] = build_text_nodes_by_doc(nodes_cache)
FIG_INDEX_BY_DOC: Dict[str, Dict[int, TextNode]] = build_fig_index_by_doc(nodes_cache)
NODE_LOOKUP: Dict[str, TextNode] = {n.node_id: n for n in nodes_cache}


# ------------ BM25 分片持久化（按 doc 分片） --------------------
@dataclass
class BM25Shard:
    bm25: BM25Okapi
    node_ids: List[str]          # 与 bm25.corpus 顺序一一对应

# ---- 持久化 I/O ----
def save_bm25_shard(doc_id: str, text_nodes: List[TextNode], bm25_dir: Path) -> Path:
    bm25_dir = Path(bm25_dir)
    shard_dir = bm25_dir / "shards"
    shard_dir.mkdir(parents=True, exist_ok=True)
    outp = shard_dir / f"{doc_id}.jsonl.gz"

    with gzip.open(outp, "wt", encoding="utf-8") as fo:
        for n in text_nodes:
            tokens = _tokenize(n.get_content())
            fo.write(json.dumps({"node_id": n.node_id, "tokens": tokens}, ensure_ascii=False) + "\n")
    return outp

def load_bm25_shard(doc_id: str, bm25_dir: Path) -> Optional[BM25Shard]:
    shard_path = Path(bm25_dir) / "shards" / f"{doc_id}.jsonl.gz"
    if not shard_path.exists():
        return None
    node_ids, corpus_tokens = [], []
    with gzip.open(shard_path, "rt", encoding="utf-8") as fi:
        for line in fi:
            obj = json.loads(line)
            node_ids.append(obj["node_id"])
            corpus_tokens.append(obj["tokens"])
    return BM25Shard(bm25=BM25Okapi(corpus_tokens), node_ids=node_ids)

def write_manifest(bm25_dir: Path, doc_ids: List[str], version: str = "v1") -> None:
    bm25_dir = Path(bm25_dir)
    bm25_dir.mkdir(parents=True, exist_ok=True)
    manifest = {
        "version": version,
        "created_at": int(time.time()),
        "doc_count": len(doc_ids),
        "doc_ids": doc_ids,
    }
    (bm25_dir / "manifest.json").write_text(json.dumps(manifest, ensure_ascii=False, indent=2), "utf-8")

def read_manifest(bm25_dir: Path) -> Dict[str, any]:
    p = Path(bm25_dir) / "manifest.json"
    return json.loads(p.read_text("utf-8")) if p.exists() else {}

# ---- 构建/更新 BM25 库（按文档分片） ----
def build_or_update_bm25_db(
    bm25_dir: Path,
    text_nodes_by_doc: Dict[str, List[TextNode]],
    overwrite: bool = False,
) -> Dict[str, Path]:
    bm25_dir = Path(bm25_dir)
    shard_dir = bm25_dir / "shards"
    shard_dir.mkdir(parents=True, exist_ok=True)
    written: Dict[str, Path] = {}
    for doc_id, nodes in text_nodes_by_doc.items():
        shard_file = shard_dir / f"{doc_id}.jsonl.gz"
        if shard_file.exists() and not overwrite:
            written[doc_id] = shard_file
            continue
        save_bm25_shard(doc_id, nodes, bm25_dir)
        written[doc_id] = shard_file
    write_manifest(bm25_dir, list(text_nodes_by_doc.keys()))
    return written

# ---- 查询时仅在“候选文档”上做 BM25 检索（懒加载分片） ----
def bm25_query_on_docs(
    query: str,
    doc_ids: List[str],
    bm25_dir: Path,
    node_lookup: Dict[str, TextNode],
    per_doc_topk: int = 12,
) -> List[Tuple[TextNode, float]]:
    toks = _tokenize(query)
    out: List[Tuple[TextNode, float]] = []
    for doc_id in doc_ids:
        shard = load_bm25_shard(doc_id, bm25_dir)
        if not shard:
            continue
        scores = shard.bm25.get_scores(toks)
        idxs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:per_doc_topk]
        for i in idxs:
            nid = shard.node_ids[i]
            n = node_lookup.get(nid)
            if n is not None:
                out.append((n, float(scores[i])))
    return out



# --------- 混合检索（FAISS→候选文档→BM25），可加“图N偏好/联动”---------
def _vector_hits(index: VectorStoreIndex, query: str, top_k: int = 24) -> List[Tuple[TextNode, float]]:
    retriever = index.as_retriever(similarity_top_k=top_k)
    hits = retriever.retrieve(QueryBundle(query))
    out: List[Tuple[TextNode, float]] = []
    for h in hits:
        out.append((h.node, float(getattr(h, "score", 0.0) or 0.0)))
    return out

def _choose_top_docs_from_vec(vec_pairs: List[Tuple[TextNode, float]], k_docs: int = 3) -> List[str]:
    agg: Dict[str, float] = {}
    for n, s in vec_pairs:
        doc_id = n.metadata.get("doc_id") or "unknown"
        agg[doc_id] = max(agg.get(doc_id, 0.0), s)   # 也可以 sum / mean
    return [d for d, _ in sorted(agg.items(), key=lambda x: x[1], reverse=True)[:k_docs]]

def _minmax_norm(pairs: List[Tuple[TextNode, float]]) -> Dict[str, float]:
    if not pairs:
        return {}
    scores = [s for _, s in pairs]
    mn, mx = min(scores), max(scores)
    rng = (mx - mn) or 1.0
    out = {}
    for n, s in pairs:
        out[n.node_id] = (s - mn) / rng
    return out

def _fig_nums_in_text(s: str) -> List[int]:
    PAT = re.compile(r"图\s*([0-9０-９]+)")
    trans = str.maketrans("０１２３４５６７８９","0123456789")
    out = []
    for m in PAT.finditer(s or ""):
        try:
            out.append(int(m.group(1).translate(trans)))
        except: pass
    return out

def hybrid_search_sharded_persisted(
    index: VectorStoreIndex,
    query: str,
    bm25_dir: Path,
    *,
    top_k: int = 8,
    mmr_boost_fig_mention: float = 0.02,
    top_docs_for_bm25: int = 3,
    per_doc_topk: int = 20,
    w_vec: float = 0.7,
    w_bm25: float = 0.3,
    prefer_fig: bool = True,
    fig_link: bool = True,
) -> List[Tuple[TextNode, float]]:
    # 1) 向量近邻（取多一点，用于挑候选文档）
    vec_pairs = _vector_hits(index, query, top_k=max(top_k, 24))
    cand_docs  = _choose_top_docs_from_vec(vec_pairs, k_docs=top_docs_for_bm25)

    # 2) 对候选文档做 BM25（从本地分片懒加载）
    bm25_pairs = bm25_query_on_docs(
        query, cand_docs, bm25_dir=bm25_dir, node_lookup=NODE_LOOKUP, per_doc_topk=per_doc_topk
    )

    # 3) 归一化 + 融合
    vnorm = _minmax_norm(vec_pairs)
    bnorm = _minmax_norm(bm25_pairs)
    pool: Dict[str, Tuple[TextNode, float]] = {}
    for n, _ in vec_pairs:
        pool[n.node_id] = (n, w_vec * vnorm.get(n.node_id, 0.0))
    for n, _ in bm25_pairs:
        if n.node_id in pool:
            old_n, old_s = pool[n.node_id]
            pool[n.node_id] = (old_n, old_s + w_bm25 * bnorm.get(n.node_id, 0.0))
        else:
            pool[n.node_id] = (n, w_bm25 * bnorm.get(n.node_id, 0.0))

    merged = sorted(pool.values(), key=lambda x: x[1], reverse=True)

    # 4) “图N”偏好（用户 query 明确提到某图号时，给相应 figure 小幅加分）
    if prefer_fig:
        mention = set(_fig_nums_in_text(query))
        if mention:
            boosted: List[Tuple[TextNode, float]] = []
            for n, s in merged:
                if n.metadata.get("node_type") == "figure":
                    try:
                        no = int(str(n.metadata.get("fig_no", "")).strip())
                    except Exception:
                        no = None
                    if no in mention:
                        boosted.append((n, s + mmr_boost_fig_mention))
                        continue
                boosted.append((n, s))
            merged = sorted(boosted, key=lambda x: x[1], reverse=True)

    # 5) “图N联动”（文本段落里出现“图N”，则补充当前 doc 的对应图节点）
    if fig_link:
        # 已命中集合
        have_ids = set(n.node_id for n, _ in merged)
        # 从文本命中里抽取“图N”
        wanted: Dict[str, set[int]] = {}
        for n, _ in merged[:top_k*2]:
            if n.metadata.get("node_type") == "text":
                doc_id = n.metadata.get("doc_id") or "unknown"
                wanted.setdefault(doc_id, set()).update(_fig_nums_in_text(n.get_content()))
        # 追加
        extras: List[Tuple[TextNode, float]] = []
        base = merged[-1][1] if merged else 0.0
        for doc_id, fig_set in wanted.items():
            fig_map = FIG_INDEX_BY_DOC.get(doc_id, {})
            for fno in sorted(fig_set):
                fn = fig_map.get(fno)
                if fn and fn.node_id not in have_ids:
                    extras.append((fn, base + 1e-6))
                    have_ids.add(fn.node_id)
        if extras:
            merged = sorted(merged + extras, key=lambda x: x[1], reverse=True)

    return merged[:top_k]


# ---------  统一打印/渲染（annos 只打印一次）---------
def render_answer(query: str, pairs: List[Tuple[TextNode, float]], show: int = 8) -> None:
    print("Q:", query)
    printed_annos = False
    for i, (n, s) in enumerate(pairs[:show], 1):
        node_id = n.id_
        mt = n.metadata or {}
        ntype = mt.get("node_type", "text")
        if ntype == "figure":
            no   = mt.get("fig_no")
            desc = mt.get("fig_desc", "")
            path = mt.get("fig_path", "")
            print(f"{i:>2}. [FIG {no}] {node_id} | {desc} | path={path} | score={s:.4f}")
            if not printed_annos and (mt.get("fig_annos") or "").strip():
                ann = mt["fig_annos"]
                print("    └─ 附图标记说明：", ann[:200] + ("…" if len(ann) > 200 else ""))
                printed_annos = True
        else:
            txt = (n.get_content() or "").replace("\n", " ").strip()
            print(f"{i:>2}. [TEXT] {node_id} | {txt[:180]}{'…' if len(txt) > 180 else ''} | score={s:.4f}")


# --------- 构建/更新 BM25_DB，然后小测试 ---------
bm25_dir = Path.cwd().parent / ".log" / "bm25_db"
_ = build_or_update_bm25_db(bm25_dir, TEXT_NODES_BY_DOC, overwrite=False)

# 试一把：通用问题
q1 = "这个专利的核心结构与技术要点是什么？"
hits1 = hybrid_search_sharded_persisted(vector_store_index, q1, bm25_dir=bm25_dir, top_k=8, prefer_fig=True, fig_link=True)
render_answer(q1, hits1, show=8)

# 试一把：带“图N”偏好 + 联动
q2 = "请解释图2的含义，并给出相关段落"   # 单独找这个 图2  有点难。
hits2 = hybrid_search_sharded_persisted(vector_store_index, q2, bm25_dir=bm25_dir, top_k=8, prefer_fig=True, fig_link=True)
render_answer(q2, hits2, show=8)


In [None]:
# (ingestion-4.3  load persist_vector_db + persit_nodes_db    -- faissdb  -- bm25_db )


def load_vector_db_local(mode="faiss"):
    pass 


def load_nodes_db_local(mode="bm25"):
    pass 


In [None]:
# ingestion-5 retrieve 

import re 
from dataclasses import dataclass 
from typing import Iterable 

from llama_index.core import QueryBundle
from llama_index.retrievers.bm25 import BM25Retriever 

# retrieve config 
TOP_K = 8
VECTOR_MODE = "mmr"       # "default" | "mmr"
MMR_ALPHA = 0.5           # 仅当 VECTOR_MODE="mmr" 生效，越大越多样化
USE_BM25 = True           # 默认使用
HYBRID_W_VEC = 0.70       # 混合检索加权：向量
HYBRID_W_BM25 = 0.30      # 混合检索加权：BM25
PREFER_FIG_BOOST = 0.02   # 检索命中“图N”时，小幅提升对应 figure 的分数
TEXT_PREVIEW_CHARS = 400  # 打印预览字数  

# 数字
def _to_ascii_int(s: str, int_out: bool=True) -> Optional[int]:
    _DIGIT_TRANS = str.maketrans("０１２３４５６７８９", "0123456789")
    try:
        return int(s.translate(_DIGIT_TRANS))
    except Exception:
        if int_out:
            return 10**9
        else:
            return None

# "图n"
FIG_PAT = re.compile(r"图\s*([0-9０-９]+)")
def _fig_nums_in_text(s: str) -> List[int]:
    out = []
    for m in FIG_PAT.finditer(s or ""):
        n = _to_ascii_int(m.group(1))
        if n is not None:
            out.append(n)
    return out

# ---------------------   统一结果结构  ----------------
@dataclass 
class HitRow:
    node_id: str 
    node_type: str
    score: float 
    text_preview: str 
    # figure 
    fig_no: Optional[str] = None  # abs -> 0 , 子图： 1 2 3 ...
    fig_desc: Optional[str] = None
    fig_path: Optional[str] = None
    fig_bs64: Optional[str] = None
    fig_annos: Optional[str] = None
    

def _coerce_hits(hits: Iterable[Any]) -> List[Tuple[TextNode, float]]:
    """  把多种 hits 统一成 List[(node, score)] """
    out = [] 
    for h in hits:
        if hasattr(h, "node"):  # NodeWithScore   -> vector-method 
            out.append((h.node, float(getattr(h, "score", 0.0))))
        elif isinstance(h, TextNode): # 直接node
            out.append((h, 0.0))
        else:
            try:
                n,s = h
                out.append((n,float(s)))
            except Exception:
                pass 
    return out 


def _build_hit_row(n: TextNode, score: float, preview_chars: int=TEXT_PREVIEW_CHARS) -> HitRow:
    ntype = (n.metadata or {}).get("node_type", "text")
    if ntype == "figure":
        pr = (n.get_content() or "")
        pr = pr.replace("\n", " ").strip()
        pr = pr[:preview_chars] + ("…" if len(pr) > preview_chars else "")
        return HitRow(
            node_id=n.node_id,
            node_type="figure",
            score=score,
            text_preview=pr,
            fig_no=str((n.metadata or {}).get("fig_no")),
            fig_desc=(n.metadata or {}).get("fig_desc") or "",
            fig_path=(n.metadata or {}).get("fig_path") or "",
            fig_bs64=(n.metadata or {}).get("fig_bs64") or "", 
            fig_annos=(n.metadata or {}).get("fig_annos") or "", 
        )
        
    else:
        pr = (n.get_content() or "")
        pr = pr.replace("\n", " ").strip()
        pr = pr[:preview_chars] + ("…" if len(pr) > preview_chars else "")
        return HitRow(
            node_id=n.node_id, node_type="text", score=score,
            text_preview=pr,
        )


def print_rows(query: str, rows: List[HitRow], show: int = 10, show_annos_once: bool = True):
    """统一打印；附图标记说明（annos）只打印一次"""
    print(f"Q: {query}\n")
    annos_printed = False
    for i, r in enumerate(rows[:show], 1):
        if r.node_type == "figure":
            print(f"{i:>2}. [FIG  ] score={r.score:.4f} | {r.node_id} | 图{r.fig_no} {r.fig_desc} | path={r.fig_path}")
            if show_annos_once and (not annos_printed) and r.fig_annos:
                print(f"    └─ 附图标记说明：{r.fig_annos[:200]}{'…' if len(r.fig_annos)>200 else ''}")
                annos_printed = True
        else:
            print(f"{i:>2}. [TEXT ] score={r.score:.4f} | {r.node_id} |{r.text_preview}")


# bm25 retriever  -->  针对目标文档 -- 检索到的这些文档  <-- doc_id
# bm25 & 图n快速映射
from llama_index.retrievers.bm25 import BM25Retriever 
import Stemmer

# 需要：你已在之前单元拿到了 index, nodes_cache
assert "vector_store_index" in globals(), "请先运行你构建索引的单元，获得 `vector_store_index`。"
assert "nodes_cache" in globals(), "请先运行你构建节点的单元，获得 `nodes_cache`。"

BM25_TOP_K_DEFAULT = 15    # 候选越多，融合时越有余地
BM25_RET = None 
if USE_BM25:
    BM25_RET = BM25Retriever.from_defaults(nodes=nodes_cache, 
                                           similarity_top_k=BM25_TOP_K_DEFAULT,
                                           stemmer=Stemmer.Stemmer("english"),
                                            language="english",
                                           )

# 图n 联动索引（text命中里提到了 图n 时， 补充相应的 figure node）
FIG_NODE_INDEX: Dict[int, TextNode] = {}
for n in nodes_cache:
    if (n.metadata or {}).get("node_type") == "figure":
        no_raw = (n.metadata or {}).get("fig_no")
        try:
            k = int(str(no_raw).strip())
            FIG_NODE_INDEX[k]=n
        except Exception:
            pass 


In [None]:
# ========= ③ 检索模式：vector / bm25 / hybrid =========

def vector_search(
    index: VectorStoreIndex,
    query: str,
    *,
    top_k: int = TOP_K,
    vector_mode: str = VECTOR_MODE,
    mmr_alpha: float = MMR_ALPHA,
) -> List[Tuple[TextNode, float]]:
    retriever = index.as_retriever(
        similarity_top_k=top_k,
        vector_store_query_mode=vector_mode,
        alpha=mmr_alpha if vector_mode == "mmr" else None,
    )
    hits = retriever.retrieve(QueryBundle(query))
    return _coerce_hits(hits)

def bm25_search(
    query: str,
    *,
    top_k: int = TOP_K,
) -> List[Tuple[TextNode, float]]:
    if not (USE_BM25 and BM25_RET is not None):
        return []
    return _coerce_hits(BM25_RET.retrieve(query)[:top_k])

def _normalize_scores(pairs: List[Tuple[TextNode, float]]) -> List[Tuple[TextNode, float, float]]:
    """min-max 归一化到 [0,1]，返回 (node, raw, norm)"""
    if not pairs:
        return []
    vals = [s for _, s in pairs]
    mx, mn = max(vals), min(vals)
    rng = (mx - mn) or 1.0
    out = []
    for n, s in pairs:
        out.append((n, s, (s - mn) / rng))
    return out

def hybrid_search(
    index: VectorStoreIndex,
    query: str,
    *,
    top_k: int = TOP_K,
    w_vec: float = HYBRID_W_VEC,
    w_bm25: float = HYBRID_W_BM25,
    vector_mode: str = VECTOR_MODE,
    mmr_alpha: float = MMR_ALPHA,
) -> List[Tuple[TextNode, float]]:
    vec_pairs = vector_search(index, query, top_k=top_k, vector_mode=vector_mode, mmr_alpha=mmr_alpha)
    bm_pairs  = bm25_search(query, top_k=top_k)

    vec_norm = _normalize_scores(vec_pairs)
    bm_norm  = _normalize_scores(bm_pairs)

    pool: Dict[str, Tuple[TextNode, float]] = {}
    for n, _, nv in vec_norm:
        pool[n.node_id] = (n, w_vec * nv)
    for n, _, nb in bm_norm:
        if n.node_id in pool:
            old_n, old_s = pool[n.node_id]
            pool[n.node_id] = (old_n, old_s + w_bm25 * nb)
        else:
            pool[n.node_id] = (n, w_bm25 * nb)

    merged = sorted(pool.values(), key=lambda x: x[1], reverse=True)[:top_k]
    return merged

# ---- 简单试跑 ----
for mode_name, fn in [
    ("vector", lambda q: vector_search(vector_store_index, q, top_k=5)),
    ("bm25",   lambda q: bm25_search(q, top_k=5)),
    ("hybrid", lambda q: hybrid_search(vector_store_index, q, top_k=5)),
]:
    q = "该专利发明了什么？"
    pairs = fn(q)
    rows = [_build_hit_row(n, s) for n, s in pairs]
    print(f"\n=== {mode_name.upper()} ===")
    print_rows(q, rows, show=5)


In [None]:
# z    llm 

# local_llm

model_name = "Qwen/Qwen3-1.7B"
from llama_index.llms.huggingface import HuggingFaceLLM 
from llama_index.core import Settings 

local_llm = HuggingFaceLLM(
    model_name=model_name,
    tokenizer_name=model_name,
    context_window=1400,
    max_new_tokens=300,
    generate_kwargs={"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    device_map='cpu'
)


# cloud_llm
from openai import OpenAI 
import os 
from dotenv import load_dotenv
load_dotenv()

DEEPSEEK_API_KEY = os.getenv("GLM_API_KEY")   # https://api.deepseek.com
QWEN_API_KEY = os.getenv("GLM_API_KEY")       # https://dashscope.aliyuncs.com/compatible-mode/v1

# client
client = OpenAI(
    api_key=QWEN_API_KEY,
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)


Settings.llm = local_llm



# 自定义retriever 