<a href="https://colab.research.google.com/github/krishnasaiv/search_tool/blob/main/Part_2_Semantic_Search_with_ML_and_BERT/2_1_Semantic_Search_Engine_with_Faiss_and_DistilBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Install Necessary Libraries

In [None]:
! pip install transformers torch
! pip install faiss-gpu


## 2. Download Distilled BERT model 

In [None]:
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")

## 3. Read Data

In [None]:
import json


with open('/content/sentences.json') as file:
    sentences = json.load(file)

with open('/content/questions.json') as file:
    questions = json.load(file)


In [None]:
questions[:3]

In [None]:
sentences[:3]

## 4. Encode Documents into Vectors

In [None]:
import torch
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")


def encode(document: str) -> torch.Tensor:
  tokens = tokenizer(document, return_tensors='pt')
  vector = model(**tokens)[0].detach().squeeze()
  return torch.mean(vector, dim=0)


averaged_vectors = [encode(sentence) for sentence in sentences]

[v.size() for v in averaged_vectors]

## 5. Upload to In Memory Faiss Vector Databse

In [None]:
import faiss
import numpy as np

index = faiss.IndexIDMap(faiss.IndexFlatIP(768)) 

index.add_with_ids(
    np.array([t.numpy() for t in averaged_vectors]),
    # the IDs will be 0 to len(documents)
    np.array(range(0, len(sentences)))
)


## 6. Search Functionality


*   Vectorize Query
*   Rank Documents
*   Retriece Results



In [None]:
def search(query: str, k=1):
  encoded_query = encode(query).unsqueeze(dim=0).numpy()
  top_k = index.search(encoded_query, k)
  scores = top_k[0][0]
  results = [sentences[_id] for _id in top_k[1][0]]
  return list(zip(results, scores))

questions[0], search(questions[0])