In [None]:
import torch
from scipy.spatial.distance import cosine
from torch import Tensor
from transformers import BertTokenizer, BertModel

In [None]:
def tokenize(text: str, tokenizer: BertTokenizer, max_length: int) -> tuple:
    dict = tokenizer(text,
                     add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
                     truncation=True,
                     max_length=max_length,
                     pad_to_max_length=True,
                     return_attention_mask=True
                     )

    token_ids = dict['input_ids']
    attention_mask = dict['attention_mask']
    return token_ids, attention_mask

In [None]:
def tokenize1(text: str, tokenizer: BertTokenizer, max_length: int) -> tuple:
    text = "[CLS] " + text + " [SEP]"
    tokens = tokenizer.tokenize(text)
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    attention_mask = [1] * len(token_ids) + [0] * (max_length - len(token_ids))
    token_ids += [0] * (max_length - len(token_ids))
    return tokens, token_ids, attention_mask

In [None]:
def vectorize(token_ids: Tensor, attn_mask: Tensor, model: BertModel) -> Tensor:
    with torch.no_grad():
        output = model(token_ids, attn_mask)

    hidden_states = output[2]
    token_vectors = hidden_states[-2][0]  # shape = (#tokens, 768)
    # result[i] = average token_vecs[i, j], j = 0 .. 767
    vector = torch.mean(token_vectors, dim=0)  # shape = (768)
    return vector

In [None]:
def token_vectorize(text: str, tokenizer: BertTokenizer, model: BertModel, max_length: int) -> Tensor:
    token_ids, attn_mask = tokenize(text, tokenizer, max_length)
    token_ids = torch.tensor(token_ids).view((1, -1))
    attn_mask = torch.tensor(attn_mask).view((1, -1))
    vector = vectorize(token_ids, attn_mask, model)
    return vector

In [None]:
max_length = 24

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
model.eval()

texts = ["Here is the sentence I want embeddings for.",
         "This is the sentence I want no embeddings for.",
         "After stealing money from the bank vault, the bank robber was seen fishing on the Mississippi river "
         "bank.",
         "The bank robber was seen fishing on the Mississippi river bank after stealing money from the bank vault"]

In [None]:
for t in texts:
    v = token_vectorize(t, tokenizer, model, max_length)

    token_ids, attn_mask = tokenize(t, tokenizer, max_length)
    tokens1, token_ids1, attn_mask1 = tokenize1(t, tokenizer, max_length)
    print('\n', t, '\n', tokens1)
    assert (token_ids == token_ids1)
    assert (attn_mask == attn_mask1)

    token_ids = torch.tensor(token_ids).view((1, -1))
    attn_mask = torch.tensor(attn_mask).view((1, -1))
    w = vectorize(token_ids, attn_mask, model)
    assert (torch.equal(v, w))

In [None]:
for s in texts:
    v = token_vectorize(s, tokenizer, model, max_length)
    for t in texts:
        w = token_vectorize(t, tokenizer, model, max_length)
        diff = cosine(v, w)
        print('\n', s)
        print(t)
        print(diff)