In [1]:
import random
import numpy as np
import pandas as pd
import faiss
import torch
import re
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer

seed = 10

random.seed(seed)
np.random.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [2]:
df = pd.read_csv('../data/prepared_with_context+label+negative.csv')

In [3]:
df = df[df['label'] == 1]

In [4]:
df

Unnamed: 0,original_response,response,context,label
0,"Oh, hi!",oh hi,two hundred pound transvestite with a skin co...,1
1,Hi?,hi,hi,1
2,"Oh, that’s nice.",oh that’s nice,we don’t mean to interrupt we live across the...,1
3,"Oh, okay, well, guess I’m your new neighbour,...",oh okay well guess i’m your new neighbour penny,oh… uh… no… we don’t live together… um… we li...,1
4,Hi.,hi,leonard sheldon,1
...,...,...,...,...
29899,"Well, then you get it.",well then you get it,well sometimes women don’t care sometimes it ...,1
29900,Right.,right,okay um let’s try this think of yourself as o...,1
29901,"Let’s forget the toy thing, okay? Um, maybe…",let’s forget the toy thing okay um maybe…,well then you get it [SEP] because there’s on...,1
29902,All right. What do you think is happening?,all right what do you think is happening,right [SEP] although amy’s already taken me o...,1


In [5]:
import  nltk
from nltk.tokenize import wordpunct_tokenize
from nltk.corpus import stopwords

nltk.download('stopwords')
russian_stopwords = stopwords.words("russian")

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Di\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [6]:
df.dropna(how='any', inplace=True)

In [7]:
def clean_symbols(x):
    pattern = r'[\"|\#|\$|\%|\&|\(|\)|\*|\+|\,|\-|\/|\:|\;|\<|\=|\>|\@|\\|\^|\_|\`|\{|\||\}|\~|\.|\!|\?]'
    cleaned = re.sub(pattern=pattern, repl='', string=x)
    return cleaned.lower()

In [8]:
class SimpleSearchEngine:
    def __init__(self, text_database: list[str], isFaiss):
        self.raw_procesed_data = [self.preprocess(sample) for sample in text_database]
        self.base = []
        self.retriever = None
        self.inverted_index = {}
        self.isFaiss = isFaiss
        self._init_retriever(text_database)
        self._init_inverted_index(text_database)

    @staticmethod
    def preprocess(sentence: str) -> str:
        sentence = clean_symbols(sentence)
        return sentence
    
    def _init_faiss(self):
        self.index = faiss.IndexFlatL2(self.base.shape[1])
        faiss.normalize_L2(self.base.toarray().astype('float32'))
        self.index.add(self.base.toarray().astype('float32'))

    def _init_retriever(self, text_database: list[str]):
        self.retriever = TfidfVectorizer(stop_words=russian_stopwords,
                             ngram_range=(1,3),
                             max_features=4096, # берем в словарь слова с максимальными tf
                             tokenizer=wordpunct_tokenize)
        self.base = self.retriever.fit_transform(text_database) # подсчитать вектора = представление документа
        if self.isFaiss:
            self._init_faiss()


    def retrieve(self, query: str) -> np.array:
        return self.retriever.transform([query]) # векторизовать входное значение

    def retrieve_documents(self, query: str, top_k=3) -> np.array:
        query_vector = self.retrieve(query)
        cosine_similarities = cosine_similarity(query_vector, self.base).flatten()
        relevant_indices = np.argsort(cosine_similarities, axis=0)[::-1][:top_k]
        return relevant_indices
    
    def retrieve_documents_faiss(self, query: str, top_k=3):
        query_vector = self.retrieve(query).toarray().astype('float32')
        faiss.normalize_L2(query_vector)
        _, relevant_indices = self.index.search(query_vector, k=top_k)
        return relevant_indices[0]

    def _init_inverted_index(self, text_database: list[str]):
        self.inverted_index = dict(enumerate(text_database))

    def display_relevant_docs(self, query: str, full_database, top_k=3) -> list[str]:
        query = clean_symbols(query)

        if self.isFaiss:
            docs_indexes = self.retrieve_documents_faiss(query, top_k)
        else:
            docs_indexes = self.retrieve_documents(query, top_k)
            
        return [self.inverted_index[ind] for ind in docs_indexes]

In [9]:
simple_search_engine = SimpleSearchEngine(df['context'], isFaiss=True)



In [10]:
def get_chatbot_response(query):
    simple_search_engine_results = simple_search_engine.display_relevant_docs(query, df['context'])
    response = df[df['context'] == simple_search_engine_results[0]]['original_response'].values
    return response[0]

In [22]:
all_context = ''
count = 0
for i in range(3):
    query = input()
    if (count == 0) or (count % 2 == 0):
        all_context = query
    else:
        all_context = all_context + ' [SEP] ' + query
    count += 1
    
    print('-' + query)
    print('---' + get_chatbot_response(all_context))
    print('\n')

-Hi
--- Hi?


-Nice to meet you
--- Good.


-What are you doing?
--- I’m settling once and for all who is the smartest around here. Okay, are you ready?


