In [1]:
import pymysql
import numpy as np

# ✅ 连接 MySQL
conn = pymysql.connect(
    host="localhost",
    user="root",
    password="1234",
    database="xunfei",
    charset="utf8mb4",
    cursorclass=pymysql.cursors.DictCursor
)

cursor = conn.cursor()


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ✅ 表名与source_type映射
table_mapping = {
    "products": "商品",
    "products2": "商品",
}


In [3]:
def record_to_text(record: dict) -> str:
    return "。".join([f"{key}：{str(value)}" for key, value in record.items() if value]) + "。"


In [8]:
from FlagEmbedding import BGEM3FlagModel
model = BGEM3FlagModel('BAAI/bge-m3',
                      use_fp16=False,
                      pooling_method='cls',
                      devices=['cuda:0'])

def get_embeddings(text):
    embeddings = model.encode(
        text,
        return_dense=True,
        return_sparse=True,
        return_colbert_vecs=False
    )
    return embeddings

Fetching 30 files: 100%|████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 30002.17it/s]


In [21]:
from pymilvus import connections, utility, FieldSchema, CollectionSchema, DataType, Collection
connections.connect("default", host="localhost", port="19530")

fields = [
    FieldSchema(name="product_id", dtype=DataType.VARCHAR, max_length=100, is_primary=True),
    FieldSchema(name="source_type", dtype=DataType.VARCHAR, max_length=20),
    FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=8192),
    FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=1024)
]

schema = CollectionSchema(fields, description="Policy Paragraph Embeddings")
collection_name = "AllProducts"

if utility.has_collection(collection_name):
    Collection(collection_name).drop()
collection = Collection(collection_name, schema, consistency_level="Strong")

dense_index = {"index_type": "HNSW", "metric_type": "L2"}
collection.create_index("dense_vector", dense_index)
# col = Collection(col_name)
collection.load()

In [None]:
from tqdm import tqdm
project_id = 1
for table_name, source_type in table_mapping.items():
    cursor.execute(f"SELECT * FROM `{table_name}`")
    rows = cursor.fetchall()
    print(f"读取表 {table_name}：{len(rows)} 条")

    # batch_data = {
    #     # "project_id": [],
    #     "source_type": [],
    #     "text": [],
    #     "dense_vec": []
    # }
    project_ids = []
    source_types = []
    texts = []
    dense_vectors = []
    BATCH_SIZE = 100
    
    for row in tqdm(rows, desc='向量生成中...'):
        # if ("项目编号" or "招标项目编号") not in row or not row["项目编号"]:
        #     continue  # 项目编号为空时跳过

        text = record_to_text(row).replace('"', '')
        if len(text.encode("utf-8")) > 8192:
            continue

        # print(text)
        # break
        product
        vector = get_embeddings(text)
        dense_vector = vector['dense_vecs']
        source_types.append(source_type)
        texts.append(text)
        dense_vectors.append(dense_vector)
        try:
            if len(project_ids) > BATCH_SIZE:
                collection.insert([project_ids, source_types, texts, dense_vectors])
                project_ids = []
                source_types = []
                texts = []
                dense_vectors = []
        except:
            continue
        # batch_data["project_id"].append(row["项目编号"])
        # batch_data["source_type"].append(source_type)
        # batch_data["text"].append(text)
        # batch_data["dense_vec"].append(vector)
    if source_types:
        collection.insert([project_ids, source_types, texts, dense_vectors])

    # 写入 Milvus
    # try:
    #     if batch_data["text"]:
    #         collection.insert([
    #             # batch_data["project_id"],
    #             batch_data["source_type"],
    #             batch_data["text"],
    #             batch_data["dense_vec"]
    #         ])
    #         print(f"✅ 已写入 {len(batch_data['text'])} 条记录到 Milvus")
    # except:
    #     continue


# batch_ids = []
# para_ids = []
# texts = []
# dense_vectors = []
# sparse_vectors = []
# BATCH_SIZE = 100

# for doc in tqdm(col_mongo.find({"vectorized": True}), desc='向量生成中...'):
#     para_id = doc['para_id']
#     text = doc['text']
#     try:
#         vector = get_embeddings(text)
#         # print(vector)
#         # break
#         dense_vector = vector["dense_vecs"]
#         sparse_vector = vector["lexical_weights"]
#         para_ids.append(para_id)
#         texts.append(text)
#         dense_vectors.append(dense_vector)
#         sparse_vectors.append(sparse_vector)
#         batch_ids.append(doc['_id'])
#         if len(para_ids) > BATCH_SIZE:
#             col.insert([para_ids, texts, dense_vectors, sparse_vectors])
#             col_mongo.update_many({"_id": {"$in": batch_ids}}, {"$set": {"vectorized": True}})
#             batch_ids = []
#             para_ids = []
#             texts = []
#             dense_vectors = []
#             sparse_vectors = []

#     except Exception as e:
#         print(f"向量化失败：{para_id}, {type(e).__name__}: {e}")

    
print("生成已完成。")

读取表 上海政府采购公告：4096 条


向量生成中...: 100%|███████████████████████████████████████████████████████████████| 4096/4096 [08:25<00:00,  8.11it/s]


读取表 上海政府采购中标结果：4469 条


向量生成中...:  11%|███████▎                                                        | 512/4469 [00:41<05:17, 12.47it/s]