In [None]:
!pip install transformers

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, PreTrainedTokenizer
from sklearn import preprocessing
from sklearn.feature_extraction.text import CountVectorizer
import pandas as pd
import torch.nn.functional as F
from tqdm import tqdm
from torch import nn

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
url='https://drive.google.com/file/d/1NKoxe-KUirKp91yLZhCz63QFCTXEGPoo/view?usp=share_link'
url='https://drive.google.com/uc?id=' + url.split('/')[-2]
data = pd.read_csv(url)

In [None]:
data['url_with_ans'] = data['url_with_ans'].apply(str.strip)
data['question'] = data['question'].apply(str.lower)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("cointegrated/LaBSE-en-ru")
model = AutoModel.from_pretrained("cointegrated/LaBSE-en-ru")
model = model.to(device)
model.eval()

In [None]:
label_enc = preprocessing.LabelEncoder()
data['target'] = label_enc.fit_transform(data['url_with_ans'])

In [None]:
def vectorize(sentences):
  encoded_input = tokenizer(sentences, padding=True, return_tensors='pt')
  with torch.no_grad():
    model_output = model(**encoded_input)
  embeddings = model_output.pooler_output
  bert_embeddings = torch.nn.functional.normalize(embeddings)

  return bert_embeddings

In [None]:
centroids = {}
for class_number in range(len(label_enc.classes_)):
  class_data = data[data.target == class_number]
  bert_embeddings = vectorize(class_data['question'].tolist())
  bert_embeddings = torch.sum(bert_embeddings, axis=0)
  centroids[class_number] = bert_embeddings

In [None]:
def inference(question: str, centroids):
  vector = vectorize([question])

  cos_sims = {}
  for class_num in centroids:
    cos_sims[class_num] = F.cosine_similarity(vector, centroids[class_num])

  return cos_sims

In [None]:
result = inference("Какие правила лабораторных работ по программированию", centroids) 

In [None]:
ind = max(result, key=result.get)
print(label_enc.inverse_transform([ind])[0])
print(result)

In [None]:
print(label_enc.inverse_transform([1])[0])

In [None]:
data['target'].value_counts()