# Bert Self Attention

Here I create an annotated version of HuggingFace implementation of [BertSelfAttention Class](https://github.com/huggingface/transformers/blob/ab756f713c7dd5257e27bb74edd906dfdfbf0e5d/src/transformers/modeling_bert.py#L183).

Note, this is different implementation than the one in PyTorch [MultHeadAttention](https://pytorch.org/docs/master/_modules/torch/nn/modules/activation.html#MultiheadAttention), although the quantiative results should be similar.

In [1]:
import platform; print("Platform", platform.platform())
import sys; print("Python", sys.version)
import torch; print("PyTorch", torch.__version__)
import torch.nn as nn
import torch.nn.functional as F

Platform Linux-4.15.0-1060-aws-x86_64-with-debian-buster-sid
Python 3.6.5 |Anaconda, Inc.| (default, Apr 29 2018, 16:14:56) 
[GCC 7.2.0]
PyTorch 1.3.1


## Annotation of HuggingFace Class

I insert comments or ammend comments that were already there without altering any of the code.

In [44]:
# Creating the pytorch module
class BertSelfAttention(nn.Module):
    # initializing with configuration
    def __init__(self, config):
        # Initializing parent classes
        super().__init__()
        # Check that hidden size is divisible by # of attn heads
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
        # store # of output attentions from config
        self.output_attentions = config.output_attentions

        # store  # of attention heads ...
        self.num_attention_heads = config.num_attention_heads
        # store size for a single attention head
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        # Set size of all heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Create  query/key/value as linear layer (with an option to shrink of
        # number of attention heads decreases)
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        # dropout layer with p from configuration
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    # transform a matrix of shape [batch_size, seq_len, hidden_size] to
    # [batch_size, num_attn_heads, seq_len, hidden_size/num_att_heads]
    # right before doing our dot product. This appears to be a "trick" 
    # to perform the matrix multiply on a per attention head basis.
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    # forward pass with hidden states
    # attention mask and head mask
    # also encoder hidden states and
    # Encoder attention mask (?)
    def forward(self,hidden_states,attention_mask=None,
        head_mask=None, encoder_hidden_states=None,encoder_attention_mask=None):
        # calculate outputs from query
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        if encoder_hidden_states is not None:
            mixed_key_layer = self.key(encoder_hidden_states)
            mixed_value_layer = self.value(encoder_hidden_states)
            attention_mask = encoder_attention_mask
        else:
            # if not instantiated as a cross-attention module, the keys
            # and values coem from the hidden states
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)

        # we need to maniuplate the outputs from [batch_size, seq_len, hidden_size] to
        # [batch_size, num_attn_heads, seq_len, hidden_size/num_att_heads] 
        # in order to perform our dot product aka similarity calculation
        # for each attention head separately 
        # we do this for the query, key, and value outputs
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attn scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        # Divide attn scores by the square root of the attn head dimension
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        
        # Check whether attention mask is provided
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for 
            # all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        
        # This is actually dropping out entire tokens to attend to, 
        # which might seem a bit unusual, but is taken from 
        # the original Transformer paper.
        attention_probs = self.dropout(attention_probs)
        
        # Check whether head mask is provided
        if head_mask is not None:
            # Mask heads if we want to
            attention_probs = attention_probs * head_mask
            
        # Standard matrix mult. between attention probabilities
        # and the value layer outputs (call it context)
        context_layer = torch.matmul(attention_probs, value_layer)
        
        # Swap out the 1 and 2 dimensions so:
        # [obs, # attn heads, seq_len, attn head dimension] to
        # [obs, seq_len, # attn heads, attn head dimension]
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        
        # Calculate the new shape which would be
        # [obs, seq_len, #attn heads * attn head dimension]
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        
        # Reshape the tensor
        context_layer = context_layer.view(*new_context_layer_shape)
        
        # Check whether we want to output the attentions
        # if so, return a tuple of context output and attention layer
        # if not, return a tuple of context output only
        outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
        return outputs

## Quick Check

We need to define the inputs which we will use a slightly simplified BertConfig class from HuggingFace to do ([link](https://github.com/huggingface/transformers/blob/master/src/transformers/configuration_bert.py)).

In [62]:
class BertConfig():

    def __init__(
        self,
        vocab_size=30522,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        output_attentions=False
):

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.output_attentions = output_attentions

In [63]:
config = BertConfig(output_attentions=False)
bsa = BertSelfAttention(config)

In [64]:
bsa

BertSelfAttention(
  (query): Linear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [65]:
import math
x = torch.rand(32, config.max_position_embeddings, config.hidden_size)
output_context = bsa(x)[0]

In [67]:
output_context.shape # e.g., obs, seq_len, hidden features

torch.Size([32, 512, 768])

### Exploring the `transpose_for_scores` function

In [68]:
hidden_size = 16
num_attention_heads = 2
attention_head_size = int(hidden_size / num_attention_heads)
all_head_size = num_attention_heads * attention_head_size
assert hidden_size % num_attention_heads == 0
query = nn.Linear(hidden_size, all_head_size)

In [69]:
batch_size = 32
seq_len = 10
hidden_states = torch.rand(batch_size, seq_len, hidden_size)
mixed_query_layer = query(hidden_states)
mixed_query_layer.shape

torch.Size([32, 10, 16])

Originally the query output is of the form [batch_size, seq_len, hidden_size]

In [70]:
hidden_states.size()[:-1]

torch.Size([32, 10])

In [71]:
hidden_states.shape

torch.Size([32, 10, 16])

In [72]:
# transpose function begins
def transpose_for_scores(x):
    # Prior shape is [batch_size, seq_len, hidden_size]
    # Create tuple for new shape:
    #    [batch_size, seq_len, num_attn_heads, attn_head_size]
    # Essentially
    new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
    # Change to new shape
    x = x.view(*new_x_shape)
    # swap seq_len and num_attn_heads dimensions
    return x.permute(0, 2, 1, 3)
    # Transpose function ends

In [73]:
query_layer = transpose_for_scores(mixed_query_layer)
query_layer.shape

torch.Size([32, 2, 10, 8])

But it is split into its respective attention heads before the matrix multiply. Brilliant.

### Quantiative Comparison to PyTorch Multhead Attention

TBD