In [129]:
from datasets import load_dataset

dataset = load_dataset('BeIR/trec-news-generated-queries')


Found cached dataset json (/home/ubuntu/.cache/huggingface/datasets/BeIR___json/BeIR--trec-news-generated-queries-58e8f34dd4c75682/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/1 [00:00<?, ?it/s]

In [130]:
dataset = dataset['train'].shuffle(seed=42)

Loading cached shuffled indices for dataset at /home/ubuntu/.cache/huggingface/datasets/BeIR___json/BeIR--trec-news-generated-queries-58e8f34dd4c75682/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-8f0959c31593cac5.arrow


In [135]:
queries = []
documents = []
titles = []
for i in range(10000):
    queries.append(dataset[i]['query'])
    documents.append(dataset[i]['text'])
    titles.append(dataset[i]['title'])

In [146]:
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch

if not torch.cuda.is_available():
    print("Warning: No GPU found. Please add GPU to your notebook")


#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 512

#The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

In [152]:
passages = []
for i in range(len(documents)):
    words = documents[i].split()
    for j in range(0, len(words), 400):
        passages.append({"text": " ".join(words[j:j+400]), "document_id": i})
encoded_passages = bi_encoder.encode(passages, show_progress_bar=True)

Batches:   0%|          | 0/651 [00:00<?, ?it/s]

In [149]:
import nltk
nltk.download('punkt')
sentences = []
for i in range(len(documents)):
    sents = nltk.sent_tokenize(documents[i])
    sents = [{'text': s, 'document_id': i} for s in sents]
    sentences.extend(sents)

[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [150]:
encoded_sentences = bi_encoder.encode([s['text'] for s in sentences], show_progress_bar=True)

Batches:   0%|          | 0/9773 [00:00<?, ?it/s]

In [156]:
from rank_bm25 import BM25Okapi
from sklearn.feature_extraction import _stop_words
import string
from tqdm.autonotebook import tqdm
import numpy as np


# We lower case our text and remove stop-words from indexing
def bm25_tokenizer(text):
    tokenized_doc = []
    for token in text.lower().split():
        token = token.strip(string.punctuation)

        if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS:
            tokenized_doc.append(token)
    return tokenized_doc


tokenized_corpus = []
for passage in tqdm(passages):
    tokenized_corpus.append(bm25_tokenizer(passage['text']))

bm25_passages = BM25Okapi(tokenized_corpus)

tokenized_corpus = []
for sentence in tqdm(sentences):
    tokenized_corpus.append(bm25_tokenizer(sentence['text']))

bm25_sentences = BM25Okapi(tokenized_corpus)

  0%|          | 0/20809 [00:00<?, ?it/s]

  0%|          | 0/312717 [00:00<?, ?it/s]

In [175]:
def search(query, bm25, corpus, corpus_embeddings):
    hits = []
    num_bm_25 = 20
    bm25_scores = bm25.get_scores(bm25_tokenizer(query))
    top_n = np.argpartition(bm25_scores, -num_bm_25)[-num_bm_25:]
    hits.extend(top_n)
    num_bi_encoder = 20
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    question_embedding = question_embedding.cuda()
    bi_encoder_hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=num_bi_encoder)
    bi_encoder_hits = bi_encoder_hits[0]  # Get the hits for the first query
    hits.extend(hit['corpus_id'] for hit in bi_encoder_hits)
    hits = list(set(hits))
    # cross-encode
    cross_encoder_hits = cross_encoder.predict([(query, corpus[hit]['text']) for hit in hits])
    hits = [hits[i] for i in np.argsort(cross_encoder_hits)[::-1]]
    return hits[:1]

In [179]:
ac_passages = 0
ac_sentences = 0
samples = 100
for i in range(samples):
    ret = search(queries[i], bm25_passages, passages, encoded_passages)
    documents = [passages[hit]['document_id'] for hit in ret]
    if i in documents:
        ac_passages += 1
    ret = search(queries[i], bm25_sentences, sentences, encoded_sentences)
    documents = [sentences[hit]['document_id'] for hit in ret]
    if i in documents:
        ac_sentences += 1

print('Passage accuracy: ', ac_passages/samples)
print('Sentence accuracy: ', ac_sentences/samples)

Passage accuracy:  0.62
Sentence accuracy:  0.5
