In [16]:
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")

documents = [
    "Bert is awesome",
    "I'm looking for a job at Huggingface",
    "I'm selling a used car in good condition",
    "Huggingface team is awesome",
    "My car is outside",
    "Hugginface team is awesome"
    "Faiss consumes a lot of memory",
]

vectors = [
  # tokenize the document, return it as PyTorch tensors (vectors),
  # and pass it onto the model
  model(**tokenizer(document, return_tensors='pt'))[0].detach().squeeze()
  for document in documents
]

[v.size() for v in vectors]

[torch.Size([5, 768]),
 torch.Size([12, 768]),
 torch.Size([12, 768]),
 torch.Size([12, 768]),
 torch.Size([6, 768]),
 torch.Size([16, 768])]

In [17]:
import torch

averaged_vectors = [torch.mean(vector, dim=0) for vector in vectors]

[v.size() for v in averaged_vectors]

[torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768])]

In [18]:
def encode(document: str) -> torch.Tensor:
  tokens = tokenizer(document, return_tensors='pt')
  vector = model(**tokens)[0].detach().squeeze()
  return torch.mean(vector, dim=0)

In [19]:
import faiss
import numpy as np

index = faiss.IndexIDMap(faiss.IndexFlatIP(768)) # the size of our vector space
# index all the documents, we need them as numpy arrays first
index.add_with_ids(
    np.array([t.numpy() for t in averaged_vectors]),
    # the IDs will be 0 to len(documents)
    np.array(range(0, len(documents)))
)

def search(query: str, k=1):
  encoded_query = encode(query).unsqueeze(dim=0).numpy()
  top_k = index.search(encoded_query, k)
  scores = top_k[0][0]
  results = [documents[_id] for _id in top_k[1][0]]
  return list(zip(results, scores))

In [23]:
documents[1]
# => "I'm looking for a job at Huggingface"

search(documents[1], k=2)

[("I'm looking for a job at Huggingface", 62.06402),
 ("I'm selling a used car in good condition", 51.57861)]