In [15]:
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
#import faiss

In [3]:
data = pd.read_csv("Data/data.csv")

In [2]:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-albert-small-v2')
model = AutoModel.from_pretrained('sentence-transformers/paraphrase-albert-small-v2')

In [4]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [18]:
def create_embedding(data):
    encoded_input = tokenizer(list(data), padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        model_output = model(**encoded_input)
    sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])

    #Uncomment if using faiss
    # sentence_embedding = sentence_embedding.numpy()

    return sentence_embedding

In [19]:
dataset_embedding = create_embedding(data.Question)

In [154]:
# Uncomment if using faiss
# index = faiss.IndexFlatL2(768)
# index.add(test)

In [23]:
def search(query):
    query_encoding = create_embedding(query)
    cosin_sim = torch.nn.CosineSimilarity()
    ques_index = np.argmax(cosin_sim(query_encoding,dataset_embedding))
    # distance, ques_index = index.search(query_encoding, 1)
    return ques_index 

In [26]:
query = ["Where are you from"]

# data.iloc[search(query).item()]
ques_index = search(query)
print(data.iloc[ques_index.item()])


Question                               Where are you from?
Answer      I'm digital. I don't have a physical location.
Name: 3919, dtype: object


In [27]:
torch.save(dataset_embedding, "chitchat_dataset_embedding.pt")