# Загрузка и импорт библиотек

In [None]:
%pip install --user annoy

In [5]:
import pandas as pd
from tqdm import tqdm
import pickle
from transformers import AutoTokenizer, AutoModel
import torch
from annoy import AnnoyIndex
import random

# Энкодер

In [1]:
version = '_tiny2_context_softmax_final'
max_length=2048

In [None]:
'''
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

tokenizer = AutoTokenizer.from_pretrained("ai-forever/sbert_large_mt_nlu_ru")
model_question = AutoModel.from_pretrained("ai-forever/sbert_large_mt_nlu_ru")
model_answer = AutoModel.from_pretrained("ai-forever/sbert_large_mt_nlu_ru")'''

def embed_bert_cls(model_output):
    embeddings = model_output.last_hidden_state[:, 0, :]
    embeddings = torch.nn.functional.normalize(embeddings)
    return embeddings

tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
model_question = AutoModel.from_pretrained("cointegrated/rubert-tiny2")
model_answer = AutoModel.from_pretrained("cointegrated/rubert-tiny2")

In [3]:
model_question.load_state_dict(torch.load(f'models/model_anchor{version}.bin', map_location=torch.device('cpu')))
model_answer.load_state_dict(torch.load(f'models/model_pos_neg{version}.bin', map_location=torch.device('cpu')))

<All keys matched successfully>

# Annoy

In [6]:
df = pd.read_csv('data/grammar_cities_fix.csv')
df = df[df['is_visitor_message'] == 'ОПЕРАТОР']
df = df.drop_duplicates(subset=['message'])
df.shape

(2934, 9)

In [7]:
# f = 1024  # for sbert
f = 312 # for tiny2

t = AnnoyIndex(f, 'angular')
for i in tqdm(df.index):
    encoded_input = tokenizer(df['message'][i], padding=True, truncation=True, max_length=2048, return_tensors='pt')
    with torch.no_grad():
        model_output = model_answer(**encoded_input)
    sentence_embedding = embed_bert_cls(model_output)
    t.add_item(i, sentence_embedding[0])

t.build(100)
t.save(f'{version[1:]}.ann')

100%|██████████| 2934/2934 [00:16<00:00, 175.44it/s]


True

In [8]:
u = AnnoyIndex(f, 'angular')
u.load(f'{version[1:]}.ann') # super fast, will just mmap the file
print(u.get_nns_by_item(0, 10)) # will find the 10 nearest neighbors

[18, 19, 28, 32, 56, 69, 77, 117, 119, 126]


In [13]:
def get_question_embs(question):
    encoded_input = tokenizer(question, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
    with torch.no_grad():
        model_output = model_question(**encoded_input)
    question_embeddings = embed_bert_cls(model_output)
    return question_embeddings.tolist()[0]

In [10]:
def pretty_response(response):
    to_output = []
    ids, distances = response
    for i in range(len(ids)):
        id = ids[i]
        score = distances[i]
        text = df['message'][id]
        to_output.append([score, id, text])
    to_output.sort()
    for [score, id, text] in to_output:
        pretty_output = f"\nID: {id} \nSummary: {text}\nScore: {score}"
        print(pretty_output)

In [11]:
df_test = pd.read_csv('data/test_onestr.csv')

In [None]:
def read_list(filename):
    with open(filename, 'rb') as fp:
        n_list = pickle.load(fp)
        return n_list

for i in df.index[:10]:
    print(df_test['anchor_one_str'][i])
    print(df_test["positive"][i])
    print()

In [None]:
context = ''
v = get_question_embs(context)
pretty_response(u.get_nns_by_vector(v, 20, include_distances=True))