<a href="https://colab.research.google.com/github/linhlinhle997/e2e-qa-distilbert/blob/develop/faiss_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --q datasets evaluate

In [None]:
!sudo apt-get install -y libopenblas-dev
!pip install -q condacolab
import condacolab
condacolab.install()
!mamba install -c pytorch faiss-gpu -y

In [None]:
import numpy as np
import collections
import torch
import faiss
import evaluate
from datasets import load_dataset
from tqdm.auto import tqdm

from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForQuestionAnswering,
    TrainingArguments,
    Trainer,
)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## Load Dataset

In [None]:
raw_ds = load_dataset("squad", split="train")

README.md:   0%|          | 0.00/7.62k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [None]:
raw_ds

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 87599
})

In [None]:
print(raw_ds["question"][0])
print(raw_ds["context"][0])
print(raw_ds["answers"][0])

To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
{'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}


Filter out non-answer samples

In [None]:
raw_ds = raw_ds.filter(
    lambda x: len(x["answers"]["text"]) > 0
)
raw_ds

Filter:   0%|          | 0/87599 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 87599
})

## Init pre-trained model

In [None]:
model_name = "distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device).half()

## Create Vector Embedding

In [None]:
def get_embeddings(text_list):
    with torch.no_grad():
        encoded_input = tokenizer(
            text_list,
            padding=True,
            truncation=True,
            return_tensors="pt"
        )
        encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
        model_output = model(**encoded_input)
        return model_output.last_hidden_state[:, 0].cpu().numpy() # Only get token <cls>

Test function

In [None]:
embedding = get_embeddings(raw_ds["context"][0])
embedding.shape

(1, 768)

In [None]:
batch_size = 32
embedding_column = "question_embedding"

embedding_ds = raw_ds.map(
    lambda batch: {embedding_column: get_embeddings(batch["context"])},
    batched=True,
    batch_size=batch_size
)

In [None]:
embedding_ds.add_faiss_index(column=embedding_column)

  0%|          | 0/88 [00:00<?, ?it/s]

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers', 'question_embedding'],
    num_rows: 87599
})

## Search similar samples with question

In [None]:
question = 'When did Beyonce start becoming popular?'

input_quest_embedding = get_embeddings([question])
input_quest_embedding.shape

(1, 768)

In [None]:
TOP_K = 5

scores, samples = embedding_ds.get_nearest_examples(
    embedding_column, input_quest_embedding, k=TOP_K
)

In [None]:
for idx, score in enumerate(scores):
    print(f"Top {idx + 1}\tScore: {score}")
    print(f"Question: {samples['question'][idx]}")
    print(f"Context: {samples['context'][idx]}")
    print(f"Answer: {samples['answers'][idx]}")
    print()

Top 1	Score: 29.190715789794922
Question: Who is the most influential recording artist of all time?
Context: Various music journalists, critical theorists, and authors have deemed Madonna the most influential female recording artist of all time. Author Carol Clerk wrote that "during her career, Madonna has transcended the term 'pop star' to become a global cultural icon." Rolling Stone of Spain wrote that "She became the first viral Master of Pop in history, years before the Internet was massively used. Madonna was everywhere; in the almighty music television channels, 'radio formulas', magazine covers and even in bookshops. A pop dialectic, never seen since The Beatles's reign, which allowed her to keep on the edge of tendency and commerciality." Laura Barcella in her book Madonna and Me: Women Writers on the Queen of Pop (2012) wrote that "really, Madonna changed everything the musical landscape, the '80s look du jour, and most significantly, what a mainstream female pop star could (