In [1]:
!pip install faiss-cpu
!pip install transformers
!pip install sentence-transformers

import json
import torch
import numpy as np
import faiss
import transformers
import sentence_transformers
from transformers import AutoTokenizer, AutoModel


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [9]:
with open('/content/sentences.json') as f:
  sentences = json.load(f)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
#tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
#model = AutoModel.from_pretrained("distilbert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
vectors = [
  # tokenize the document, return it as PyTorch tensors (vectors),
  # and pass it onto the model
  model(**tokenizer(sentence, return_tensors='pt'))[0].detach().squeeze()
  for sentence in sentences
]
[v.size() for v in vectors]

[torch.Size([35, 768]),
 torch.Size([37, 768]),
 torch.Size([25, 768]),
 torch.Size([18, 768]),
 torch.Size([24, 768]),
 torch.Size([55, 768]),
 torch.Size([57, 768]),
 torch.Size([24, 768]),
 torch.Size([27, 768]),
 torch.Size([35, 768]),
 torch.Size([43, 768])]

In [11]:
averaged_vectors = [torch.mean(vector, dim=0) for vector in vectors]

In [12]:
index = faiss.IndexIDMap(faiss.IndexFlatIP(768)) # the size of our vector space
# index all the documents, we need them as numpy arrays first
index.add_with_ids(
    np.array([t.numpy() for t in averaged_vectors]),
    # the IDs will be 0 to len(documents)
    np.array(range(0, len(sentences)))
)

In [13]:

def search(query: str, k=1):
  encoded_query = encode(query).unsqueeze(dim=0).numpy()
  top_k = index.search(encoded_query, k)
  scores = top_k[0][0]
  results = [sentences[_id] for _id in top_k[1][0]]
  return list(zip(results, scores))

def encode(document: str) -> torch.Tensor:
  tokens = tokenizer(document, return_tensors='pt')
  vector = model(**tokens)[0].detach().squeeze()
  return torch.mean(vector, dim=0)

In [18]:
search('Spanish pandemic', k=5)

[('The Spanish flu, also known as the 1918 flu pandemic, was an unusually deadly influenza pandemic caused by the H1N1 influenza A virus.',
  70.525696),
 ('The most fatal pandemic in recorded history was the Black Death (also known as The Plague), which killed an estimated 75–200 million people in the 14th century.',
  63.727413),
 ('Current pandemics include COVID-19 (SARS-CoV-2) and HIV/AIDS.', 63.29453),
 ('The COVID-19 pandemic, also known as the coronavirus pandemic, is an ongoing pandemic of coronavirus disease 2019 (COVID-19) caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2).',
  62.70733),
 ('A pandemic is an epidemic of an infectious disease that has spread across a large region, for instance multiple continents or worldwide, affecting a substantial number of people.',
  62.37272)]

In [7]:
with open('/content/questions.json') as f:
  questions = json.load(f)

In [14]:
for question in questions:
  print(question)
  print(search(question, k=3))

How many people have died during Black Death?
[('The most fatal pandemic in recorded history was the Black Death (also known as The Plague), which killed an estimated 75–200 million people in the 14th century.', 56.909508), ('The death toll of Spanish Flu is estimated to have been somewhere between 17 million and 50 million, and possibly as high as 100 million, making it one of the deadliest pandemics in human history.', 55.359592), ('A pandemic is an epidemic of an infectious disease that has spread across a large region, for instance multiple continents or worldwide, affecting a substantial number of people.', 53.01925)]
Which diseases can be transmitted by animals?
[('Cholera is an infection of the small intestine by some strains of the bacterium Vibrio cholerae.', 60.600384), ('A pandemic is an epidemic of an infectious disease that has spread across a large region, for instance multiple continents or worldwide, affecting a substantial number of people.', 60.030334), ('The most fat