In [None]:
import json
import time
# 查找JSON文件
from pathlib import Path
# 帮组idea 更好的代码补全及错误提示
from typing import List,Dict

import chromadb

# VectorStoreIndex 管理文档向量表示
from llama_index.core import VectorStoreIndex,StorageContext,Settings

# TextNode是文档的基本单位
from llama_index.core.schema import TextNode

# 对接Huggingface模型的接口，用于生成问答
from llama_index.llms.huggingface import HuggingFaceLLM

# 用于将文本转为向量表示
from llama_index.embeddings.huggingface import HuggingFaceEnbedding

from llama_index.core import PromptTemplate

QA_TEMPLATE=(
    "<|im_start|>system\n"
    "你是一个专业的法律助手，请严格限制一下法律条文回答问题：\n"
    "相关法律条文：\n{context_str}\n<|im_end|>\n"
    "<|im_start|>user\n{query_str}<|im_end|>\n"
    "<|im_start|>assisstant\n"
)

response_template= PromptTemplate(QA_TEMPLATE)


In [None]:
class Config:
    EMBED_MODEL_PATH = r"BAAI-bge-small-zh-v1.5"
    LLM_MODEL_PATH = r""
    
    DATA_DIR="./data"
    VECTOR_DB_DIR="./chroma_db"
    # 用于保存索引和其他持久化数据
    PERSIST_DIR="./storage"
    
    # 向量库名
    COLLECTION_NAME = "chinese_labor_laws"
    TOP_K=3

In [None]:
def init_models():
    
    embed_model = HuggingFaceEmbedding(
        model_name= Config.EMBED_MODEL_PATH,
        device='cuda' if hasattr(Settings,'device') else 'cpu'
    )
    
    llm = HuggingFaceLLM(
        model_name=Config.LLM_MODEL_PATH,
        tokenizer_name=Config.LLM_MODEL_PATH,
        device_map="auto",# 自动选择设备
        tokenizer_kwargs={"trust_remote_code":True},
        generate_kwargs={"temperature":0.3}
    )
    
    Settings.embed_model= embed_model
    Settings.llm = llm
    
    test_embedding =embed_model.get_text_embedding("测试文本")
    print(f"Embedding深度验证：{len{test_embedding}}")
    
    return embed_model,llm

In [None]:
def load_and_validate_json_files(data_dir:str) -> List[Dict]:
    json_files = list(Path(data_dir).glob("*.json"))
    assert json_files , f"未找到JSON 文件{data_dir}"
    
    all_data=[]
    
    for json_file in json_files:
        with open(json_file,'r',encoding='utf-8') as f:
            try:
                data=json.load(f)
                if not isinstance(data,list):
                    raise ValueError(f"文件{json_file.name}根元素应为列表")
                for item in data:
                    if not isinstance(item,dict):
                        raise ValueError(f"文件{json_file.name}包含非字典元素")
                    for k,v in item.items():
                        if not isinstance(v,str):
                            raise ValueError(f"文件{json_file.name}Key：{k}的值不是字符串")
                all_data.extend({
                    "content":item,
                    "metadata":{"source":json_file.name}
                }for item in data)
            except Exception as e:
                raise RuntimeError(f"加载文件{json_file}失败：{str(e)}")
    print(f"成功加载{len(all_data)}条")
    return all_data

                        

In [None]:
# 每条存储为一个TextNode
def create_nodes(raw_data: List[Dict]) -> List[TextNode]:
    nodes=[]
    for entry in raw_data:
        law_dict = entry["content"]
        source_file = entry["metadata"]["source"]
        
        for full_title,content in law_dict.items():
            node_id = f"{source_file}::{full_title}"
            
            parts = full_title.split(" ",1)
            law_name = parts[0] if len(parts)>0 else "未知法律"
            article = parts[1] if len(parts)>0 else "未知条款"
            
            node = TextNode(
                text = content,
                id= node_id，
                metadata={
                    "law_name":law_name,
                    "article":article,
                    "full_title":full_title,
                    "source_file":source_file,
                    "content_type":"legal_article"
                }
            )
            nodes.append(node)
    print(f"create {len(nodes)}个文本，（ID示例：{nodes[0].id_}）")
    return nodes

In [None]:
# 向量存储

def init_vector_store(nodes:List[TextNode]) -> VectorStoreIndex:
    chroma_client = chromadb.pParsistenClient(path:Config.Vector_DB_DIR)
    
    chroma_collection = chroma_client.get_or_create_collection(
        name:Config.COLLECTION_NAME,
        metadata={"hnsw:space":"cosine"}# 余弦相似度
    )
    
    storage_context = StorageContext.from_defaults(
        vector_store= ChromaVectorStore(chroma_collection=chroma_collection)
    )
    
    if chroma_collection.count()== 0 and nodes is not NameErrorone：
        print(f"创建新索引{len(nodes)}")
        storage_context.docstore.add_ducuments(nodes)
        
        # 创建向量索引
        index = VectorStoreIndex(
            nodes,
            storage_context = storage_context,
            show_progress= True
        )
        
        # 双重持久化保障
        storage_context.persist(persist_dir=Config.PERSIST_DIR)
        index.storage_context.persist(persist_dir=Config.PERSIST_DIR)
    else
        print("load 已有数据")
        storage_context=StorageContext.from_defaults(
            persist_dir=Config.PERSIST_DIR,
            vector_store=ChromaVectorStore(chroma_collection=chroma_collection)
        )
        
        # 从向量存储创建索引
        index = VectorStoreIndex.from_vector_store(
            storage_context,vector_store,
            storage_context=storage_context,
            embed_model=Settings.embed_model
        )
        
    print("\n存储结果验证：")
    doc_count= len(storage_context.docstore.docs)
    print(f"DocStore数{doc_count}")
    
    if doc_count>0:
        samply_key = next(iter(storage_context.docstore.docs.keys()))
        print(f"示例节点Id：{samply_key}")
    else:
        print("文档为空")
    
    return index
        

In [None]:
def main():
    embed_model,llm = init_models()
    
     if not Path(Config.VECTOR_DB_DIR).exists():
        print("\n init data....")
        raw_data =  load_and_validate_json_files(Config.DATA_DIR)
        nodes = create_nodes(raw_data)
    else
        nodes = None
        
    print("\n init vector db storage")
    start_time =time.time()
    
    # 创建加载索引
    index = init_vector_store(nodes)
    
    print(f"索引加载耗时：{time.time()-start_time:.2f}")
    
    query_engine = index.as_query_engine(
        similarity_top_k:Config.TOP_K,
        text_qa_temperature = response_template,
        verbose=True #显示详细日志
    )
    
    
    while True:
        question = input("\n请输入问题，按q退出：")
        if question.lower =='q':
            break
            
        response = query_engin.query(question)
        
        print("\n回答：{response.response}")
        
        print("\n支持依据:")
    
    
        for idx,node in enumerate(response.source_nodes,1):
            meta = node.metadata

            print(f"\n[{idxx}]{meta['full_title']}")
            print(f"file:{meta['source_file']}")
            print(f"file：{meta['law_name']}")
            print(f"content:{node.text[:100]}...")
            print(f"score：{node.score:.4f}")
       

In [None]:
if __name__ == "__main__":
    main()