In [1]:
from pymilvus.model.hybrid import BGEM3EmbeddingFunction
import pandas as pd
from pymilvus import AnnSearchRequest, WeightedRanker
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
import torch
# from langchain_milvus import Milvus
from pymilvus import connections, utility, FieldSchema, CollectionSchema, DataType, Collection

In [2]:
# 加载嵌入模型
def load_db():
    embedding_model = BGEM3EmbeddingFunction(
        model_name=r'autodl-tmp/embedding_model/BAAI/bge-m3',
        use_fp16=False,
        device='cpu'
    )
    connections.connect(uri='vectordb/milvus_mix/milvus_m3_2.db')
    collection_name = 'hybrid_demo'
    milvus_collection = Collection(name=collection_name)
    milvus_collection.load()
    return milvus_collection, embedding_model

In [3]:
# 初始化
milvus_collection,embedding_model = load_db()


  return self.fget.__get__(instance, owner)()


In [4]:
def hybrid_search(
    col,
    query_dense_embedding,
    query_sparse_embedding,
    sparse_weight=1.0,
    dense_weight=1.0,
    limit=10
):

    dense_search_params = {'metric_type': 'IP', 'params': {}}
    dense_req = AnnSearchRequest(
        [query_dense_embedding], 'dense_vector', dense_search_params, limit=limit
    )

    sparse_search_params = {'metric_type':'IP', 'params':{}}
    sparse_req = AnnSearchRequest(
        [query_sparse_embedding], 'sparse_vector', sparse_search_params, limit=limit
    )

    rerank = WeightedRanker(sparse_weight, dense_weight)
    res = col.hybrid_search(
        [sparse_req, dense_req], rerank=rerank, limit=limit, output_fields=['text', 'title', 'time', 'infosource']
    )[0]

    return [{'metadata':{'title':hit.get('title'), 'time':hit.get('time'), 'infosource':hit.get('infosource')}, 'page_content':hit.get('text')} for hit in res]

In [5]:
def dense_search(col, query_dense_embedding, limit=10):
    search_params = {"metric_type": "IP", "params": {}}
    res = col.search(
        [query_dense_embedding],
        anns_field="dense_vector",
        limit=limit,
        output_fields=['text', 'title', 'time', 'infosource'],
        param=search_params,
    )[0]
    return [{'metadata':{'title':hit.get('title'), 'time':hit.get('time'), 'infosource':hit.get('infosource')}, 'page_content':hit.get('text')} for hit in res]

In [6]:
import csv
import os

In [9]:
# 这里调用前两者
test_data = pd.read_excel('qa_data_1250.xlsx')[['Question', 'Title']]
# test_data = pd.read_excel('500条.xlsx')[['Question', 'title']]

# recall_list = []
search_method = ['混合搜素', '稠密搜索']
current_search = search_method[0]
save_path = f'{current_search}优化测试结果_2.csv'
for idx in range(test_data.shape[0]):
    # print('data：', test_data.loc[idx])
    search_question = test_data.loc[idx]['Question']
    # 混合搜素向量召回
    embed_input = embedding_model([search_question])
    if current_search == '混合搜素':
        results = hybrid_search(
            milvus_collection,
            embed_input['dense'][0],
            embed_input['sparse']._getrow(0),
            sparse_weight=0.7,
            dense_weight=1.0,
            limit=10
        )

    # 稠密搜索向量召回
    elif current_search == '稠密搜索':
        results = dense_search(milvus_collection, embed_input["dense"][0], limit=10)
    else:
        print('无该方法')
    # recall_list.append(results)
    meta = [file['metadata']['title'] for file in results]
    # print(f'召回title:{meta}')
    print('===============')
    true_title = test_data.loc[idx]['Title']
    arg_weight = ''
    # prdict = 0
    for ii in range(len(meta)):
        predict_title = meta[ii]
        if true_title == predict_title:
            prdict = 1
            arg_weight = ii + 1
            break
        else:
            prdict = 0
    data_ = [search_question, true_title, prdict, arg_weight]
    with open(save_path, 'a', newline='', encoding='utf-8') as f:
        fw = csv.writer(f)
        if not os.path.getsize(save_path):
            header = ['Question', 'True_Title', 'Recalling', 'Recall_Weight']
            fw.writerow(header)
        fw.writerow(data_)
        

You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.




In [None]:
 pd.read_excel('qa_data_1250.xlsx')[['Question', 'Title']]

In [None]:
test_data.loc[idx]['Question']

In [None]:
true_title

In [None]:
pd.read_excel('500条.xlsx')

In [32]:
dd = {5:10}

In [41]:
list(dd.keys())

[5]

<built-in method keys of dict object at 0x7f1021f79e00>
