In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("ai-forever/sbert_large_nlu_ru")
model = AutoModel.from_pretrained("ai-forever/sbert_large_nlu_ru")

In [None]:
import pandas as pd
import torch
from datasets import Dataset

In [None]:
df_articles = pd.read_csv('../data/anekdots.csv', index_col=0)

df_articles.head()

In [None]:
text_example = df_articles['text_clean'][300]
text_example

In [None]:
input_ids = tokenizer.encode(text=text_example)
tokens = tokenizer.tokenize(text=text_example)

print(len(input_ids), len(tokens))
for tok, id in zip(tokens[:10], input_ids[1:10]):
    print(id, tok)

print(tokenizer.decode(token_ids=input_ids))

In [None]:
encoded_input = tokenizer(text_example, padding=True, truncation=True, max_length=10, return_tensors='pt')

In [None]:
encoded_input

In [None]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

In [None]:
print(model)

In [None]:
with torch.no_grad():
    out = model(**encoded_input)

In [None]:
out.last_hidden_state[:,0,:].shape

In [None]:
mean_pooling(out, encoded_input['attention_mask']).shape

In [None]:
anekdot_dataset = Dataset.from_pandas(df_articles)
anekdot_dataset

In [None]:
device = 'cpu'

In [None]:
def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0, :]

In [None]:
def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, max_length=128, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

In [None]:
with torch.no_grad():
    embedding = get_embeddings(anekdot_dataset["text_clean"][0])
embedding.shape

In [None]:
embeddings_dataset = anekdot_dataset.map(
    lambda x: {"embeddings": get_embeddings(x["text_clean"]).detach().cpu().numpy()[0]}
)

In [None]:
question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape

In [None]:
scores, samples = embeddings_dataset.get_nearest_examples(
    "embeddings", question_embedding, k=5
)

In [None]:
import pandas as pd

samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)