In [1]:
import numpy as np
import pandas as pd
import pickle
from sentence_transformers import SentenceTransformer, util
import torch
from torch import Tensor
from tqdm import tqdm
from typing import Callable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
def semantic_search(query_embeddings: Tensor,
                    corpus_embeddings: Tensor,
                    query_chunk_size: int = 100,
                    corpus_chunk_size: int = 500000,
                    top_k: int = 10,
                    score_function: Callable[[Tensor, Tensor], Tensor] = util.cos_sim):
    """
    This function performs a cosine similarity search between a list of query embeddings  and a list of corpus embeddings.
    It can be used for Information Retrieval / Semantic Search for corpora up to about 1 Million entries.

    :param query_embeddings: A 2 dimensional tensor with the query embeddings.
    :param corpus_embeddings: A 2 dimensional tensor with the corpus embeddings.
    :param query_chunk_size: Process 100 queries simultaneously. Increasing that value increases the speed, but requires more memory.
    :param corpus_chunk_size: Scans the corpus 100k entries at a time. Increasing that value increases the speed, but requires more memory.
    :param top_k: Retrieve top k matching entries.
    :param score_function: Function for computing scores. By default, cosine similarity.
    :return: Returns a list with one entry for each query. Each entry is a list of dictionaries with the keys 'corpus_id' and 'score', sorted by decreasing cosine similarity scores.
    """

    if isinstance(query_embeddings, (np.ndarray, np.generic)):
        query_embeddings = torch.from_numpy(query_embeddings)
    elif isinstance(query_embeddings, list):
        query_embeddings = torch.stack(query_embeddings)

    if len(query_embeddings.shape) == 1:
        query_embeddings = query_embeddings.unsqueeze(0)

    if isinstance(corpus_embeddings, (np.ndarray, np.generic)):
        corpus_embeddings = torch.from_numpy(corpus_embeddings)
    elif isinstance(corpus_embeddings, list):
        corpus_embeddings = torch.stack(corpus_embeddings)


    #Check that corpus and queries are on the same device
    if corpus_embeddings.device != query_embeddings.device:
        query_embeddings = query_embeddings.to(corpus_embeddings.device)

    queries_result_list = [[] for _ in range(len(query_embeddings))]

    for query_start_idx in tqdm(range(0, len(query_embeddings), query_chunk_size)):
        # Iterate over chunks of the corpus
        for corpus_start_idx in range(0, len(corpus_embeddings), corpus_chunk_size):
            # Compute cosine similarities
            cos_scores = score_function(query_embeddings[query_start_idx:query_start_idx+query_chunk_size], corpus_embeddings[corpus_start_idx:corpus_start_idx+corpus_chunk_size])

            # Get top-k scores
            cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(top_k, len(cos_scores[0])), dim=1, largest=True, sorted=False)
            cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
            cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()

            for query_itr in range(len(cos_scores)):
                for sub_corpus_id, score in zip(cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]):
                    corpus_id = corpus_start_idx + sub_corpus_id
                    query_id = query_start_idx + query_itr
                    queries_result_list[query_id].append((corpus_id, score))

    #Sort and strip to top_k results
    for i, sublist in enumerate(queries_result_list):
        queries_result_list[i] = sorted(sublist, key=lambda x: x[1], reverse=True)[:top_k]

    return queries_result_list

In [3]:
queries = pd.read_csv(r"queries.doctrain.tsv", sep='\t', header=None, names=['qid', 'query'], index_col=['qid'])['query']

In [4]:
query_list = queries.tolist()

In [5]:
#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
model_name = 'msmarco-MiniLM-L6-cos-v5'
bi_encoder = SentenceTransformer(model_name, device=device)
bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens
top_k = 32                          #Number of passages we want to retrieve with the bi-encoder

In [6]:
question_embedding = bi_encoder.encode(query_list, convert_to_tensor=True, show_progress_bar=True, device=device)

Batches:   0%|          | 0/11470 [00:00<?, ?it/s]

In [7]:
corpus_embeddings = torch.load(f'corpus_embeddings_{model_name}.pt', map_location=device, mmap=True)

In [8]:
results = semantic_search(question_embedding, corpus_embeddings, top_k=top_k)

100%|██████████| 3671/3671 [08:47<00:00,  6.95it/s]


In [9]:
for i, (qid, result) in enumerate(zip(queries.index, results)):
    results[i] = [(qid, r[0]) for r in result]

In [10]:
# save results
with open(f'semantic_search_results_{model_name}.pkl', 'wb') as f:
    pickle.dump(results, f)

In [11]:
with open('corpus_id_to_doc_id.txt') as f:
    doc_ids = list(map(int, f.read().splitlines()))

In [12]:
relevance = pd.read_csv(r"msmarco-doctrain-qrels-idconverted.tsv", sep='\t', header=None, names=['qid', 'pid'], index_col=['qid'], dtype=int)['pid']
relevance.map(type).value_counts()

pid
<class 'int'>    367013
Name: count, dtype: int64

In [13]:
n_results = 10

In [14]:
pak_sum = 0
for qid, result in tqdm(zip(queries.index, results)):
    for i in range(n_results):
        if relevance[qid] == doc_ids[result[i][1]]:
            pak_sum += 1 / (i + 1)
            break
pak_sum / len(results)

367013it [00:03, 100236.19it/s]


0.3147422776294608