In [2]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel

### Batch inputs (two sentences) have different number of tokens

In [3]:
review1="The Matrix is great" # 5 tokens
review2="A terrible movie" # 4 tokens

reviews = [review1, review2]
reviews

['The Matrix is great', 'A terrible movie']

### BERT processes Batch inputs to tokens

In [4]:
# Initialize BERT tokenizer and model (frozen)
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')  # Load tokenizer

# Batch all phrases together
inputs = tokenizer(
    reviews,  # all texts at once
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=128
)

In [5]:
type(inputs)

transformers.tokenization_utils_base.BatchEncoding

In [6]:
print(inputs['input_ids'].shape)         # torch.Size([batch_size, seq_len])
print(inputs['attention_mask'].shape)    # torch.Size([batch_size, seq_len])
print(inputs['token_type_ids'].shape)    # torch.Size([batch_size, seq_len])

torch.Size([2, 6])
torch.Size([2, 6])
torch.Size([2, 6])


### padding when two sentences have different len

In [7]:
print(inputs['input_ids'][1]) # Token IDs
print(tokenizer.convert_ids_to_tokens(inputs['input_ids'][1])) # Tokens

tensor([ 101, 1037, 6659, 3185,  102,    0])
['[CLS]', 'a', 'terrible', 'movie', '[SEP]', '[PAD]']


In [8]:
model = AutoModel.from_pretrained('bert-base-uncased')          # Load model for embeddings
model.eval()  # Set to evaluation mode (no training)

with torch.no_grad():
    outputs = model(**inputs)

outputs.last_hidden_state.shape

torch.Size([2, 6, 768])

### Sentences and 3D dimension. Assume
- 3 sentences, 
- each sentence has 2 words, 
- each word has 5 features, 

![shapes](https://www.tensorflow.org/static/guide/images/tensor/3-axis_front.png)

#### What is dimension of sentence embeddings?
- (3,5)

`nn.mean(data, dim=1)`

### Sentence embeddings is the average of word embeddings

In [10]:
torch.mean(outputs.last_hidden_state, dim=1)

tensor([[ 0.1656, -0.2764, -0.0298,  ...,  0.0087, -0.0636,  0.2763],
        [ 0.1329,  0.0747, -0.2481,  ..., -0.2341,  0.2315, -0.1357]])

### (Optional) What is the potential issue of use the average of word embeddings for sentence embeddings

The mean includes padding tokens (where attention_mask=0), which can dilute the embedding quality. BERT’s padding tokens produce non-informative embeddings, and averaging them may introduce noise, especially for short reviews with many padding tokens.

In [16]:
# Masked mean-pooling
attention_mask = inputs['attention_mask']  # (batch_size, seq_len)
attention_mask

tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0]])

In [17]:
mask = attention_mask.unsqueeze(-1).expand_as(outputs.last_hidden_state)  # (batch_size, seq_len, hidden_dim)
mask

tensor([[[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 0, 0, 0]]])

In [18]:
masked_embeddings = outputs.last_hidden_state * mask
masked_embeddings

tensor([[[-3.8348e-02,  9.5097e-02,  1.4332e-02,  ..., -1.7143e-01,
           1.2736e-01,  3.7117e-01],
         [-3.7472e-01, -6.2022e-01,  1.2133e-01,  ..., -2.7666e-02,
           1.5813e-01,  1.7997e-01],
         [ 7.1591e-01, -1.9231e-01,  1.5049e-01,  ..., -4.0711e-01,
           1.9909e-01,  2.7043e-01],
         [-3.6584e-01, -3.0518e-01,  5.0851e-04,  ...,  1.1478e-01,
          -2.0296e-01,  9.8816e-01],
         [ 4.8723e-02, -7.2430e-01, -1.8481e-01,  ...,  3.9914e-01,
           9.7036e-02,  4.0537e-02],
         [ 1.0081e+00,  8.8626e-02, -2.8047e-01,  ...,  1.4469e-01,
          -7.6039e-01, -1.9232e-01]],

        [[-1.0380e-01,  4.6764e-03, -1.2088e-01,  ..., -2.1156e-01,
           2.9962e-01, -1.0300e-02],
         [-1.1521e-01,  2.1597e-01, -4.0657e-01,  ..., -5.8376e-01,
           8.9380e-01,  4.3011e-01],
         [ 4.4965e-01,  2.5421e-01,  2.4422e-02,  ..., -3.6552e-01,
           2.4427e-01, -6.5578e-01],
         [ 6.2745e-02,  6.8042e-02, -9.1592e-01,  ...