In [1]:
import pymysql
import numpy as np

# ✅ 连接 MySQL
conn = pymysql.connect(
    host="localhost",
    user="root",
    password="1234",
    database="xunfei",
    charset="utf8mb4",
    cursorclass=pymysql.cursors.DictCursor
)

cursor = conn.cursor()


In [2]:
# ✅ 表名与source_type映射
table_mapping = {
    "批量查询导出数据（企业信息）": "企业",
}


In [3]:
def record_to_text(record: dict) -> str:
    return "。".join([f"{key}：{str(value)}" for key, value in record.items() if value]) + "。"


In [4]:
from FlagEmbedding import BGEM3FlagModel
model = BGEM3FlagModel('BAAI/bge-m3',
                      use_fp16=False,
                      pooling_method='cls',
                      devices=['cuda:0'])

def get_embeddings(text):
    embeddings = model.encode(
        text,
        return_dense=True,
        return_sparse=True,
        return_colbert_vecs=False
    )
    return embeddings

  from .autonotebook import tqdm as notebook_tqdm
Fetching 30 files: 100%|███████████████████████████████████████████████████████████████████████| 30/30 [00:00<?, ?it/s]


In [5]:
from pymilvus import connections, utility, FieldSchema, CollectionSchema, DataType, Collection
connections.connect("default", host="localhost", port="19530")

fields = [
    FieldSchema(name="company_id", dtype=DataType.VARCHAR, max_length=100, is_primary=True),
    FieldSchema(name="source_type", dtype=DataType.VARCHAR, max_length=20),
    FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=8192),
    FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
    FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR)
]

schema = CollectionSchema(fields, description="Policy Paragraph Embeddings")
collection_name = "AllCompanies"

if utility.has_collection(collection_name):
    Collection(collection_name).drop()
collection = Collection(collection_name, schema, consistency_level="Strong")

dense_index = {"index_type": "HNSW", "metric_type": "L2"}
collection.create_index("dense_vector", dense_index)
sparse_inde
# col = Collection(col_name)
collection.load()

In [6]:
from tqdm import tqdm

for table_name, source_type in table_mapping.items():
    cursor.execute(f"SELECT * FROM `{table_name}`")
    rows = cursor.fetchall()
    print(f"读取表 {table_name}：{len(rows)} 条")

    company_ids = []
    source_types = []
    texts = []
    dense_vectors = []
    BATCH_SIZE = 100
    
    for row in tqdm(rows, desc='向量生成中...'):
        text = record_to_text(row).replace('"', '')
        if len(text.encode("utf-8")) > 8192:
            continue

        # print(text)
        # break
        company_ids.append(row["统一社会信用代码"])
        vector = get_embeddings(text)
        dense_vector = vector['dense_vecs']
        source_types.append(source_type)
        texts.append(text)
        dense_vectors.append(dense_vector)
        try:
            if len(company_ids) > BATCH_SIZE:
                collection.insert([company_ids, source_types, texts, dense_vectors])
                company_ids = []
                source_types = []
                texts = []
                dense_vectors = []
        except:
            continue
    if source_types:
        collection.insert([company_ids, source_types, texts, dense_vectors])

print("生成已完成。")

读取表 批量查询导出数据（企业信息）：632 条


You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
向量生成中...: 100%|█████████████████████████████████████████████████████████████████| 632/632 [03:32<00:00,  2.98it/s]

生成已完成。



