In [69]:
import torch
from transformers import BertTokenizer, BertModel

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# Sample sentence
sentence = "The police is chasing a criminal on the run."

# Tokenize the sentence and convert to input IDs
tokens = tokenizer.tokenize(sentence)
input_ids = tokenizer.convert_tokens_to_ids(tokens) # List of token ids

# Tensor of token ids with batch size as 1; so shape will be torch.Size([1, 10])
input_ids = torch.tensor(input_ids).unsqueeze(0)

# Get BERT embeddings
with torch.no_grad():
    outputs = model(input_ids)

# Extract word embeddings from BERT outputs

# Squeeze drops the batch dimension.
# Shape: [seq_len, hidden_size]
input_embeddings = outputs.last_hidden_state.squeeze(0)

In [71]:
input_embeddings.shape

torch.Size([10, 768])

In [73]:
input_embeddings[0]

tensor([ 4.5731e-02, -1.6975e-01, -1.2351e-01,  6.4223e-01,  1.6973e-01,
        -6.0670e-01, -4.2367e-02, -3.9192e-02,  7.1000e-02, -8.2006e-01,
        -2.8014e-03, -3.8255e-02, -4.1160e-01,  3.0972e-01,  1.6829e-01,
         8.7796e-01,  2.7160e-03,  1.4480e-01,  2.8266e-01, -5.4223e-01,
         5.6544e-01,  5.0288e-01, -5.3114e-01,  3.6462e-01, -5.5024e-01,
        -1.2004e-01,  1.1906e-01, -2.9636e-01, -1.9004e-01, -1.7190e-01,
         1.4404e-01, -5.1432e-01,  1.0791e-01,  7.3809e-01, -1.5954e-02,
        -1.0558e+00,  5.5525e-01, -1.9024e-02, -2.7733e-01,  1.5997e-01,
        -4.1892e-01,  1.6732e-01,  1.1009e-01,  4.2427e-01, -4.0523e-01,
        -2.9339e-02, -5.3721e-01, -4.5965e-01, -7.4765e-02,  2.7219e-01,
        -3.7822e-01,  1.4941e-01,  4.1616e-03,  1.1602e+00, -3.3779e-01,
        -2.0065e-01,  5.5285e-01, -8.1083e-01,  1.2337e-01, -9.1753e-02,
         8.3569e-01, -2.1187e-01,  1.5810e-01, -3.7538e-01,  1.6588e-01,
         3.9229e-01,  7.1644e-01, -3.4856e-01,  5.0

In [74]:
input_embeddings

tensor([[ 0.0457, -0.1697, -0.1235,  ..., -0.6835,  0.1324,  0.1968],
        [ 0.2084, -0.2574, -0.2877,  ..., -0.6672,  0.2406,  0.3226],
        [-0.2241, -0.1168, -0.5282,  ..., -0.3569,  0.1170,  0.4262],
        ...,
        [-0.1020, -0.6876, -0.8670,  ..., -0.1890,  0.5862,  0.0672],
        [ 0.5326, -0.5168, -0.1144,  ..., -0.5453,  0.1622, -0.3743],
        [ 0.2081, -0.7447, -0.1060,  ..., -0.5307,  0.4500,  0.1150]])

In [None]:
This input contains information about each token in 768 dimensions.
Lets say we want to know the information for the 1st token (i.e. row 0 in the input).


In [78]:
decoder_embeddings = torch.zeros(size=(768,768))

In [79]:
decoder_embeddings

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [83]:
(input_embeddings @ decoder_embeddings).shape

torch.Size([10, 768])

In [101]:
t1 = torch.randn(768,768)

In [102]:
t2 = torch.randn(768,768)

In [104]:
t3 = t1@t2

In [111]:
t3 = t3/768**0.5

In [113]:
t3.var()

tensor(1.0054)