In [42]:
# followed this tutorial-: https://deepnote.com/blog/semantic-search-using-faiss-and-mpnet

# Libraries

In [2]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

# Semantic Embeddings class
## does the embedding and averaging pooling and as well

In [12]:
class semanticEmbedding:
    def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    def get_embedding(self, sentences):
    # Tokenize sentences
        encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        # Perform pooling
        sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])

        # Normalize embeddings
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        return sentence_embeddings.detach().numpy()

In [17]:
# creating an instance of the class
model = semanticEmbedding()

In [14]:
a = model.get_embedding("I am kishan")

In [16]:
print(a.shape)

(1, 768)


In [36]:
import faiss 
import numpy as np
import pickle

In [37]:
class FaissIdx:

    def __init__(self, model, dim=768):
        self.index = faiss.IndexFlatIP(dim) # need to load the pickle model in the final file
        # self.index = faiss
        # Maintaining the document data
        self.doc_map = dict()
        self.model = model
        self.ctr = 0

    def add_doc(self, document_text):
        self.index.add(self.model.get_embedding(document_text))
        self.doc_map[self.ctr] = document_text # store the original document text
        self.ctr += 1

    def search_doc(self, query, k=3):
        D, I = self.index.search(self.model.get_embedding(query), k)
        return [{self.doc_map[idx]: score} for idx, score in zip(I[0], D[0]) if idx in self.doc_map]
    def save_index(self, index_filename, doc_map_filename):
        # Save Faiss index to file
        faiss.write_index(self.index, index_filename)

        # Save document map to file using pickle
        with open(doc_map_filename, 'wb') as f:
            pickle.dump(self.doc_map, f)

    def load_index(self, index_filename, doc_map_filename):
        # Load Faiss index from file
        self.index = faiss.read_index(index_filename)

        # Load document map from file using pickle
        with open(doc_map_filename, 'rb') as f:
            self.doc_map = pickle.load(f)

In [40]:
index = FaissIdx(model)
index.add_doc("laptop computers")
index.add_doc("doctor's office")
index.search_doc("PC computer")

[{'laptop computers': 0.72760224}, {"doctor's office": 0.25554448}]

In [21]:
import requests
from io import StringIO
import pandas as pd

In [22]:
res = requests.get('https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/sick2014/SICK_train.txt')
# create dataframe
data = pd.read_csv(StringIO(res.text), sep='\t')
data.head()

Unnamed: 0,pair_ID,sentence_A,sentence_B,relatedness_score,entailment_judgment
0,1,A group of kids is playing in a yard and an ol...,A group of boys in a yard is playing and a man...,4.5,NEUTRAL
1,2,A group of children is playing in the house an...,A group of kids is playing in a yard and an ol...,3.2,NEUTRAL
2,3,The young boys are playing outdoors and the ma...,The kids are playing outdoors near a man with ...,4.7,ENTAILMENT
3,5,The kids are playing outdoors near a man with ...,A group of kids is playing in a yard and an ol...,3.4,NEUTRAL
4,9,The young boys are playing outdoors and the ma...,A group of kids is playing in a yard and an ol...,3.7,NEUTRAL


In [23]:
data.drop_duplicates(subset="sentence_A", inplace=True)

In [27]:
total = data.shape[0]
total

3146

In [30]:
from IPython.display import clear_output

In [32]:
for idx, row in data.iterrows():
    index.add_doc(row['sentence_A'])
    print(((idx)/total)*100)
    if idx%10==0:
        clear_output(wait=True)

142.8480610298792
142.97520661157023
143.006993006993


In [33]:
index.search_doc("yellow train is going by")


[{'A yellow dog is stopping on white snow on a sunny day': 0.44798172},
 {'A yellow dog is running on white snow on a sunny day': 0.382093},
 {'A schoolgirl with a black bag is on a crowded train': 0.38048464}]

In [34]:
index.search_doc("sprinting with football")


[{'A player is running with the ball': 0.59186924},
 {'A player is running with the ball': 0.59186924},
 {'A player is running with the ball': 0.59186924}]

# Saving the index mapping in a pkl file

In [41]:
index.save_index('index.bin', 'doc_map.pkl')