In [1]:
import json
import os
import re
from typing import List, Tuple

# from langchain_community.vectorstores import Milvus
# from langchain_milvus import Milvus
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
# from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_milvus.utils.sparse import BM25SparseEmbedding
from langchain_milvus.vectorstores import Milvus

json_folder_path = 'data/policy'
embedding_model_path = r"autodl-tmp/embedding_model/BAAI/bge-large-zh-v1___5"
persist_directory = 'vectordb/milvus_mix'

# 从文件夹读取文件名称列表，获取所有的文件完整路径
def get_file_name_form_folder(json_folder_path: str) -> List[str]:
    file_path_list_all = []
    for file_name in os.listdir(json_folder_path):
        file_path = os.path.join(json_folder_path, file_name)
        file_path_list_all.append(file_path)
    return file_path_list_all


# 读取和解析 JSON 文件
def parse_file_to_document(file_path_list_all: List[str]) -> List[Document]:
    documents = []
    for file_path in file_path_list_all:
        # 确保路径是文件而不是目录
        if os.path.isfile(file_path):
            document = Document(page_content="", metadata={})
            filename, extension = os.path.splitext(file_path)
            extension = extension.lstrip(".")

            with open(file_path, "r", encoding='utf-8', errors="ignore") as f:
                if extension == "json":
                    data = json.load(f)
                else:
                    data = [json.loads(line) for line in f if line.strip()]

                title = data.get("title", "").strip()
                time = data.get('time', "")
                infosource = data.get('infosource', "")
                metadata = {
                    "title": title,
                    "time": time[0] if time else "",
                    "infosource": infosource
                }

                context = data.get("context", '')
                context_text = "\n".join(context)
                context_text = re.sub(r'\n+', '', context_text)

                document.page_content = context_text
                document.metadata = metadata
                documents.append(document)
        else:
            print(f"Skipping directory: {file_path}")
    return documents


# 文本分割
def split_text(file_path_list_all):
    docs_list = parse_file_to_document(file_path_list_all)
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200)
    # chunk_size每多少个文本切分一次；chunk_overlap重叠部分是多少个字符
    splits = text_splitter.split_documents(docs_list)

    return splits
    # 下面将切分结果进行展示：splits 是一个列表，其中每个元素也是一个列表，表示一个文档的分割结果
    # for doc_index, doc_splits in enumerate(splits):
    #     print(f"Document {doc_index + 1}:")  # 显示文档编号
    #     for split_index, split_text in enumerate(doc_splits):
    #         print(f"  Split {split_index + 1}: {split_text[:50]}...")  # 打印每个分段的前50个字符
    #     print("\n" + "-" * 60 + "\n")  # 在每个文档之间加入分隔线，增加可读性


# 数据库创建
def create_my_db(split_docs):
    embeddings = HuggingFaceEmbeddings(model_name=embedding_model_path)

    
    # 定义持久化路径
    # 加载数据库
    vectordb = Milvus.from_documents(
        documents=split_docs,
        embedding=embeddings,
        connection_args={
        "uri": persist_directory + "/milvus_demo.db",
    },
        # connection_args={"host": "127.0.0.1", "port": "6006"}
        drop_old=True,
    )
    # 将加载的向量数据库持久化到磁盘上
    # vectordb.persist()

# 数据库创建
def create_mix_db(split_docs):
    dense_embeddings = HuggingFaceEmbeddings(model_name=embedding_model_path)
    # sparse_embeddings = BM25SparseEmbedding(corpus=split_docs)
    
    data_context = [file.page_content for file in split_docs]
    sparse_embeddings = BM25SparseEmbedding(corpus=data_context)
    
    
    # 定义持久化路径
    # 加载数据库
    vectordb = Milvus.from_documents(
        documents=split_docs,
        embedding=[dense_embeddings, sparse_embeddings],
        connection_args={
        "uri": persist_directory + "/milvus_mix_demo.db",
    },
        # connection_args={"host": "127.0.0.1", "port": "6006"}
        vector_field=["dense_vector", "sparse_vector"],  # 指定向量字段名
        # drop_old=True,
        auto_id=True
        
    )
    # 将加载的向量数据库持久化到磁盘上
    # vectordb.persist()


def add_new_data_to_db(new_split_docs, persist_directory):
    # 加载已有的数据库
    vectordb = Chroma(
        persist_directory=persist_directory,
        embedding=HuggingFaceEmbeddings(model_name=embedding_model_path)
    )
    # 将新的数据添加到数据库中
    vectordb.add_documents(new_split_docs)
    # 将更新后的数据库持久化到磁盘上
    vectordb.persist()


# 测试生成的document是否正确
# def t0():
#     file_path_list_all = get_file_name_form_folder(json_folder_path)
#     documents = parse_file_to_document(file_path_list_all)
#     print(documents)
#     for i, document in enumerate(documents):
#         print(f"{i+1}：document: {document}")


# # 运行主函数
# if __name__ == "__main__":
file_path_list_all = get_file_name_form_folder(json_folder_path)
split_texts = split_text(file_path_list_all)
# create_my_db(split_texts)  # 创建向量数据库，并传入数据
create_mix_db(split_texts)


Skipping directory: data/policy/.ipynb_checkpoints


  from tqdm.autonotebook import tqdm, trange
  return self.fget.__get__(instance, owner)()


In [5]:
data_context = [file.page_content for file in split_texts]

In [8]:
with open('BM25初始化数据.xlsx', 'a', newline='', encoding='utf-8') as f:
    for data_ in data_context:
        f.write(data_)


In [11]:
import pandas as pd
data_context = pd.DataFrame(data_context)

In [14]:
data_context.to_excel('BM25初始化数据.xlsx',index=False)