In [1]:
from transformers import BertTokenizer, BertModel
import torch

model = BertModel.from_pretrained("bert-base-uncased", 
                                  output_hidden_states=True,
                                )
model.eval()

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [3]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
text = "Hello, how are you? I want to create an embedding for this sentence."

marked_request = "[CLS] " + text + " [SEP]"
tokenized_text = tokenizer.tokenize(marked_request)
print("Tokenized text:", tokenized_text)

Tokenized text: ['[CLS]', 'hello', ',', 'how', 'are', 'you', '?', 'i', 'want', 'to', 'create', 'an', 'em', '##bed', '##ding', 'for', 'this', 'sentence', '.', '[SEP]']


In [4]:
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

for tup in zip(tokenized_text, indexed_tokens):
    print('{:<12} {:>6}'.format(tup[0], tup[1]))

[CLS]           101
hello          7592
,              1010
how            2129
are            2024
you            2017
?              1029
i              1045
want           2215
to             2000
create         3443
an             2019
em             7861
##bed          8270
##ding         4667
for            2005
this           2023
sentence       6251
.              1012
[SEP]           102


In [5]:
segments_ids = [1] * len(tokenized_text)
print("Segment IDs:", segments_ids)

tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

print("Tokens tensor:", tokens_tensor)
print("Segments tensor:", segments_tensors)

Segment IDs: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Tokens tensor: tensor([[ 101, 7592, 1010, 2129, 2024, 2017, 1029, 1045, 2215, 2000, 3443, 2019,
         7861, 8270, 4667, 2005, 2023, 6251, 1012,  102]])
Segments tensor: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])


In [8]:
with torch.no_grad():
    outputs = model(tokens_tensor, segments_tensors)

last_hidden_states = outputs.last_hidden_state
print("Last hidden states shape:", last_hidden_states.shape)

hidden_states = outputs.hidden_states
print("Number of hidden states:", len(hidden_states))
print("Shape of each hidden state:", [state.shape for state in hidden_states])

# Extract the last hidden state for the first token ([CLS])
cls_embedding = last_hidden_states[0][0]
print("CLS embedding shape:", cls_embedding.shape)
print("CLS embedding:", cls_embedding)

Last hidden states shape: torch.Size([1, 20, 768])
Number of hidden states: 13
Shape of each hidden state: [torch.Size([1, 20, 768]), torch.Size([1, 20, 768]), torch.Size([1, 20, 768]), torch.Size([1, 20, 768]), torch.Size([1, 20, 768]), torch.Size([1, 20, 768]), torch.Size([1, 20, 768]), torch.Size([1, 20, 768]), torch.Size([1, 20, 768]), torch.Size([1, 20, 768]), torch.Size([1, 20, 768]), torch.Size([1, 20, 768]), torch.Size([1, 20, 768])]
CLS embedding shape: torch.Size([768])
CLS embedding: tensor([-6.4543e-02, -1.8406e-02, -5.7792e-02, -1.3213e-01, -3.7210e-01,
         1.0858e-02,  4.5951e-01,  2.0098e-01,  1.3049e-01, -5.1479e-01,
         1.7972e-01, -4.9159e-01,  1.8552e-01,  1.7872e-01,  2.9401e-01,
        -5.4633e-02, -4.1804e-03,  3.9316e-01,  2.3027e-01, -1.0770e-01,
        -2.4033e-01, -2.2257e-01,  2.4798e-01,  4.7015e-02,  7.0802e-02,
        -3.5405e-01, -1.2678e-01,  1.2807e-01, -2.0577e-01, -3.2061e-01,
         7.9583e-02,  5.1757e-01, -4.3851e-01, -2.7259e-01,  2