In [None]:
!pip install faiss-cpu
!pip install transformers

In [6]:
import json
from pprint import pprint
import faiss
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer

In [4]:
with open('sentences.json', 'r') as f:
  dataset = json.load(f)

In [None]:
model_name = 'distilbert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

In [7]:
def encode(document: str) -> torch.Tensor:
  tokens = tokenizer(document, return_tensors='pt')
  vector = model(**tokens)[0].detach().squeeze()
  return torch.mean(vector, dim=0)

In [9]:
vectors = [encode(d) for d in dataset]

In [13]:
index = faiss.IndexIDMap(faiss.IndexFlatIP(768)) # the size of our vector space

index.add_with_ids(
    np.array([t.numpy() for t in vectors]),
    np.array(range(0, len(dataset)))
)

In [15]:
def search(query: str, k=1):
  encoded_query = encode(query).unsqueeze(dim=0).numpy()
  top_k = index.search(encoded_query, k)
  scores = top_k[0][0]
  results = [dataset[_id] for _id in top_k[1][0]]
  return list(zip(results, scores))

In [16]:
pprint(search("spanish flu casualties", k=2))

[('The Spanish flu, also known as the 1918 flu pandemic, was an unusually '
  'deadly influenza pandemic caused by the H1N1 influenza A virus.',
  51.069515),
 ('As of 2018, approximately 37.9 million people are infected with HIV '
  'globally.',
  45.203125)]
