In [1]:
import torch
from transformers import BertTokenizer, BertModel

In [2]:
# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [3]:
# Encode input text
inputs = tokenizer("This is an example text", return_tensors="pt")

In [4]:
inputs

{'input_ids': tensor([[ 101, 2023, 2003, 2019, 2742, 3793,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [5]:
# Get the hidden states (embeddings)
with torch.no_grad():
    outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state  # Embeddings for each token

In [7]:
last_hidden_states.shape

torch.Size([1, 7, 768])

In [8]:
# Pooling Strategies
# 1. CLS Pooling: Use the embedding of the [CLS] token
cls_embedding = last_hidden_states[:, 0, :]  # The [CLS] token is at index 0
print("CLS Pooling Embedding Shape:", cls_embedding.shape)

CLS Pooling Embedding Shape: torch.Size([1, 768])


In [9]:
# 2. Mean Pooling: Take the mean of all token embeddings
# Exclude padding tokens from the mean if padding is present
attention_mask = inputs['attention_mask']
masked_embeddings = last_hidden_states * attention_mask.unsqueeze(-1)
mean_embedding = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)

In [10]:
print("Mean Pooling Embedding Shape:", mean_embedding.shape)

Mean Pooling Embedding Shape: torch.Size([1, 768])


In [17]:
last_hidden_states

tensor([[[-0.1671,  0.0339, -0.0868,  ..., -0.4128,  0.4176,  0.8275],
         [-0.7717, -0.2977,  0.0357,  ..., -0.6559,  0.6945,  0.0457],
         [-0.1457, -0.6747,  0.3640,  ..., -0.3649,  0.4561,  0.5992],
         ...,
         [-0.6385, -0.2565,  0.0449,  ..., -0.4831,  0.3721,  0.5082],
         [-0.5544, -0.0528, -0.3377,  ...,  0.4267,  0.4717,  0.4110],
         [ 0.8272,  0.1619, -0.4493,  ...,  0.1640, -0.7352, -0.1326]]])