# Semantic search with FAISS (PyTorch)

Install the Transformers, Datasets, and Evaluate libraries to run this notebook.

In [1]:
%%capture
!pip install datasets evaluate transformers[sentencepiece]
!pip install faiss-gpu

Also, log into Hugging face.

In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [section 5](https://huggingface.co/course/chapter5/5), we created a <font color='blue'>dataset</font> of <font color='blue'>GitHub issues</font> and <font color='blue'>comments</font> from the 🤗 Datasets repository. In this section we'll use this information to <font color='blue'>build a search engine</font> that can help us find answers to our most pressing questions about the library!

<Youtube id="OATCgQtNX2o"/>

## Using embeddings for semantic search

As we saw in [Chapter 1](https://huggingface.co/course/chapter1), Transformer-based language models represent each <font color='blue'>token</font> in a <font color='blue'>span of text</font> as an <font color='blue'>embedding vector</font>. It turns out that one can <font color='blue'>pool</font> the <font color='blue'>individual embeddings</font> to create a <font color='blue'>vector representation</font> for <font color='blue'>whole sentences</font>, <font color='blue'>paragraphs</font>, or (in some cases) <font color='blue'>documents</font>. These <font color='blue'>embeddings</font> can then be used to <font color='blue'>find similar documents</font> in the corpus by computing the <font color='blue'>dot-product similarity</font> (or some other similarity metric) between each embedding and returning the documents with the greatest overlap.

In this section we'll use <font color='blue'>embeddings</font> to <font color='blue'>develop</font> a <font color='blue'>semantic search engine</font>. These search engines offer several advantages over conventional approaches that are based on matching keywords in a query with the documents.

<div class="flex justify-center">
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter5/semantic-search.svg" alt="Semantic search."/>


## Loading and preparing the dataset


The first thing we need to do is <font color='blue'>download</font> our <font color='blue'>dataset</font> of <font color='blue'>GitHub issues</font>, so let's use `load_dataset()` function as usual:

In [27]:
from datasets import load_dataset

issues_dataset = load_dataset("lewtun/github-issues", split="train")
issues_dataset

Repo card metadata block was not found. Setting CardData to empty.


Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'timeline_url', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 3019
})

Here we've specified the <font color='blue'>default `train` split</font> in `load_dataset()`, so it returns a <font color='blue'>`Dataset`</font> instead of a `DatasetDict`. The first order of business is to <font color='blue'>filter out</font> the <font color='blue'>pull requests</font>, as these tend to be rarely used for answering user queries and will introduce noise in our search engine. As should be familiar by now, we can use the <font color='blue'>`Dataset.filter()` function</font> to <font color='blue'>exclude</font> these <font color='blue'>rows</font> in our dataset. While we're at it, let's also <font color='blue'>filter out rows</font> with <font color='blue'>no comments</font>, since these provide no answers to user queries:


In [4]:
issues_dataset = issues_dataset.filter(
    lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset

Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'timeline_url', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 808
})

We can see that there are a <font color='blue'>lot of columns</font> in our <font color='blue'>dataset</font>, most of which we don't need to build our search engine. From a search perspective, the <font color='blue'>most informative columns</font> are <font color='blue'>`title`</font>, <font color='blue'>`body`</font>, and <font color='blue'>`comments`</font>, while <font color='blue'>`html_url`</font> provides us with a link back to the source issue. Let's use the `Dataset.remove_columns()` function to drop the rest:


In [5]:
columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset

Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 808
})

To <font color='blue'>create</font> our <font color='blue'>embeddings</font> we'll <font color='blue'>augment each comment</font> with the <font color='blue'>issue's title</font> and <font color='blue'>body</font>, since these fields often include useful contextual information. Because our <font color='blue'>`comments` column</font> is currently a <font color='blue'>list of comments</font> for each issue, we need to <font color='blue'>explode the column</font> so that each row consists of an `(html_url, title, body, comment)` tuple. In Pandas we can do this with the [`DataFrame.explode()` function](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.explode.html), which creates a <font color='blue'>new row</font> for <font color='blue'>each element in a list-like column</font>, while replicating all the other column values. To see this in action, let's first switch to the Pandas  `DataFrame` format:


In [6]:
issues_dataset.set_format("pandas")
df = issues_dataset[:]

If we inspect the first row in this `DataFrame` we can see there are <font color='blue'>two comments</font> associated with this issue:

In [40]:
comments = df["comments"][0]
for i, comment in enumerate(comments, 1):
    print(f"\nComment {i}:\n{comment.strip()}")


