Retrieve & Re-Rank Demo over Simple Wikipedia
This examples demonstrates the Retrieve & Re-Rank Setup and allows to search over Simple Wikipedia.

You can input a query or a question. The script then uses semantic search to find relevant passages in Simple English Wikipedia (as it is smaller and fits better in RAM).

For semantic search, we use SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') and retrieve 32 potentially passages that answer the input query.

Next, we use a more powerful CrossEncoder (cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')) that scores the query and all retrieved passages for their relevancy. The cross-encoder further boost the performance, especially when you search over a corpus for which the bi-encoder was not trained for.

In [1]:
!pip install -U sentence-transformers rank_bm25

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Collecting transformers<5.0.0,>=4.6.0 (from sentence-transformers)
  Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m111.1 MB/s[0m eta [36m0:00:00[0m
Collecting sentencepiece (from sentence-transformers)
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m73.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hu

In [4]:
import json
import gzip
import os
import torch
from sentence_transformers import SentenceTransformer, util, CrossEncoder

In [3]:
if torch.cuda.is_available():
  print("GPU available and ready to go")

GPU available and ready to go


In [5]:
# Bi-encoder
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256 # truncate long passages to 256 tokens
top_k = 32 # Number of passages we want to retrieve

# the bi-encoder will  retrieve 100 docs. We use a crossencoder to
# re-rank the results list
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

wikipedia_filepath = 'simplewiki-2020-11-01.jsonl.gz'

if not os.path.exists(wikipedia_filepath):
  util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz',
                wikipedia_filepath)

Downloading (…)5fedf/.gitattributes:   0%|          | 0.00/737 [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)2cb455fedf/README.md:   0%|          | 0.00/11.5k [00:00<?, ?B/s]

Downloading (…)b455fedf/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)edf/data_config.json:   0%|          | 0.00/25.5k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)5fedf/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

Downloading (…)fedf/train_script.py:   0%|          | 0.00/13.8k [00:00<?, ?B/s]

Downloading (…)2cb455fedf/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)455fedf/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

  0%|          | 0.00/50.2M [00:00<?, ?B/s]

In [7]:
passages = []
with gzip.open(wikipedia_filepath,'rt',encoding='utf8') as f:
  for line in f:
    data = json.loads(line.strip())
    # Add the first paragraph
    passages.append(data['paragraphs'][0])

print('Passages:',len(passages))

Passages: 169597


In [8]:
# Encode the passages

corpus_embeddings = bi_encoder.encode(passages,
                                      convert_to_tensor=True,
                                      show_progress_bar=True)

print("Shape of embeddings:",corpus_embeddings.shape)

Batches:   0%|          | 0/5300 [00:00<?, ?it/s]

Shape of embeddings: torch.Size([169597, 384])


In [9]:
# We also compare the results to lexical search (keyword search). Here, we use 
# the BM25 algorithm which is implemented in the rank_bm25 package.

from rank_bm25 import BM25Okapi
from sklearn.feature_extraction import _stop_words
import string
from tqdm.autonotebook import tqdm
import numpy as np

In [26]:
# lower case text and remove stopwords

def bm25_tokenizer(text):
  tokenized_doc = []
  for token in text.lower().split():
    token = token.strip(string.punctuation)

    if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS:
      tokenized_doc.append(token)
  return tokenized_doc

tokenized_corpus = []
for passage in tqdm(passages):
  tokenized_corpus.append(bm25_tokenizer(passage))

bm25 = BM25Okapi(tokenized_corpus)

  0%|          | 0/169597 [00:00<?, ?it/s]

In [27]:
# This function will search all wikipedia articles for passages that
# answer the query  \ h 
def search(query):
    print("Input question:", query)

    ##### BM25 search (lexical search) #####
    bm25_scores = bm25.get_scores(bm25_tokenizer(query))
    top_n = np.argpartition(bm25_scores, -5)[-5:]
    bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
    bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
    
    print("Top-3 lexical search (BM25) hits")
    for hit in bm25_hits[0:3]:
        print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))

    ##### Sematic Search #####
    # Encode the query using the bi-encoder and find potentially relevant passages
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    question_embedding = question_embedding.cuda()
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
    hits = hits[0]  # Get the hits for the first query

    ##### Re-Ranking #####
    # Now, score all retrieved passages with the cross_encoder
    cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)

    # Sort results by the cross-encoder scores
    for idx in range(len(cross_scores)):
      hits[idx]['cross-score'] = cross_scores[idx]

    # Output of top-5 hits from bi-encoder
    print("\n-------------------------\n")
    print("Top-3 Bi-Encoder Retrieval hits")
    hits = sorted(hits, key=lambda x: x['score'], reverse=True)
    for hit in hits[0:3]:
        print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))

    # Output of top-5 hits from re-ranker
    print("\n-------------------------\n")
    print("Top-3 Cross-Encoder Re-ranker hits")
    hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
    for hit in hits[0:3]:
        print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))


