## Extract embeddings from BERT

In [16]:
import matplotlib.pyplot as plt
import seaborn as sns

In [92]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import BertTokenizer, BertModel

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# Sample sentence
sentence = "The police is chasing a criminal on the run."

# Tokenize the sentence and convert to input IDs
tokens = tokenizer.tokenize(sentence)
input_ids = tokenizer.convert_tokens_to_ids(tokens) # List of token ids

# Tensor of token ids with batch size as 1; so shape will be torch.Size([1, 10])
input_ids = torch.tensor(input_ids).unsqueeze(0)

# Get BERT embeddings
with torch.no_grad():
    outputs = model(input_ids)

# Extract word embeddings from BERT outputs

# Shape: [seq_len, hidden_size]
input_embeddings = outputs.last_hidden_state

In [93]:
input_embeddings.shape

torch.Size([1, 10, 768])

## Attention Mechanism

Attention enables contextualized word embeddings by allowing the model to selectively focus on different parts of the input sequence when making predictions. Put simply, the attention mechanism allows the transformer to dynamically weigh the importance of different parts of the input sequence based on the current task and context.

In [94]:
# Shape of the input embeddings [batch_size, input_size,embedding_dimension]
B, T, C = input_embeddings.shape

In [95]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
out = wei @ input_embeddings

In [175]:
key = nn.Linear(C, C)
query = nn.Linear(C, C)
value = nn.Linear(C, C)

k = key(input_embeddings)
q = query(input_embeddings)
v = value(input_embeddings)
wei = q @ k.permute(0,2,1)

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)


out = wei @ v

In [None]:
def self_attention(input_embed, block_type:str):
    # Shape of the input embeddings [batch_size, input_size,embedding_dimension]
    B, T, C = input_embed.shape
    
    key = nn.Linear(C, C)
    query = nn.Linear(C, C)
    value = nn.Linear(C, C)

    k = key(input_embeddings)
    q = query(input_embeddings)
    v = value(input_embeddings)
    wei = q @ k.permute(0,2,1) / C**0.5

    tril = torch.tril(torch.ones(T,T))
    
    if block_type == "decoder":
        wei = wei.masked_fill(tril == 0, float('-inf'))
    
    wei = F.softmax(wei, dim=1)

    out = wei @ v
    return 

In [176]:
wei.shape

torch.Size([1, 10, 10])

In [177]:
wei.shape

torch.Size([1, 10, 10])

In [178]:
out.shape

torch.Size([1, 10, 768])

In [106]:
# plt.figure(figsize=(15,5))
# sns.heatmap(out.squeeze(0))

In [166]:
out

tensor([[[-9.1074e-04,  1.0913e-02,  8.6040e-03,  ..., -2.1879e-04,
          -1.4773e-02, -5.9723e-03],
         [-2.2574e-02,  1.4292e-01,  1.8794e-01,  ..., -2.0085e-02,
          -1.4592e-01, -4.8681e-02],
         [ 1.6720e-02,  2.1470e-01,  3.2068e-01,  ..., -2.9511e-02,
          -2.2448e-01, -7.7544e-02],
         ...,
         [ 8.5476e-02,  1.6527e-01,  2.0111e-01,  ..., -7.6960e-02,
          -1.6512e-01, -5.5056e-02],
         [ 2.0144e-01,  3.1347e-01,  2.0510e-01,  ..., -1.6995e-01,
          -4.2555e-01, -1.9355e-02],
         [ 6.9687e-01,  1.5490e+00,  9.0180e-01,  ..., -4.2238e-01,
          -1.4563e+00, -7.3655e-01]]], grad_fn=<UnsafeViewBackward0>)