Comment 1:
Cool, I think we can do both :)

Comment 2:
@lhoestq now the 2 are implemented.

Please note that for the the second protection, finally I have chosen to protect the master branch only from **merge commits** (see update comment above), so no need to disable/re-enable the protection on each release (direct commits, different from merge commits, can be pushed to the remote master branch; and eventually reverted without messing up the repo history).


When we explode `df`, we expect to get <font color='blue'>one row</font> for <font color='blue'>each</font> of these <font color='blue'>comments</font>. Let's check if that's the case:

In [37]:
comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)

Unnamed: 0,html_url,title,comments,body
0,https://github.com/huggingface/datasets/issues...,Protect master branch,"Cool, I think we can do both :)",After accidental merge commit (91c55355b634d0d...
1,https://github.com/huggingface/datasets/issues...,Protect master branch,@lhoestq now the 2 are implemented.\r\n\r\nPle...,After accidental merge commit (91c55355b634d0d...
2,https://github.com/huggingface/datasets/issues...,Backwards compatibility broken for cached data...,Hi ! I guess the caching mechanism should have...,## Describe the bug\r\nAfter upgrading to data...
3,https://github.com/huggingface/datasets/issues...,Backwards compatibility broken for cached data...,"If it's easy enough to implement, then yes ple...",## Describe the bug\r\nAfter upgrading to data...


Great, we can see the <font color='blue'>rows</font> have been <font color='blue'>replicated</font>, with the <font color='blue'>`comments` column</font> containing the <font color='blue'>individual comments</font>! Now that we're finished with Pandas, we can quickly switch back to a `Dataset` by loading the `DataFrame` in memory:


In [9]:
from datasets import Dataset

comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset

Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 2964
})

Okay, this has given us a few thousand comments to work with!



<Tip>

✏️ **Try it out!** See if you can use `Dataset.map()` to explode the `comments` column of `issues_dataset` <font color='blue'>without</font> resorting to the <font color='blue'>use of Pandas</font>. This is a little tricky; you might find the ["Batch mapping"](https://huggingface.co/docs/datasets/about_map_batch#batch-mapping) section of the 🤗 Datasets documentation useful for this task.

</Tip>

In [10]:
# Exercise: Calculate and filter comments based on their length

# Function to explode the comments column
def explode_comments(batch):
    # Create a list to store the new examples
    new_examples = {key: [] for key in batch.keys()}

    # Iterate over each example in the batch
    for i in range(len(batch['comments'])):
        # Get the current example
        example = {key: batch[key][i] for key in batch.keys()}

        # Explode the comments into separate examples
        for comment in example['comments']:
          comment_length = len(comment.split())
          if comment_length > 15:
            for key in example.keys():
                if key == 'comments':
                    new_examples[key].append(comment)
                else:
                    new_examples[key].append(example[key])

    return new_examples

# Apply the explode_comments function using Dataset.map() with batched=True
exploded_comments_dataset = issues_dataset.map(explode_comments, batched=True)
exploded_comments_dataset

Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 2175
})

Now that we have one comment per row, let's create a new `comments_length` column that <font color='blue'>contains</font> the <font color='blue'>number of words</font> per <font color='blue'>comment</font>:


In [11]:
comments_dataset = comments_dataset.map(
    lambda x: {"comment_length": len(x["comments"].split())}
)

Map:   0%|          | 0/2964 [00:00<?, ? examples/s]

We can use this <font color='blue'>new column</font> to <font color='blue'>filter</font> out <font color='blue'>short comments</font>, which typically include things like "cc @lewtun" or "Thanks!" that are not relevant for our search engine. There's no precise number to select for the filter, but around <font color='blue'>15 words</font> seems like a good start:


In [12]:
comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset

Filter:   0%|          | 0/2964 [00:00<?, ? examples/s]

Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 2175
})

Having cleaned up our dataset a bit, let's <font color='blue'>concatenate</font> the <font color='blue'>issue title</font>, <font color='blue'>description</font>, and <font color='blue'>comments together</font> in a new <font color='blue'>`text` column</font>. As usual, we'll write a simple function that we can pass to `Dataset.map()`:

In [13]:
def concatenate_text(examples):
    return {
        "text": examples["title"]
        + " \n "
        + examples["body"]
        + " \n "
        + examples["comments"]
    }


comments_dataset = comments_dataset.map(concatenate_text)

Map:   0%|          | 0/2175 [00:00<?, ? examples/s]

