In [2]:
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
from sentence_transformers import SentenceTransformer, util
import numpy as np
#import faiss

In [3]:
data = pd.read_csv("Data/data.csv")

In [7]:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-albert-small-v2')
model = AutoModel.from_pretrained('sentence-transformers/paraphrase-albert-small-v2')

In [6]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [5]:
def create_embedding(data):
    encoded_input = tokenizer(list(data), padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        model_output = model(**encoded_input)
    sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])

    #Uncomment if using faiss
    # sentence_embedding = sentence_embedding.numpy()

    return sentence_embedding

In [6]:
dataset_embedding = create_embedding(data.Question)


KeyboardInterrupt



In [154]:
# Uncomment if using faiss
# index = faiss.IndexFlatL2(768)
# index.add(test)

In [8]:
dataset_embedding = torch.load("Data/chitchat_dataset_embedding.pt")

In [18]:
test_model = SentenceTransformer('sentence-transformers/paraphrase-albert-small-v2')
# embeddings = model.encode(sentences)
# print(embeddings)

Downloading: 100%|██████████| 690/690 [00:00<00:00, 339kB/s]
Downloading: 100%|██████████| 190/190 [00:00<00:00, 190kB/s]
Downloading: 100%|██████████| 3.71k/3.71k [00:00<00:00, 1.87MB/s]
Downloading: 100%|██████████| 827/827 [00:00<00:00, 413kB/s]
Downloading: 100%|██████████| 122/122 [00:00<00:00, 119kB/s]
Downloading: 100%|██████████| 229/229 [00:00<00:00, 45.9kB/s]
Downloading: 100%|██████████| 46.7M/46.7M [00:09<00:00, 5.08MB/s]
Downloading: 100%|██████████| 53.0/53.0 [00:00<00:00, 26.4kB/s]
Downloading: 100%|██████████| 245/245 [00:00<00:00, 124kB/s]
Downloading: 100%|██████████| 760k/760k [00:00<00:00, 838kB/s]  
Downloading: 100%|██████████| 1.31M/1.31M [00:01<00:00, 852kB/s] 
Downloading: 100%|██████████| 465/465 [00:00<00:00, 232kB/s]


In [37]:
def search(query):
    query_encoding = test_model.encode(query)[0]
    print(type(query_encoding))
    # cosin_sim = torch.nn.CosineSimilarity()
    ques_index = np.argmax(util.cos_sim(query_encoding,dataset_embedding))
    # distance, ques_index = index.search(query_encoding, 1)
    return ques_index 

In [38]:
query = ["Who are you"]

data.iloc[search(query).item()]
ques_index = search(query)
print(data.iloc[ques_index.item()])
# bl = test_model.encode(query)


<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Question            Who are you?
Answer      I don't have a name.
Name: 1576, dtype: object


In [27]:
torch.save(dataset_embedding, "chitchat_dataset_embedding.pt")