## Extract embeddings from BERT

In [16]:
import matplotlib.pyplot as plt
import seaborn as sns

In [215]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import BertTokenizer, BertModel

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# Sample sentence
sentences = ["The police is chasing a criminal on the run.", "The criminal is hiding in the police van."]

# Tokenize the sentence and convert to input IDs
input_ids = tokenizer(sentences, padding=True,return_tensors='pt')

# Get BERT embeddings
with torch.no_grad():
    outputs = model(**input_ids)

# # Extract word embeddings from BERT outputs

# # Shape: [seq_len, hidden_size]
input_embeddings = outputs.last_hidden_state

## Attention Mechanism

Attention enables contextualized word embeddings by allowing the model to selectively focus on different parts of the input sequence when making predictions. Put simply, the attention mechanism allows the transformer to dynamically weigh the importance of different parts of the input sequence based on the current task and context.

In [217]:
# Shape of the input embeddings [batch_size, input_size,embedding_dimension]
B, T, C = input_embeddings.shape

In [218]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
out = wei @ input_embeddings

In [223]:
def self_attention(input_embed, is_decoder:bool):
    # Shape of the input embeddings [batch_size, input_size,embedding_dimension]
    B, T, C = input_embed.shape
    
    key = nn.Linear(C, C)
    query = nn.Linear(C, C)
    value = nn.Linear(C, C)

    k = key(input_embed)
    q = query(input_embed)
    v = value(input_embed)
    wei = q @ k.permute(0,2,1) / C**0.5

    tril = torch.tril(torch.ones(T,T))
    
    if is_decoder:
        # in a decoder, the current token will not have access to future tokens. Ex: Generation.
        # in an encoder, all the tokens will jointly attend to each other. Ex: Text Classification.
        wei = wei.masked_fill(tril == 0, float('-inf'))
    
    wei = F.softmax(wei, dim=1)

    out = wei @ v
    return out