We're finally ready to create some embeddings! Let's take a look.

## Creating text embeddings

We saw in [Chapter 2](https://huggingface.co/course/chapter2) that we can obtain <font color='blue'>token embeddings</font> by using the <font color='blue'>`AutoModel` class</font>. All we need to do is pick a suitable checkpoint to load the model from. Fortunately, there's a library called <font color='blue'>`sentence-transformers`</font> that is dedicated to <font color='blue'>creating embeddings</font>. As described in the library's [documentation](https://www.sbert.net/examples/applications/semantic-search/README.html#symmetric-vs-asymmetric-semantic-search), our use case is an example of<font color='blue'>asymmetric semantic search</font> because we have a <font color='blue'>short query</font> whose <font color='blue'>answer</font> we'd like to find in a <font color='blue'>longer document</font>, like a an issue comment. The handy [model overview table](https://www.sbert.net/docs/pretrained_models.html#model-overview) in the documentation indicates that the <font color='blue'>`multi-qa-mpnet-base-dot-v1` checkpoint</font> has the <font color='blue'>best performance</font> for semantic search, so we'll use that for our application. We'll also load the tokenizer using the same checkpoint:


In [14]:
from transformers import AutoTokenizer, AutoModel

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

To <font color='blue'>speed up</font> the embedding process, it helps to place the <font color='blue'>model and inputs</font> on a <font color='blue'>GPU</font> device, so let's do that now:

In [15]:
import torch

device = torch.device("cuda")
model.to(device)

MPNetModel(
  (embeddings): MPNetEmbeddings(
    (word_embeddings): Embedding(30527, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): MPNetEncoder(
    (layer): ModuleList(
      (0-11): 12 x MPNetLayer(
        (attention): MPNetAttention(
          (attn): MPNetSelfAttention(
            (q): Linear(in_features=768, out_features=768, bias=True)
            (k): Linear(in_features=768, out_features=768, bias=True)
            (v): Linear(in_features=768, out_features=768, bias=True)
            (o): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (intermediate): MPNetIntermediate(
          (dense): Linear(in_

Note that we've set <font color='blue'>`from_pt=True`</font> as an <font color='blue'>argument</font> of the <font color='blue'>`from_pretrained()` method</font>. That's because the `multi-qa-mpnet-base-dot-v1` checkpoint only has <font color='blue'>PyTorch weights</font>, so setting `from_pt=True` will <font color='blue'>automatically convert</font> them to the <font color='blue'>TensorFlow format</font> for us. As you can see, it is very simple to switch between frameworks in 🤗 Transformers!

As we mentioned earlier, we'd like to <font color='blue'>represent each entry</font> in our GitHub issues corpus as a <font color='blue'>single vector</font>, so we need to "pool" or average our token embeddings in some way. One popular approach is to perform <font color='blue'>CLS pooling</font> on our <font color='blue'>model's outputs</font>, where we simply <font color='blue'>collect the last hidden state</font> for the special `[CLS]` token. The following function does the trick for us:


In [16]:
def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

Next, we'll create a <font color='blue'>helper function</font> that will <font color='blue'>tokenize</font> a <font color='blue'>list of documents</font>, <font color='blue'>place</font> the <font color='blue'>tensors</font> on the <font color='blue'>GPU</font>, feed them to the model, and finally <font color='blue'>apply CLS pooling</font> to the outputs:

In [17]:
def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

We can test the function works by <font color='blue'>feeding it the first text entry</font> in our corpus and inspecting the output shape:

In [18]:
embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape

torch.Size([1, 768])

Great, we've converted the first entry in our corpus into a <font color='blue'>768-dimensional vector</font>! We can use `Dataset.map()` to apply our `get_embeddings()` function to each row in our corpus, so let's create a new `embeddings` column as follows:


In [19]:
embeddings_dataset = comments_dataset.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)

Map:   0%|          | 0/2175 [00:00<?, ? examples/s]

Notice that we've converted the embeddings to NumPy arrays -- that's because 🤗 Datasets requires this format when we try to index them with FAISS, which we'll do next.


## Using FAISS for efficient similarity search

Now that we have a <font color='blue'>dataset of embeddings</font>, we need some way to <font color='blue'>search</font> over them. To do this, we'll use a special data structure in 🤗 Datasets called a _FAISS index_. [FAISS](https://faiss.ai/) (short for <font color='blue'>Facebook AI Similarity Search</font>) is a library that provides <font color='blue'>efficient algorithms</font> to quickly <font color='blue'>search</font> and <font color='blue'>cluster embedding vectors</font>.

The basic idea behind FAISS is to <font color='blue'>create</font> a special <font color='blue'>data structure</font> called an <font color='blue'>index</font> that allows one to <font color='blue'>find which embeddings</font> are <font color='blue'>similar</font> to an <font color='blue'>input embedding</font>. Creating a FAISS index in 🤗 Datasets is simple -- we use the `Dataset.add_faiss_index()` function and specify which column of our dataset we'd like to index:


In [20]:
%%capture
!pip install faiss-cpu

In [21]:
embeddings_dataset.add_faiss_index(column="embeddings")

  0%|          | 0/3 [00:00<?, ?it/s]

Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length', 'text', 'embeddings'],
    num_rows: 2175
})

We can now perform <font color='blue'>queries</font> on this <font color='blue'>index</font> by doing a <font color='blue'>nearest neighbor lookup</font> with the `Dataset.get_nearest_examples()` function. Let's test this out by first embedding a question as follows:

In [22]:
question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape

(1, 768)

Just like with the documents, we now have a <font color='blue'>768-dimensional vector</font> representing the <font color='blue'>query</font>, which we can <font color='blue'>compare against the whole corpus</font> to <font color='blue'>find</font> the most <font color='blue'>similar embeddings</font>:

In [23]:
scores, samples = embeddings_dataset.get_nearest_examples(
    "embeddings", question_embedding, k=5
)

The `Dataset.get_nearest_examples()` function returns a <font color='blue'>tuple of scores</font> that <font color='blue'>rank</font> the <font color='blue'>overlap between</font> the <font color='blue'>query</font> and the <font color='blue'>document</font>, and a corresponding set of samples (here, the <font color='blue'>5 best matches</font>). Let's collect these in a `pandas.DataFrame` so we can easily sort them:


In [24]:
import pandas as pd

samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)

Now we can iterate over the first few rows to see how well our query matched the available comments:

In [25]:
for _, row in samples_df.iterrows():
    print(f"COMMENT: {row.comments}")
    print(f"SCORE: {row.scores}")
    print(f"TITLE: {row.title}")
    print(f"URL: {row.html_url}")
    print("=" * 50)
    print()

COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.

@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505016326904297
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824

COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
```python
datasets = load_dataset('text', data_files=data_files)
```

We'll do a new release soon
SCORE: 24.555540084838867
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824

COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's n

Not bad! Our second hit seems to match the query.

<Tip>

✏️ **Try it out!** Create your own query and see whether you can find an <font color='blue'>answer</font> in the <font color='blue'>retrieved documents</font>. You might have to increase the `k` parameter in `Dataset.get_nearest_examples()` to broaden the search.

</Tip>

In [26]:
# Exercise: Increase k in Dataset.get_nearest_examples()

# Define a different question
question = "What are the best practices for writing clean code in Python?"

# Get the embedding for the new question
question_embedding = get_embeddings([question]).cpu().detach().numpy()

# Perform a nearest neighbor search with k=10
scores, samples = embeddings_dataset.get_nearest_examples(
    "embeddings", question_embedding, k=10
)

# Convert the samples to a pandas DataFrame and display the top results
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)

# Print the top results
for _, row in samples_df.iterrows():
    print(f"COMMENT: {row.comments}")
    print(f"SCORE: {row.scores}")
    print(f"TITLE: {row.title}")
    print(f"URL: {row.html_url}")
    print("=" * 50)
    print()

COMMENT: Hi ! The ClassLabel feature type encodes the labels as integers.
The integer corresponds to the index of the label name in the `names` list of the ClassLabel.
Here that means that the labels are 'entailment' (0), 'neutral' (1), 'contradiction' (2).

You can get the label names back by using `a.features['label'].int2str(i)`.

SCORE: 45.326255798339844
TITLE: making labels consistent across the datasets
URL: https://github.com/huggingface/datasets/issues/2207

COMMENT: Alternatively huggingface could consider some submodule type structure like:

`import huggingface.datasets`
`import huggingface.transformers`

`datasets` is a very common module in ML and should be an end-user decision and not scope all of python ¯\_(ツ)_/¯ 

SCORE: 45.293704986572266
TITLE: Add helper to resolve namespace collision
URL: https://github.com/huggingface/datasets/issues/1590

COMMENT: > `from_dict` was added in #350 that was unfortunately not included in the 0.3.0 release. It's going to be 