In [1]:
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch

if not torch.cuda.is_available():
    print("Warning: No GPU found. Please add GPU to your notebook")


#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 512     #Truncate long passages to 256 tokens

#The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')


In [2]:
from datasets import load_dataset

dataset = load_dataset('BeIR/trec-news-generated-queries', split='train').shuffle(seed=42)


Found cached dataset json (/home/ubuntu/.cache/huggingface/datasets/BeIR___json/BeIR--trec-news-generated-queries-58e8f34dd4c75682/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)
Loading cached shuffled indices for dataset at /home/ubuntu/.cache/huggingface/datasets/BeIR___json/BeIR--trec-news-generated-queries-58e8f34dd4c75682/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-8f0959c31593cac5.arrow


In [3]:
queries = []
passages = []
titles = []
for i in range(10000):
    queries.append(dataset[i]['query'])
    passages.append(dataset[i]['text'])
    titles.append(dataset[i]['title'])

In [4]:
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)

Batches:   0%|          | 0/313 [00:00<?, ?it/s]

In [7]:
def search(query, top_k=3):
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    question_embedding = question_embedding.cuda()
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100)
    
    hits = hits[0]  # Get the hits for the first query

    cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)
    for idx in range(len(cross_scores)):
        hits[idx]['cross_score'] = cross_scores[idx]

    hits = sorted(hits, key=lambda x: x['cross_score'], reverse=True)

    return hits[0:top_k]
    

In [10]:
n = 200
k = 3
recall_at_k = 0
for i in range(n):
    results = search(queries[i], top_k=k)
    if any(result['corpus_id'] == i for result in results):
        recall_at_k += 1
print("Recall@{}: {}".format(k, recall_at_k/n))

Recall@3: 0.815
