# Word Embeddings

In [1]:
import torch
from transformers import AutoTokenizer, AutoModel


# Load pre-trained BERT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-uncased')
model = AutoModel.from_pretrained('google-bert/bert-base-uncased')

# Sentence to process
sentences = ["It's a tokenization example.", "This is Elon Musk"]

# Tokenize the sentence
tokens = tokenizer(sentences, return_tensors='pt', add_special_tokens=False, padding='max_length')

# Get token embeddings
with torch.no_grad():
    outputs = model(**tokens)
    token_embeddings = outputs.last_hidden_state

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
from dp_gfn.utils.pretrains import batch_token_embeddings_to_batch_word_embeddings


batch_word_embeddings = batch_token_embeddings_to_batch_word_embeddings(
    tokens=tokens,
    token_embeddings=token_embeddings,
    agg_func='mean',
    max_word_length=150,
)

batch_word_embeddings.shape

check
0 0
check
0 1
check
0 2
check
0 3
check
0 4
check
0 5
check
0 6
check
1 0
check
1 1
check
1 2
check
1 3


torch.Size([2, 150, 768])

In [216]:
batch_size = token_embeddings.shape[0]
max_word_len = 160
batch_embeddings = []

for sample_idx in range(batch_size): 
    word_embeddings = []
    start, end = 0, 0

    #BatchEncoding.word_ids returns a list mapping words to tokens
    for w_idx in set(tokens.word_ids(sample_idx)):
        if w_idx is None: 
            break
        
        #BatchEncoding.word_to_tokens tells us which and how many tokens are used for the specific word
        start, end = tokens.word_to_tokens(sample_idx, w_idx)
        word_embedding = torch.mean(
            token_embeddings[sample_idx][start: end],
            dim = 0, 
        )
        word_embeddings.append(word_embedding)
    
    word_embeddings = torch.stack(word_embeddings, dim=0)
    word_embeddings = torch.cat([word_embeddings, token_embeddings[sample_idx][end: max_word_len - word_embeddings.shape[0] + end]], dim=0)
    batch_embeddings.append(word_embeddings)

In [217]:
batch_embeddings = torch.stack(batch_embeddings, dim=0).shape

torch.Size([2, 160, 768])

In [181]:
token_embeddings[0][end:160].shape

torch.Size([152, 768])