In [14]:
import torch
import tqdm
import numpy as np
import json
import os
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [15]:
def read_json(file_path):
    ''' 读取 json 文件 '''
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

In [16]:
def write_json(data, path):
    ''' 写入 json 文件 '''
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False)

In [17]:
def jaccard_similarity(list1, list2):
    set1, set2 = set(list1), set(list2)
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union

In [18]:
def cosine_similarity(embedding1, embedding2):
    return torch.nn.functional.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0), dim=-1)

In [19]:
def retrieve_top_k_documents(query_embedding, document_embeddings, query_token_list, document_token_lists, k=3, alpha=0.3):
    """
    从所有document embeddings中检索出与query embedding最相关的前k个document。
    """
    embedding_similarities = cosine_similarity(query_embedding, document_embeddings) #26599*1
    token_id_similarities = [] #26599*1
    for document_token_list in document_token_lists:
        token_id_similarities.append(jaccard_similarity(query_token_list, document_token_list))
    
    token_id_similarities = torch.tensor(token_id_similarities, device=device)
    combined_similarities = alpha * embedding_similarities + (1 - alpha) * token_id_similarities
    _, top_document_indices = combined_similarities.topk(k)
    return top_document_indices.tolist()

In [20]:
def rerank_documents(top_document_indices,query_embedding, document_embeddings,query_token_list,document_token_lists, k=3):
    """
    根据fact_input_list和query_input_list之间的相似度对初始检索结果进行重新排序。
    """
    new_scores = []
    token_id_similarities = [] #26599*1
    embedding_similarities = [] #26599*1

    for idx in top_document_indices[0]:
#         embedding_similarities.append(torch.nn.functional.cosine_similarity(query_embedding, document_embeddings[idx]))
        token_id_similarities.append(jaccard_similarity(query_token_list, document_token_lists[idx]))
    
    token_id_similarities = torch.tensor(token_id_similarities, device=device)
    combined_similarities = alpha * embedding_similarities + (1 - alpha) * token_id_similarities
    _, top_document_indices = token_id_similarities.topk(k)
    return top_document_indices.tolist()

In [21]:
def zip_fun():
    path=os.getcwd()
    newpath=path+"/output/"
    os.chdir(newpath)
    os.system('zip prediction.zip result.json')
    os.chdir(path)

In [22]:
# 读取query_testset文件（512条）  512*74
query = read_json('input/query_testset.json')
# query_embeddings = torch.tensor([entry['query_embedding'] for entry in query], device=device)
query_token_lists = [entry['query_input_list'] for entry in query]

In [23]:
# 读取检索fact（26599条）     26599*90
document = read_json('input/document.json')
document_embeddings = torch.tensor([entry['facts_embedding'] for entry in document], device=device)
document_token_lists = [entry['fact_input_list'] for entry in document]

In [25]:
results = []
for item in tqdm.tqdm(query):
    result = {}
    query_embedding = torch.tensor(item['query_embedding'], device=device)
    query_token_list=item['query_input_list']
    top_document_indices = retrieve_top_k_documents(query_embedding, document_embeddings,query_token_list,document_token_lists, k=3,alpha=0.3)
    reranked_indices = rerank_documents(top_document_indices,query_embedding, document_embeddings,query_token_list,document_token_lists, k=3)
    result['query_input_list'] = item['query_input_list']
    result['evidence_list'] = [{'fact_input_list': document[index]['fact_input_list']} for index in top_document_indices[0]]
    results.append(result)

100%|██████████| 512/512 [01:50<00:00,  4.64it/s]


In [None]:
write_json(results, 'output/result.json')
print('write to output/result.json successful')
zip_fun()
