In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt

#nn.Softmax is an nn.Module, which can be initialized e.g. in the __init__ method of your model and used in the forward. 
#torch.softmax() and nn.functional.softmax are equal and I would recommend to stick to nn.functional.softmax, since it’s documented.

In [5]:
# Stage 1 and 2 until you see Stage 3

# Every decoder needs these:

seq_len = 512
embed_dim = 256
#head_dim = 32 -> in our implementation we calculate the equivalent
num_attn_heads = 2
# See diffence in ouput shape: last dimension is seq_len/num_atten_heads
batch_size = 16
#scaling_factor = head_dim ** -0.5

In [6]:
# Specific to our implementation

attn_head_input_dim = embed_dim
attn_head_output_dim = embed_dim // num_attn_heads

About bias: https://www.turing.com/kb/necessity-of-bias-in-neural-networks#what-is-bias-in-a-neural-network?

However, for certain types of layers, such as transformers and convolutional layers, including a bias term is unnecessary and adds unnecessary overhead to the model. The reason for this is that these layers are typically followed by a normalization layer, such as Batch 
Normalization or Layer Normalization. These normalization layers center the data at mean=0 (and std=1), effectively removing any bias.
Therefore, it is common practice to omit the bias term in transformers and convolutional layers that are preceded by a normalization layer.

In [15]:
# Fake input data in a tensor just to push it through here - token ids are mapped to embeddings.
# Stage 1: without batch dimension
#input_data = torch.randn(seq_len, embed_dim)
# Stage 2: simply add the batch dimension
input_data = torch.randn(batch_size, seq_len, embed_dim)
#Look at shape below to see the difference in the output - all the code remains the same!

# Stage 3: we putting all of the below in a nn.module that return the head's context vector in its forward()

class AttnHead(nn.Module):

    def __init__(self, attn_head_input_dim, attn_head_output_dim):
        super().__init__()

        # Just add self..
        self.Wq = nn.Linear(attn_head_input_dim, attn_head_output_dim, bias=False) 
        self.Wk = nn.Linear(attn_head_input_dim, attn_head_output_dim, bias=False)
        self.Wv = nn.Linear(attn_head_input_dim, attn_head_output_dim, bias=False) 

    def forward(self, input_data):

        # Just add self..
        queries = self.Wq(input_data)     
        keys = self.Wk(input_data)
        values = self.Wv(input_data)

        dim_of_key = keys.size(-1) 

        attn_scores = queries @ keys.transpose(-2,-1)/sqrt(dim_of_key) 

        mask = torch.ones((attn_scores.shape[-1], attn_scores.shape[-1]), dtype=torch.bool).triu(diagonal=1)

        attn_scores = attn_scores.masked_fill(mask, float("-inf"))

        attn_weights = F.softmax(attn_scores, dim = -1)

        head_context_vector = attn_weights@values 

        return head_context_vector

Each head is producing a head_context_vector - these are then brought back to embed_size via Wo
We do that in Stage 3, separate file

In [18]:
# Stage 3 == bring it back to embed_size via Wo

# Wo is defined in the multi-head level, while Wq/Wk/Wv are defined at the level of the attention head

Wo = nn.Linear(embed_dim, embed_dim)

attn_heads = nn.ModuleList(
    [AttnHead(attn_head_input_dim, attn_head_output_dim) for _ in range(num_attn_heads)]
)

concatenated_head_context_vectors = torch.cat(
               [attn_head(input_data) for attn_head in attn_heads], dim=-1
)

multihead_context_vector = Wo(concatenated_head_context_vectors)