<a href="https://colab.research.google.com/github/brownsloth/transformers_concepts_notebooks/blob/main/transformers_3_logic_behind_self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install bertviz

In [None]:
from bertviz.transformers_neuron_view import BertModel
from transformers import AutoTokenizer, AutoConfig
from bertviz.neuron_view import show
from torch import nn
from torch.nn import functional as F
import torch
from math import sqrt

In [None]:
model_ckpt = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BertModel.from_pretrained(model_ckpt)
text = "As the aircraft becomes lighter, it flies higher in the air of lower density to maintain the same airspeed."

show(model, 'bert', tokenizer, text, display_mode='light', layer=0, head=0)

In [None]:
inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False)
inputs.input_ids

In [None]:
config = AutoConfig.from_pretrained(model_ckpt)
config

In [None]:
token_embeddings_layer = nn.Embedding(config.vocab_size, config.hidden_size) ## this doesnt require more bert context?
token_embeddings_layer

In [None]:
input_embeddings = token_embeddings_layer(inputs.input_ids)
input_embeddings
input_embeddings.size()

In [None]:
## Get the Q, K, V vectors and do dot product to calculate attention
## All q,k,v are of shape (batch_size, seq_len, hidden_depth)
## for understanding purpose we can keep them same for now
query = key = value  = input_embeddings

dim_k = key.size(-1)
attention_scores = torch.bmm(query, key.transpose(1, 2))/sqrt(dim_k) #dimensions 1 and 2 are swapped in key
## bmm is required to do batchwise matmul
## bmm gets performed only on the last 2 dimensions

attention_scores.size()

In [None]:
"""
torch.nn.functional in PyTorch is used for operations that do not have trainable parameters or maintain state. It is often imported as F and offers a variety of functions for building neural networks, including:
Activation functions: relu, sigmoid, tanh, softmax, etc., are applied element-wise to introduce non-linearity.
Convolutional operations: conv2d, conv_transpose2d, etc., perform convolutions for feature extraction.
Pooling operations: max_pool2d, avg_pool2d, etc., reduce spatial dimensions.
Linear transformations: linear applies a linear transformation.
Loss functions: mse_loss, cross_entropy, etc., calculate the difference between predictions and actual values.
Dropout: dropout randomly zeroes elements to prevent overfitting.
torch.nn.functional is suitable when:
You need a simple, stateless operation.
You want to define custom operations within a neural network.
You need more flexibility than provided by torch.nn modules.
In contrast, torch.nn is used for layers with learnable parameters, such as nn.Linear, nn.Conv2d, and nn.BatchNorm2d. These layers manage their weights and biases internally, while torch.nn.functional requires you to handle these parameters manually.
"""
attention_weights = F.softmax(attention_scores, dim=1)
print(attention_weights.shape)
print(attention_weights.sum(dim=-1))
print(attention_weights.shape)

In [None]:
for i in range(attention_weights.shape[1]):
  print(attention_weights[0][i][i])

In [None]:
attention_outputs = torch.bmm(attention_weights, value)

print(attention_outputs.shape)
## Self attention ~~ Weigthed average of embeddings

## Combining all we have done so far into a single method to calculate attention

In [None]:
def scaled_dot_product_attention(query, key, value):
  dim_k = query.size()[-1]
  attn_scores = torch.bmm(query, key.transpose(1, 2))/sqrt(dim_k)
  attn_wts = F.softmax(attn_scores, dim=1)
  return torch.bmm(attn_wts, value)

In practice, the query key and value are linear projections of the input embeddings to any layer (learnable and used to capture semantic relationships). Since these are different heads, the mechanism is called multi-headed attention. One head softmax focuses on one aspect of relationship. Several heads => several aspects (which are learned and not hand-engineered, similar to filters in CNNs) captured ex: subject-verb interaction, nearby adjective etc.