In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer, AutoModelForPreTraining, AutoModelForMaskedLM, AutoModelForTextEncoding
from ..constants import ROCAR_CSV, FINE_TUNED_BERT_MODEL_PATH
from ..utils.format import preprocess_text

In [None]:
df = pd.read_csv("../" / ROCAR_CSV)

df["input"] = df["input"].apply(preprocess_text)

In [None]:
# Stefan Dumitrescu, Andrei-Marius Avram, and Sampo Pyysalo. 2020. The birth of Romanian BERT. In Findings of the Association for Computational Linguistics: EMNLP 2020, pages 4324–4328, Online. Association for Computational Linguistics.
# https://huggingface.co/dumitrescustefan/bert-base-romanian-cased-v1
tokenizer = AutoTokenizer.from_pretrained("dumitrescustefan/bert-base-romanian-cased-v1")
model = AutoModelForMaskedLM.from_pretrained("../" / FINE_TUNED_BERT_MODEL_PATH)
model.config.output_hidden_states = True

torch.cuda.empty_cache()

model

In [None]:
class TextDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        return item

    def __len__(self):
        return len(self.encodings.input_ids)


encodings = tokenizer(df["input"].tolist(), max_length=512, padding=True, truncation=True)

In [None]:
from torch.utils.data import DataLoader


def encode_texts(model, batch_size=8):
    dataset = TextDataset(encodings)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    embeddings = []
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            outputs = model(input_ids, attention_mask=attention_mask)
            embeddings.append(outputs.hidden_states[-1])
    embeddings = torch.cat(embeddings)
    return embeddings

In [None]:
embeddings = encode_texts(df["input"].tolist(), tokenizer, model)

embeddings.shape