In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from transformers import BertTokenizer, BertModel
import json
from tqdm.notebook import tqdm

from src.settings import POLISH_ANNOTATIONS_FPATH, EMBEDDINGS_DIR

In [3]:
CHECKPOINT = "dkleczek/bert-base-polish-cased-v1"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

In [4]:
tokenizer = BertTokenizer.from_pretrained(CHECKPOINT, device=DEVICE)
model = BertModel.from_pretrained(CHECKPOINT, output_hidden_states=True).to(DEVICE)

In [5]:
def get_embedding(text, tokenizer, model, device):
    marked_text = "[CLS] " + text + " [SEP]"
    tokenized_text = tokenizer.tokenize(marked_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [1] * len(tokenized_text)

    tokens_tensor = torch.tensor([indexed_tokens]).to(device)
    segments_tensors = torch.tensor([segments_ids]).to(device)

    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensors)

    token_vecs = outputs[2][-2][0]
    sentence_embedding = torch.mean(token_vecs, dim=0)
    return sentence_embedding

In [6]:
embeddings = []
with open(POLISH_ANNOTATIONS_FPATH, 'r', encoding='utf-8') as file:
    for line in tqdm(file, total=40_355):
        data = json.loads(line)
        text = data['text']
        embedding = get_embedding(text, tokenizer, model, DEVICE)
        embeddings.append(embedding.tolist())

  0%|          | 0/40355 [00:00<?, ?it/s]

In [7]:
safe_model_name = CHECKPOINT.replace('/', '__').replace('-', '_')
out_path = EMBEDDINGS_DIR / f"polish_annotations__{safe_model_name}.json"
with open(out_path, 'w', encoding='utf-8') as outfile:
    json.dump(embeddings, outfile)

print(f"Embeddings saved to {out_path}")

Embeddings saved to /app/data/preprocessed/embeddings/polish_annotations__dkleczek__bert_base_polish_cased_v1.json
