## Embedding with BERT

BERT doesn't have a `SentenceTransformers` implementation so we have to manually extract the last hidden state vector. The classification vector will be at position 0

In [1]:
from transformers import AutoModel, AutoTokenizer
import torch

In [2]:
sentences = [
    "The weather is lovely today.",
    "It's so sunny outside!",
    "He drove to the stadium.",
]
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


In [3]:
model_checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModel.from_pretrained(model_checkpoint).to(device)
print(model)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [7]:
def get_embeddings(sentences: list[str]):
    inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)
    outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :]

In [8]:
embeddings = get_embeddings(sentences)
print(embeddings.shape)

torch.Size([3, 768])


In [None]:
# Encoding a single string will work:
string_embedding = model.
print(f"String embedding shape: {string_embedding.shape}")


# Encoding a list containing a single string works:
singleton_embedding = model.encode(sentences[:1], convert_to_tensor=True, normalize_embeddings=False)
print(f"Singleton shape: {singleton_embedding.shape}")

# Encoding a list of strings will not work
try:
    embeddings = model.encode(sentences, convert_to_tensor=True, normalize_embeddings=False)
    print(f"Multiple embeddings: {embeddings.shape}")
except Exception as e:
    print(f"Error: {e}")