# 离线流程

In [None]:
# 获取chunk的embedding
from utils.vector_store import get_chunk_with_embedding
chunks_path = "outputs_chunks/article_chunks06.json"
embedding_path = "outputs_chunks/chunk_embedding08.json"
chunks_with_embedding = get_chunk_with_embedding(
    chunks_path, embedding_path=embedding_path
)

In [None]:
from typing import List
from llama_index.core.base.embeddings.base import BaseEmbedding
from utils.common_utils import build_doubao_embedding
import os
class DouBaoEmbedding(BaseEmbedding):
    def __init__(self, model_name: str = "doubao-embedding-text-240715", emb_model, **kwargs):
        super().__init__(**kwargs)
        self.model_name = model_name
        self.emb_model = emb_model
    def _get_embedding(self, texts: list[str] | str) -> List[float] | List[List[float]]:
        # 这里替换为实际调用豆包平台的 API 获取 embedding 的逻辑
        # 例如通过 requests 请求、认证等
        single_text = isinstance(texts, str)
        if single_text:
            texts = [texts]
        response = self.emb_model(
            model=self.model_name,
            input=texts
        )
        embeddings = [
            embedding_data.embedding for embedding_data in response.data
        ]
        if single_text:
            return embeddings[0]
        return embeddings  # 返回浮点数列表

    async def _aget_embedding(self, text: str) -> List[float]:
        return self._get_embedding(text)

    def _get_text_embedding(self, text: list[str]) -> List[List[float]]:
        return self._get_embedding(text)

    def _get_query_embedding(self, query: str) -> List[float]:
        return self._get_embedding(query)
    async def _aget_text_embedding(self, text: list[str]) -> List[List[float]]:
        return self._get_text_embedding(text)
    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

def get_doubao_embedding(model="doubao-embedding-text-240715"):
    emb_model = build_doubao_embedding()
    
    embedding_model = DouBaoEmbedding(
        model=model,
        emb_model=emb_model,
        api_key=os.environ.get("COMPLETION_OPENAI_API_KEY"),
        api_base=os.environ.get("COMPLETION_OPENAI_BASE_URL"),
    )
    return embedding_model


In [None]:
from utils.vector_store import get_doubao_embedding

doubao_embedding = get_doubao_embedding()

In [None]:
# 将 chunks 存储起来
# embedding_path = "outputs_chunks/chunk_embedding08.json"
# allnodes = get_nodes(embedding_path)

def storage_embedding_nodes(
        embedding_path,
        chroma_db="llama_index/chroma_db",
        chroma_name="sc_collection01",
        storage_dir="./vector_index01",
        embedding_model=None
    ):
    allnodes = get_nodes(embedding_path)
    docstore = SimpleDocumentStore()
    docstore.add_documents(allnodes)
    
    db = chromadb.PersistentClient(path=chroma_db)
    chroma_collection = db.get_or_create_collection(name=chroma_name)
    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
    storage_context = StorageContext.from_defaults(
        vector_store=vector_store, 
        docstore=docstore
    )
    vector_store.add(allnodes)

    doubao_embedding = get_doubao_embedding()
    index = VectorStoreIndex.from_vector_store(
        vector_store,
        storage_context=storage_context,
        show_progress=True,
        embed_model=doubao_embedding
    )
    storage_context.persist(persist_dir=storage_dir)


In [1]:
from utils.vector_store import storage_embedding_nodes

embedding_path = "outputs_chunks/chunk_embedding08.json"
storage_embedding_nodes(
    embedding_path,
    chroma_db="llama_index/chroma_db01",
    chroma_name="sc_collection",
    storage_dir="llama_index/vector_index01",
)

# 在线检索流程

In [2]:
from utils.retrieve_nodes import rerank_chunks
question = "请用中文回答：你好，你叫什么名字？"
docs = [
    {"chunk": "我叫张三", "chunk_id": 1, "source": "chunk_1"}, 
    {"chunk": "hello", "chunk_id": 2, "source": "chunk_2"}, 
    {"chunk": "rainning", "chunk_id": 3, "source": "chunk_3"}, 
]
sorted_chunks = rerank_chunks(question, docs)
sorted_chunks

[{'chunk': 'hello',
  'chunk_id': 2,
  'source': 'chunk_2',
  'score': 2.1747677326202393},
 {'chunk': '我叫张三',
  'chunk_id': 1,
  'source': 'chunk_1',
  'score': -1.6213587522506714},
 {'chunk': 'rainning',
  'chunk_id': 3,
  'source': 'chunk_3',
  'score': -9.40652084350586}]

In [None]:
from utils.retrieve_nodes import get_reranked_nodes
query = "你是做什么工作的?"
source_nodes = get_reranked_nodes(query)
len(source_nodes), source_nodes[0]