In [28]:
search(query = "What is the capital of the United States?")

Input question: What is the capital of the United States?
Top-3 lexical search (BM25) hits
	13.316	Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.
	11.434	Ohio is one of the 50 states in the United States. Its capital is Columbus. Columbus also is the largest city in Ohio.
	11.179	Nevada is one of the United States' states. Its capital is Carson City. Other big cities are Las Vegas and Reno.

-------------------------

Top-3 Bi-Encoder Retrieval hits
	0.622	Cities in the United States:
	0.597	The United States Capitol is the building where the United States Congress meets. It is the center of the legislative branch of the U.S. federal government. It is in Washington, D.C., on top of Capitol Hill at the east end of the National Mall.
	0.596	In the United States:

-

In [29]:
search(query = "When did the cold war end?")

Input question: When did the cold war end?
Top-3 lexical search (BM25) hits
	17.374	The Cold War was the tense relationship between the United States (and its allies), and the Soviet Union (the USSR and its allies) between the end of World War II and the fall of the Soviet Union. It is called the "Cold" War because the US and the USSR never actually fought each other directly. Instead, they opposed each other in conflicts known as proxy wars, where each country chose a side to support.
	17.291	The Reagan Doctrine was a document by the United States under the Reagan Administration. It was about being against the global influence of the Soviet Union during the final years of the Cold War. The doctrine lasted for less than a decade, it was the most important document of United States foreign policy from the early 1980s until the end of the Cold War in 1991.
	15.420	Cold Norton is a village and civil parish in Maldon District, Essex, England. In 2001 there were 1103 people living in Cold N

In [31]:
search(query = "Indira Gandhi")

Input question: Indira Gandhi
Top-3 lexical search (BM25) hits
	20.053	Indira Gandhi (19 November 1917 – 31 October 1984) was an Indian politician. She was Prime Minister of India from 1966 to 1977.She was the daughter of Jawaharlal Nehru, who was also Prime Minister of India. Her son, Rajiv Gandhi, later became Prime Minister of India. She married Feroze Gandhi, who was not related to the civil rights leader, Mahatma Gandhi.
	19.768	Swaraj was the Minister of External Affairs of India in the first Narendra Modi government (2014 – 2019). She was the second woman to hold the office, after Indira Gandhi.
	18.388	Rajiv Ratna Gandhi (; 20 August 1944 – 21 May 1991) was the seventh Prime Minister of India. He served as prime minister from 1984 to 1989. He took office after the 1984 murder of Prime Minister Indira Gandhi, his mother. He became the youngest Indian prime minister.

-------------------------

Top-3 Bi-Encoder Retrieval hits
	0.768	Indira Gandhi (19 November 1917 – 31 October 19

In [32]:
search(query="Dalai Lama")

Input question: Dalai Lama
Top-3 lexical search (BM25) hits
	31.141	The Dalai Lama is a religious figure in Tibetan Buddhism. He is its highest spiritual teacher of the Gelugpa school. A new Dalai Lama is said to be the reborn old Dalai Lama. This line goes back to 1391. The 14th and current Dalai Lama is Tenzin Gyatso.
	25.197	In the Tibetan language, lama means teacher. A lama is a religious teacher, guide, or mentor of Tibetan Buddhism. The meaning is similar to the Sanskrit term "guru". The word "lama" is part of the title, Dalai Lama.
	24.787	Kyabje Gelek Rimpoche (26 October 1939 – 14 February 2017) was a Tibetan-American Buddhist lama. He was born in Lhasa, China. Gelek was a nephew of the 13th Dalai Lama, Thubten Gyatso. He was tutored by many of the same masters who tutored the current (14th) Dalai Lama, Tenzin Gyatso.

-------------------------

Top-3 Bi-Encoder Retrieval hits
	0.758	The Dalai Lama is a religious figure in Tibetan Buddhism. He is its highest spiritual teacher