## Extract embeddings from BERT

In [16]:
import matplotlib.pyplot as plt
import seaborn as sns

In [215]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import BertTokenizer, BertModel

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# Sample sentence
sentences = ["The police is chasing a criminal on the run.", "The criminal is hiding in the police van."]

# Tokenize the sentence and convert to input IDs
input_ids = tokenizer(sentences, padding=True,return_tensors='pt')

# Get BERT embeddings
with torch.no_grad():
    outputs = model(**input_ids)

# # Extract word embeddings from BERT outputs

# # Shape: [seq_len, hidden_size]
input_embeddings = outputs.last_hidden_state

# Attention Mechanism

An attention mechanism is a key component in artificial intelligence and deep learning, particularly in models like transformers. It enables models to focus on specific parts of input data while making predictions or generating outputs. 

Imagine you're reading a long paragraph and trying to summarize it. Your attention is likely to focus more on important sentences or keywords rather than every single word. Similarly, in AI models, the attention mechanism helps the model assign different weights to different parts of the input data, giving more importance to relevant information.

This mechanism allows the model to process sequences of data more effectively by selectively attending to the most relevant elements at each step of the computation, improving the model's performance in various tasks.

<b>Breakdown</b>
1. <b>Input</b>: Input data, such as a sequence of words in NLP or an image in computer vision, is typically transformed into embeddings or feature vectors that capture the data's semantic or structural information.
2. <b> Query, Key, and Value </b>:  Input embeddings are further transformed into three sets of vectors: Query, Key, and Value. The transfomrations are linear projections followed by non-linear activation functions. 
3. <b> Attention Weights </b>: Attention weights indicate the importance or relevance of each element (token, pixel, etc.) in the input sequence. Computed using methods like (dot product, scaled dot product, etc.) between the Q and K.
4. <b> Weighted Sum </b>: Once the attention weights are computed, they are used to calculate a weighted sum of the corresponding Value vectors. This represents the attended information, i.e., the parts of the input that are most relevant or important for the current context.

Example:
   - Q: What the token is looking for?
   - K: Gist of what each token can offer.
   - V: Actual content of each token.

Imaginge browsing a streaming service. Your Query might be the genre like "comedy" or "action-packed". The movie titles, posters, and descriptions serve as Key. Actual film serve as the Value.

<b>Multi-head attention</b>
Attention mechanism is applied multiple times in parallel with different sets of Query, Key, and Value transformations. Multiple heads let us attend to several words. Each head can focus on specific ascpet of the input. Example: Subject, Object and the action. The outputs of multiple attention heads are concatenated or combined to provide diverse and richer representations.

In [217]:
# Shape of the input embeddings [batch_size, input_size,embedding_dimension]
B, T, C = input_embeddings.shape

In [218]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
out = wei @ input_embeddings

In [223]:
def self_attention(input_embed, is_decoder:bool):
    # Shape of the input embeddings [batch_size, input_size,embedding_dimension]
    B, T, C = input_embed.shape
    
    key = nn.Linear(C, C)
    query = nn.Linear(C, C)
    value = nn.Linear(C, C)

    k = key(input_embed)
    q = query(input_embed)
    v = value(input_embed)
    wei = q @ k.permute(0,2,1) / C**0.5

    tril = torch.tril(torch.ones(T,T))
    
    if is_decoder:
        # in a decoder, the current token will not have access to future tokens. Ex: Generation.
        # in an encoder, all the tokens will jointly attend to each other. Ex: Text Classification.
        wei = wei.masked_fill(tril == 0, float('-inf'))
    
    wei = F.softmax(wei, dim=1)

    out = wei @ v
    return out

In [251]:
def multi_head_attention(input_embed, n_heads):
    """
    Perform multi-head attention on input embeddings.

    Args:
    - input_embed: Input embeddings tensor of shape [batch_size, input_size, embedding_dimension].
    - n_heads: Number of attention heads.

    Returns:
    - multi_head_output: Concatenated output of multi-head attention, shape [batch_size, input_size, n_heads * head_size].
    
    multi_head_output must pass through a Linear Layer.
    """
    
    B, T, C = input_embed.shape
    
    head_size = C//n_heads
    
    multi_head_output = torch.tensor(())
    for _ in range(n_heads):
        key = nn.Linear(C, head_size)
        query = nn.Linear(C, head_size)
        value = nn.Linear(C, head_size)

        k = key(input_embed)
        q = query(input_embed)
        v = value(input_embed)
        wei = q @ k.permute(0,2,1) / C**0.5
        tril = torch.tril(torch.ones(T,T))
        wei = F.softmax(wei, dim=1)

        out = wei @ v
        multi_head_output = torch.cat((multi_head_output, out), -1)
    print("Shape of each head", out.shape)    
    return multi_head_output

In [252]:
multi_head_output = multi_head_attention(input_embeddings, n_heads=8)
print("Shape of concatenated output", multi_head_output.shape)    

Shape of each head torch.Size([2, 12, 96])
Shape of concatenated output torch.Size([2, 12, 768])


In [253]:
def masked_multi_head_attention(input_embed, n_heads):
    """
    Perform masked multi-head attention on input embeddings.

    Args:
    - input_embed: Input embeddings tensor of shape [batch_size, input_size, embedding_dimension].
    - n_heads: Number of attention heads.

    Returns:
    - masked_multi_head_output: Concatenated output of masked multi-head attention, shape [batch_size, input_size, n_heads * head_size].
    
    masked_multi_head_output must pass through a Linear Layer.
    """
    B, T, C = input_embed.shape
    
    head_size = C//n_heads
    
    masked_multi_head_output = torch.tensor(())
    for _ in range(n_heads):
        key = nn.Linear(C, head_size)
        query = nn.Linear(C, head_size)
        value = nn.Linear(C, head_size)

        k = key(input_embed)
        q = query(input_embed)
        v = value(input_embed)
        wei = q @ k.permute(0,2,1) / C**0.5
        tril = torch.tril(torch.ones(T,T))
        
        # Apply mask so that the current token will not have access to future tokens. 
        wei = wei.masked_fill(tril == 0, float('-inf'))
        wei = F.softmax(wei, dim=1)

        out = wei @ v
        masked_multi_head_output = torch.cat((masked_multi_head_output, out), -1)
        
    return masked_multi_head_output

In [254]:
masked_multi_head_output = masked_multi_head_attention(input_embeddings, n_heads=8)
masked_multi_head_output.shape

torch.Size([2, 12, 768])