# BERT and Related Models - Embeddings

(C) 2024-2025 by [Damir Cavar](http://damir.cavar.me/)

**Download:** This and various other Jupyter notebooks are available from my [GitHub repo](https://github.com/dcavar/python-tutorial-for-ipython).

This code pulls embeddings for words or text from BERT.

Prerequisites:
You will have to install the `transformers` Python module.

In [None]:
!pip install -U transformers

We will need to import `pytorch` and `transformers`.

In [10]:
import random
import torch
from transformers import BertTokenizer, BertModel
from numpy import dot

We'll seed the random value and check whether we can use CUDA and GPUs for computations.

In [11]:
random_seed = 42
random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

The following function returns the embedding of a text, which could be just one word.

In [12]:
def get_embedding(text: str, tokenizer, model) -> torch.FloatTensor:
	example_encoding = tokenizer.batch_encode_plus(
		[ text ],
		padding            = True,
		truncation         = True,
		return_tensors     = 'pt',
		add_special_tokens = True
	)
	example_input_ids = example_encoding['input_ids']
	example_attention_mask = example_encoding['attention_mask']
	with torch.no_grad():
		example_outputs = model(example_input_ids, attention_mask=example_attention_mask)
		example_embedding = example_outputs.last_hidden_state.mean(dim=1) # Average pooling
	return example_embedding[0].tolist()

In the following code segment we intialize the BERT tokenizer and load the model:

In [13]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model     = BertModel.from_pretrained('bert-base-uncased')
vector_length = len(get_embedding("test", tokenizer, model))



This is the word list that we want to use to pull embeddings for:

In [14]:
word_list = list(set("""
apple banana cherry date elderberry fig grapefruit honeydew
car airplane train bus bicycle motorcycle boat ship
dog cat rabbit hamster parrot goldfish
""".split()))

Pull embeddings for each word:

In [15]:
for word in word_list:
    vector = get_embedding(word, tokenizer, model)
    print(word, len(vector), vector)

ship 768 [0.21604371070861816, 0.05331966280937195, -0.09693580865859985, 0.09807302802801132, -0.04365711286664009, -0.026934444904327393, 0.13867424428462982, -0.28003647923469543, -0.007644951343536377, -0.24840660393238068, 0.1805824637413025, 0.11090791970491409, 0.1181827187538147, 0.19309325516223907, -0.22989672422409058, -0.31791219115257263, -0.12864427268505096, 0.19062228500843048, -0.005026951432228088, 0.025382719933986664, 0.17606811225414276, -0.06510838121175766, 0.38882994651794434, -0.13406813144683838, 0.13015006482601166, 0.11217167228460312, -0.12591029703617096, -0.10384338349103928, -0.14498236775398254, 0.017191609367728233, 0.021489113569259644, -0.03962324932217598, 0.08012358099222183, 0.361190527677536, -0.04363026097416878, -0.04444580897688866, 0.0674561932682991, -0.10804545134305954, -0.15796804428100586, -0.03498661518096924, 0.08850470930337906, -0.02489652670919895, 0.1742640882730484, 0.017730534076690674, 0.15391556918621063, 0.03255397081375122, -

In [16]:
word = "The ARTICLE house"
vector = get_embedding(word, tokenizer, model)
print(word, len(vector), vector)

The ARTICLE house 768 [0.21632826328277588, -0.3678334057331085, -0.08473453670740128, 0.10511256754398346, -0.03370572254061699, -0.3001091182231903, 0.0015765547286719084, 0.10472764074802399, 0.07939986884593964, -0.16032372415065765, -0.05395383760333061, -0.07013607025146484, -0.03223909065127373, 0.18324092030525208, -0.2950148582458496, 0.06273850053548813, -0.1579228937625885, -0.08000735193490982, -0.040428049862384796, -0.16987191140651703, 0.28086358308792114, 0.022985758259892464, 0.12371919304132462, 0.1609281301498413, 0.48000651597976685, 0.09604588896036148, -0.10451064258813858, -0.14005446434020996, -0.7152451276779175, -0.04349450767040253, -0.011680593714118004, -0.14365258812904358, 0.05332493782043457, 0.26480770111083984, -0.24656371772289276, -0.29148587584495544, -0.08998049795627594, 0.011383774690330029, -0.2665092349052429, 0.08129433542490005, 0.20978307723999023, 0.12836439907550812, 0.1973499357700348, 0.09299448877573013, 0.08594222366809845, 0.059681154

(C) 2024-2025 by [Damir Cavar](http://damir.cavar.com/)