# M1 Semantic Search Engine with Faiss and DistilBERT

## Objective

- Build a search engine using FAISS similarity search library and a pre-trained DistilBERT model from Transformers.


- On your search for an optimal document retrieval method in the CDC’s huge knowledge base you decide to implement a semantic search engine to overcome known limitations of statistical (TfIdf) full-text search. Its weaknesses stem from the fact that it relies on counting and matching words in a search query with documents in the database in the document. Even though modern full-text search engines do include, synonyms, for example, still there are many ways to express the same idea. You know that Transformers models excel at contextual learning, so you decide to apply transfer learning with pre-trained BERT models to see if you can make your search engine smarter.

## Create a new Jupyter Notebook and load all relevant Python libraries.

In [18]:
import json
from pprint import pprint
import faiss
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer

## Open the provided JSON file called sentences.json. It contains a list of strings (sentences.)

In [2]:
# Load the documents
with open('data/sentences.json', 'r') as file:
    documents = json.load(file)
file.close()

## Use AutoTokenizer and AutoModel classes from Transformers library to load a pre-trained model from Transformers, along with the appropriate tokenizer.

In [3]:
# Load the a BERT model and a tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Create an empty inverted index with FAISS.

In [22]:
# Create a flat Faiss index
index = faiss.IndexIDMap(faiss.IndexFlatIP(768)) # the size of our vector space

## Write an encoder function that inputs a string and outputs a dense PyTorch tensor.

In [26]:
# Build a function that uses a BERT model to vectorize the texts
def encode(document):
    # Encode the documents and return vectors
    tokens = tokenizer(document, return_tensors='pt')
    vector = model(**tokens)[0].detach().squeeze()
    return torch.mean(vector, dim=0)

## Build a list of modeled vector representations for each document with a reusable encoder function you created in step 5.

In [27]:
# vectorize the documents
vectors = [encode(d) for d in documents]

In [28]:
[v.size() for v in vectors]

[torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768])]

## Populate the empty FAISS index with the output vectors.

In [31]:
# Add the document vectors into the index. They need to be transformed into numpy arrays first
index.add_with_ids(
    np.array([v.numpy() for v in vectors]),
    # the IDs will be 0 to len(documents)
    np.array(range(0, len(documents)))
)

## Build a search function that accepts a string query, encodes it, searches similar documents in the index, and returns top 5 results with their top_k scores.

In [29]:
# Build a function to search the index and return scored results
def search(query, k=5):
    # Search the index and return top scored results
    encoded_query = encode(query).unsqueeze(dim=0).numpy()
    top_k = index.search(encoded_query, k)
    scores = top_k[0][0]
    results = [documents[_id] for _id in top_k[1][0]]
    results = list(zip(results, scores))
    return results

## Test your search engine by asking some questions. Check out the attached questions.json for a few suggested questions to start with, but feel free to play around and search for anything you want!

In [30]:
pprint(search("spanish flu casualties", k=2))

[('The Spanish flu, also known as the 1918 flu pandemic, was an unusually '
  'deadly influenza pandemic caused by the H1N1 influenza A virus.',
  51.06952),
 ('As of 2018, approximately 37.9 million people are infected with HIV '
  'globally.',
  45.203133)]


In [32]:
questions = ["How many people have died during Black Death?", "Which diseases can be transmitted by animals?", "Connection between climate change and a likelihood of a pandemic", "What is an example of a latent virus", "Viruses in nanotechnology", "Giant viruses classification", "What are the notable pandemic prevention organizations?", "How many leprosy outbreaks are known to happen?", "What are the geographic areas with the highest transmission of malaria?", "How to prevent the spread of viral infections?"]

In [35]:
for question in questions:
    pprint(question)
    pprint(search(question, k=1))

'How many people have died during Black Death?'
[('As of 2018, approximately 37.9 million people are infected with HIV '
  'globally.',
  52.61343)]
'Which diseases can be transmitted by animals?'
[('A pandemic is an epidemic of an infectious disease that has spread across a '
  'large region, for instance multiple continents or worldwide, affecting a '
  'substantial number of people.',
  54.049507)]
'Connection between climate change and a likelihood of a pandemic'
[('A pandemic is an epidemic of an infectious disease that has spread across a '
  'large region, for instance multiple continents or worldwide, affecting a '
  'substantial number of people.',
  60.54062)]
'What is an example of a latent virus'
[('A pandemic is an epidemic of an infectious disease that has spread across a '
  'large region, for instance multiple continents or worldwide, affecting a '
  'substantial number of people.',
  59.449497)]
'Viruses in nanotechnology'
[('Current pandemics include COVID-19 (SARS-Co