# Semantic search with FAISS (PyTorch)

Install the Transformers, Datasets, and Evaluate libraries to run this notebook.

In [2]:
!pip install datasets evaluate transformers[sentencepiece]
!pip install faiss-gpu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers[sentencepiece]
  Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m73.3 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.2.0-c

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [1]:
from datasets import load_dataset

issues_dataset = load_dataset("lewtun/github-issues", split="train")
issues_dataset

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading and preparing dataset json/lewtun--github-issues to /root/.cache/huggingface/datasets/lewtun___json/lewtun--github-issues-cff5093ecc410ea2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/12.2M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/lewtun___json/lewtun--github-issues-cff5093ecc410ea2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'timeline_url', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 3019
})

In [2]:
issues_dataset = issues_dataset.filter(
    lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset

Filter:   0%|          | 0/3019 [00:00<?, ? examples/s]

Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'timeline_url', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 808
})

In [3]:
columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset

Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 808
})

In [4]:
columns

['url',
 'repository_url',
 'labels_url',
 'comments_url',
 'events_url',
 'html_url',
 'id',
 'node_id',
 'number',
 'title',
 'user',
 'labels',
 'state',
 'locked',
 'assignee',
 'assignees',
 'milestone',
 'comments',
 'created_at',
 'updated_at',
 'closed_at',
 'author_association',
 'active_lock_reason',
 'pull_request',
 'body',
 'timeline_url',
 'performed_via_github_app',
 'is_pull_request']

In [5]:
columns_to_remove

{'active_lock_reason',
 'assignee',
 'assignees',
 'author_association',
 'closed_at',
 'comments_url',
 'created_at',
 'events_url',
 'id',
 'is_pull_request',
 'labels',
 'labels_url',
 'locked',
 'milestone',
 'node_id',
 'number',
 'performed_via_github_app',
 'pull_request',
 'repository_url',
 'state',
 'timeline_url',
 'updated_at',
 'url',
 'user'}

In [6]:
issues_dataset.set_format("pandas")
df = issues_dataset[:]

In [7]:
df

Unnamed: 0,html_url,title,comments,body
0,https://github.com/huggingface/datasets/issues...,Protect master branch,"[Cool, I think we can do both :), @lhoestq now...",After accidental merge commit (91c55355b634d0d...
1,https://github.com/huggingface/datasets/issues...,Backwards compatibility broken for cached data...,[Hi ! I guess the caching mechanism should hav...,## Describe the bug\r\nAfter upgrading to data...
2,https://github.com/huggingface/datasets/issues...,OSCAR unshuffled_original_ko: NonMatchingSplit...,[I tried `unshuffled_original_da` and it is al...,## Describe the bug\r\n\r\nCannot download OSC...
3,https://github.com/huggingface/datasets/issues...,load_dataset using default cache on Windows ca...,"[Hi @daqieq, thanks for reporting.\r\n\r\nUnfo...",## Describe the bug\r\nStandard process to dow...
4,https://github.com/huggingface/datasets/issues...,to_tf_dataset keeps a reference to the open da...,"[I did some investigation and, as it seems, th...",To reproduce:\r\n```python\r\nimport datasets ...
...,...,...,...,...
803,https://github.com/huggingface/datasets/issues/6,Error when citation is not given in the Datase...,[Yes looks good to me.\r\nNote that we may ref...,The following error is raised when the `citati...
804,https://github.com/huggingface/datasets/issues/5,ValueError when a split is empty,[To fix this I propose to modify only the file...,"When a split is empty either TEST, VALIDATION ..."
805,https://github.com/huggingface/datasets/issues/4,[Feature] Keep the list of labels of a dataset...,[Yes! I see mostly two options for this:\r\n- ...,It would be useful to keep the list of the lab...
806,https://github.com/huggingface/datasets/issues/3,[Feature] More dataset outputs,[Yes!\r\n- pandas will be a one-liner in `arro...,Add the following dataset outputs:\r\n\r\n- Sp...


In [9]:
df["comments"][1].tolist()

["Hi ! I guess the caching mechanism should have considered the new `filter` to be different from the old one, and don't use cached results from the old `filter`.\r\nTo avoid other users from having this issue we could make the caching differentiate the two, what do you think ?",
 "If it's easy enough to implement, then yes please 😄  But this issue can be low-priority, since I've only encountered it in a couple of `transformers` CI tests.",
 "Well it can cause issue with anyone that updates `datasets` and re-run some code that uses filter, so I'm creating a PR",
 "I just merged a fix, let me know if you're still having this kind of issues :)\r\n\r\nWe'll do a release soon to make this fix available",
 'Definitely works on several manual cases with our dummy datasets, thank you @lhoestq !',
 'Fixed by #2947.']

In [10]:
comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)

Unnamed: 0,html_url,title,comments,body
0,https://github.com/huggingface/datasets/issues...,Protect master branch,"Cool, I think we can do both :)",After accidental merge commit (91c55355b634d0d...
1,https://github.com/huggingface/datasets/issues...,Protect master branch,@lhoestq now the 2 are implemented.\r\n\r\nPle...,After accidental merge commit (91c55355b634d0d...
2,https://github.com/huggingface/datasets/issues...,Backwards compatibility broken for cached data...,Hi ! I guess the caching mechanism should have...,## Describe the bug\r\nAfter upgrading to data...
3,https://github.com/huggingface/datasets/issues...,Backwards compatibility broken for cached data...,"If it's easy enough to implement, then yes ple...",## Describe the bug\r\nAfter upgrading to data...


In [11]:
from datasets import Dataset

comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset

Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 2964
})

In [12]:
comments_dataset = comments_dataset.map(
    lambda x: {"comment_length": len(x["comments"].split())}
)