resource module not available on Windows


  from .autonotebook import tqdm as notebook_tqdm
  warn(


LLM is explicitly disabled. Using MockLLM.
Retrieving nodes for query: 你是做什么工作的?
20 nodes retrieved


You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  attn_output = torch.nn.functional.scaled_dot_product_attention(


len(response.source_nodes): 15


(15,
 NodeWithScore(node=TextNode(id_='2244e12b-fdd7-45d5-bb69-459b2c5b429b', embedding=None, metadata={'source': '2005_OLED行业一瞥_王力_llm_correct.md'}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, metadata_template='{key}: {value}', metadata_separator='\n', text='电话:编辑部 68514016, 58882976广告部 68578428, 68578429, 58882861会展部 68580742, 68511818-25发行部 58882980, 68511818-34, 33\n电邮:编辑部 article@edw.com.cn广告部 ads@edw.com.cn发行部 faxing@edw.com.cn会展部 seminar@edw.com.cn\n传真: (010) 68580564 Overseas Agent 海外广告代理:\n美国地区：AlignPoint Media Inc.\nTel: 1-925-998-4342, 1-510-828-7899\nFax: 1-866-235-4856\nEmail: eepw@alignpoint.com\n香港地区：Alegra International Ltd.\nE-mail: eepw@alegra.com.hk\n日本地区：Chugai Co., Ltd.\nTel: 81-3-3255-8411 Fax: 81-3-3255-8412 Contact Person: Mizoguchi Hiroyasu', mimetype='text/plain', start_char_idx=None, end_char_idx=None, metadata_seperator='\n', text_template='{metadata_str}\n\n{content}'), score=-6.721523761749268))

In [None]:
from utils.retrieve_nodes import get_retriever

retriever = get_retriever(
    docstore_path="llama_index/docstore.json",
    chroma_db="llama_index/chroma_db01",
    chroma_name="sc_collection",
    storage_dir="llama_index/vector_index01",
    similarity_top_k=10
)
retrieved_nodes = retriever.retrieve("What is the capital of France?")


resource module not available on Windows


  from .autonotebook import tqdm as notebook_tqdm
  warn(


Retrieving nodes for query: What is the capital of France?
20 nodes retrieved


In [None]:
query = "你是做什么工作的?"
from utils.retrieve_nodes import get_retriever
retriever = get_retriever(
    docstore_path="llama_index/docstore.json",
    chroma_db="llama_index/chroma_db01",
    chroma_name="sc_collection",
    storage_dir="llama_index/vector_index01",
    similarity_top_k=10
)
# from llama_index.postprocessor.flag_embedding_reranker import FlagEmbeddingReranker
# reranker_model = r"C:\Users\Administrator\.cache\modelscope\hub\models\BAAI\bge-reranker-large"
# reranker = FlagEmbeddingReranker(
#     model=reranker_model, top_n=15
# )
source_nodes = get_reranked_nodes(query, retriever)

In [None]:
def get_reranked_nodes(query, retrieved_nodes, retriever, reranker):
    Settings.llm = None
    query_engine = RetrieverQueryEngine.from_args(
        llm=None,
        response_mode="no_text",
        retriever=retriever, 
        node_postprocessors=[reranker]
    )
    response = query_engine.query(query)
    print(f"len(response.source_nodes): {len(response.source_nodes)}")
    return response.source_nodes

In [6]:
# 添加重排模型
from llama_index.postprocessor.flag_embedding_reranker import FlagEmbeddingReranker
from llama_index.core.query_engine import RetrieverQueryEngine

reranker_model = r"C:\Users\Administrator\.cache\modelscope\hub\models\BAAI\bge-reranker-large"
reranker = FlagEmbeddingReranker(
    model=reranker_model, top_n=15
)

from llama_index.core import Settings
# 显式关闭全局 LLM 设置
Settings.llm = None
query_engine = RetrieverQueryEngine.from_args(
    llm=None,
    response_mode="no_text",
    retriever=retriever, 
    node_postprocessors=[reranker]
)

LLM is explicitly disabled. Using MockLLM.


In [8]:
response = query_engine.query("技术开发项目中，可根据条件裁剪的角色有？")
len(response.source_nodes), response.source_nodes[0]

Retrieving nodes for query: 技术开发项目中，可根据条件裁剪的角色有？
20 nodes retrieved


You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  attn_output = torch.nn.functional.scaled_dot_product_attention(


(15,
 NodeWithScore(node=TextNode(id_='a3937eaa-b782-4512-863e-43c343b7ce31', embedding=None, metadata={'source': '2-对联苯-8-羟基喹啉锌...及其应用于新型白光OLED_赵婷_llm_correct.md'}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, metadata_template='{key}: {value}', metadata_separator='\n', text='14 Ding, B.-D.; Zhang, J.-M.; Zhu, W.-Q.; Zheng, X.-Y.; Wu, Y.-Z.; Jiang, X.-Y.; Zhang, Z.-L.; Xu, S.-H.\n\nChin.\n\nJ.\n\nLumin.\n\n2003, 24, 606 (in Chinese).\n\n(丁邦东, 张积梅, 朱文清, 郑新友, 吴有智, 蒋雪茵, 张志林, 许少鸿, 发光学报, 2003, 24, 606.)\n15 Flora, W.\n\nH.; Hall, H.\n\nK.; Armstrong, N.\n\nR.\n\nJ.\n\nPhys.\n\nChem.\n\nB 2003, 107, 1142.\n\n16 Kim, D.-E.; Kim, W.-S.; Kim, B.-S.; Lee, B.-J.; Kwon, Y.-S.\n\nColloids Surf.\n\nA: Physicochem.\n\nEng.\n\nAspects 2007, doi: 10.1016/j.colsurfa.2007.05.042.\n\n17 Shi, Y.-M.; Deng, Z.-B.; Xu, D.-H.; Xiao, J.', mimetype='text/plain', start_char_idx=None, end_char_idx=None, metadata_seperator='\n', text_template='{metadata_str}\n\n{content}'), sco

In [None]:
custom_retriever = get_vector_retriever(
    docstore_path="llama_index/docstore.json",
    embedding_path="",
    chroma_db="llama_index/chroma_db01",
    chroma_name="sc_collection",
    storage_dir="llama_index/vector_index01",
)
retrieved_nodes = custom_retriever.retrieve("What is the capital of France?")


In [None]:
import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import Settings, StorageContext, VectorStoreIndex


def get_retriever(
    docstore_path,
    embedding_path,
    chroma_db="llama_index/chroma_db01",
    chroma_name="sc_collection",
    storage_dir="llama_index/vector_index01",
    similarity_top_k=10
):
    docstore = SimpleDocumentStore.from_persist_path(docstore_path)
    db = chromadb.PersistentClient(path=chroma_db)
    chroma_collection = db.get_or_create_collection(name=chroma_name)
    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
    storage_context = StorageContext.from_defaults(
        persist_dir=storage_dir,
        vector_store=vector_store, 
        docstore=docstore
    )

    doubao_embedding = get_doubao_embedding()
    vector_index = VectorStoreIndex.from_vector_store(
        vector_store,
        storage_context=storage_context,
        embed_model=doubao_embedding
        show_progress=True,
    )
    vector_retriever = vector_index.as_retriever(
        similarity_top_k=similarity_top_k, 
        verbose=True
    )
    bm25_retriever = BM25Retriever.from_defaults(
        docstore=docstore,
        similarity_top_k=similarity_top_k,
    )
    
    custom_retriever = CustomRetriever(
        vector_retriever, 
        bm25_retriever, 
    )
    return custom_retriever

# from llama_index.retrievers.bm25 import BM25Retriever
# bm25_retriever = BM25Retriever.from_defaults(
#     docstore=docstore,
#     similarity_top_k=10,
# )
# bm25_retriever.persist("llama_index/bm25_retriever.json")
# loaded_bm25_retriever = BM25Retriever.from_persist_dir("llama_index/bm25_retriever.json")

from llama_index.core.retrievers import (
    BaseRetriever,
    VectorIndexRetriever,
)
from llama_index.core.schema import NodeWithScore
from llama_index.core import QueryBundle
from typing import List

# 4. 创建自定义的检索器
class CustomRetriever(BaseRetriever):
    """custom retriever that performs both vector and keyword table retrieval"""
    def __init__(self,
                 vector_retriever: VectorIndexRetriever,
                 bm25_retriever: BM25Retriever,
                 mode: str = "OR",
    ) -> None:
        self._vector_retriever = vector_retriever
        self._bm25_retriever = bm25_retriever
        if mode not in ["AND", "OR"]:
            raise ValueError("mode must be either AND or OR")
        self._mode = mode
        super().__init__()
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """retrieve nodes given query"""
        print(f"Retrieving nodes for query: {query_bundle.query_str}")
        vector_nodes = self._vector_retriever.retrieve(query_bundle)
        bm25_nodes = self._bm25_retriever.retrieve(query_bundle)
        
        vector_ids = {node.node.node_id for node in vector_nodes}
        bm25_ids = {node.node.node_id for node in bm25_nodes}
        
        combined_dict = {node.node.node_id: node for node in vector_nodes}
        combined_dict.update({node.node.node_id: node for node in bm25_nodes})
        
        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(bm25_ids)
        if self._mode == "OR":
            retrieve_ids = vector_ids.union(bm25_ids)
        
        retrieve_nodes = [combined_dict[node_id] for node_id in retrieve_ids]
        print(f"{len(retrieve_nodes)} nodes retrieved")
        return retrieve_nodes

custom_retriever = CustomRetriever(
    vector_retriever, 
    bm25_retriever, 
)
retrieved_nodes = custom_retriever.retrieve("What is the capital of France?")