Map:   0%|          | 0/2964 [00:00<?, ? examples/s]

In [13]:
comments_dataset

Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 2964
})

In [14]:
comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset

Filter:   0%|          | 0/2964 [00:00<?, ? examples/s]

Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 2175
})

In [15]:
def concatenate_text(examples):
    return {
        "text": examples["title"]
        + " \n "
        + examples["body"]
        + " \n "
        + examples["comments"]
    }


comments_dataset = comments_dataset.map(concatenate_text)

Map:   0%|          | 0/2175 [00:00<?, ? examples/s]

In [16]:
from transformers import AutoTokenizer, AutoModel

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

Downloading (…)okenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [17]:
import torch

device = torch.device("cuda")
model.to(device)

MPNetModel(
  (embeddings): MPNetEmbeddings(
    (word_embeddings): Embedding(30527, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): MPNetEncoder(
    (layer): ModuleList(
      (0-11): 12 x MPNetLayer(
        (attention): MPNetAttention(
          (attn): MPNetSelfAttention(
            (q): Linear(in_features=768, out_features=768, bias=True)
            (k): Linear(in_features=768, out_features=768, bias=True)
            (v): Linear(in_features=768, out_features=768, bias=True)
            (o): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (intermediate): MPNetIntermediate(
          (dense): Linear(in_

In [18]:
def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

In [19]:
def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

In [23]:
embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape

torch.Size([1, 768])

In [24]:
embeddings_dataset = comments_dataset.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)

Map:   0%|          | 0/2175 [00:00<?, ? examples/s]

In [25]:
embeddings_dataset

Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length', 'text', 'embeddings'],
    num_rows: 2175
})

In [26]:
embeddings_dataset["embeddings"][0]

[-0.15531955659389496,
 -0.10023057460784912,
 -0.07032109051942825,
 -0.07981694489717484,
 -0.10425378382205963,
 -0.1879875659942627,
 0.010403402149677277,
 0.27286383509635925,
 -0.009848777204751968,
 -0.08413968235254288,
 0.2926148772239685,
 -0.07787511497735977,
 -0.13245734572410583,
 0.21589063107967377,
 -0.05622908100485802,
 0.17054562270641327,
 0.20032304525375366,
 -0.04079785570502281,
 -0.09233297407627106,
 0.027923252433538437,
 0.022608397528529167,
 -0.06870279461145401,
 0.0832001119852066,
 0.04363005980849266,
 -0.1837823987007141,
 0.005401225294917822,
 -0.014811672270298004,
 0.15649646520614624,
 -0.4397662580013275,
 -0.4695698320865631,
 0.1452465057373047,
 0.23224009573459625,
 0.025106102228164673,
 0.5705347061157227,
 -0.00010436369484523311,
 -0.00034549616975709796,
 0.1694866269826889,
 -0.040600862354040146,
 -0.14010265469551086,
 -0.21546350419521332,
 -0.4901069700717926,
 -0.4447956681251526,
 -0.07006335258483887,
 -0.008354305289685726,
 

In [27]:
embeddings_dataset.add_faiss_index(column="embeddings")

  0%|          | 0/3 [00:00<?, ?it/s]

Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length', 'text', 'embeddings'],
    num_rows: 2175
})

In [28]:
question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape

(1, 768)

In [29]:
scores, samples = embeddings_dataset.get_nearest_examples(
    "embeddings", question_embedding, k=5
)

In [30]:
import pandas as pd

samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)

In [31]:
samples_df

Unnamed: 0,html_url,title,comments,body,comment_length,text,embeddings,scores
4,https://github.com/huggingface/datasets/issues...,Discussion using datasets in offline mode,Requiring online connection is a deal breaker ...,"`datasets.load_dataset(""csv"", ...)` breaks if ...",57,Discussion using datasets in offline mode \n `...,"[-0.47318071126937866, 0.24578367173671722, -0...",25.505049
3,https://github.com/huggingface/datasets/issues...,Discussion using datasets in offline mode,"The local dataset builders (csv, text , json a...","`datasets.load_dataset(""csv"", ...)` breaks if ...",38,Discussion using datasets in offline mode \n `...,"[-0.44908538460731506, 0.20950642228126526, -0...",24.555553
2,https://github.com/huggingface/datasets/issues...,Discussion using datasets in offline mode,I opened a PR that allows to reload modules th...,"`datasets.load_dataset(""csv"", ...)` breaks if ...",179,Discussion using datasets in offline mode \n `...,"[-0.47164806723594666, 0.2902272641658783, -0....",24.148972
1,https://github.com/huggingface/datasets/issues...,Discussion using datasets in offline mode,"> here is my way to load a dataset offline, bu...","`datasets.load_dataset(""csv"", ...)` breaks if ...",76,Discussion using datasets in offline mode \n `...,"[-0.499260276556015, 0.22699761390686035, -0.0...",22.893995
0,https://github.com/huggingface/datasets/issues...,Discussion using datasets in offline mode,"here is my way to load a dataset offline, but ...","`datasets.load_dataset(""csv"", ...)` breaks if ...",47,Discussion using datasets in offline mode \n `...,"[-0.49025776982307434, 0.2288963347673416, -0....",22.406639


In [32]:
for _, row in samples_df.iterrows():
    print(f"COMMENT: {row.comments}")
    print(f"SCORE: {row.scores}")
    print(f"TITLE: {row.title}")
    print(f"URL: {row.html_url}")
    print("=" * 50)
    print()

COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.

@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505048751831055
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824

COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
```python
datasets = load_dataset('text', data_files=data_files)
```

We'll do a new release soon
SCORE: 24.555553436279297
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824

COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's n