In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt

#nn.Softmax is an nn.Module, which can be initialized e.g. in the __init__ method of your model and used in the forward. 
#torch.softmax() and nn.functional.softmax are equal and I would recommend to stick to nn.functional.softmax, since it’s documented.

In [6]:
# Stage 1 and 2 until you see Stage 3

# Every decoder needs these:

seq_len = 512
embed_dim = 256
#head_dim = 32 -> in our implementation we calculate the equivalent
num_attn_heads = 2
# See diffence in ouput shape: last dimension is seq_len/num_atten_heads
batch_size = 16
#scaling_factor = head_dim ** -0.5

In [7]:
# Specific to our implementation

attn_head_input_dim = embed_dim
attn_head_output_dim = embed_dim // num_attn_heads

About bias: https://www.turing.com/kb/necessity-of-bias-in-neural-networks#what-is-bias-in-a-neural-network?

However, for certain types of layers, such as transformers and convolutional layers, including a bias term is unnecessary and adds unnecessary overhead to the model. The reason for this is that these layers are typically followed by a normalization layer, such as Batch 
Normalization or Layer Normalization. These normalization layers center the data at mean=0 (and std=1), effectively removing any bias.
Therefore, it is common practice to omit the bias term in transformers and convolutional layers that are preceded by a normalization layer.

In [9]:
# Fake input data in a tensor just to push it through here - token ids are mapped to embeddings.
# Stage 1: without batch dimension
#input_data = torch.randn(seq_len, embed_dim)
# Stage 2: simply add the batch dimension
input_data = torch.randn(batch_size, seq_len, embed_dim)
#Look at shape below to see the difference in the output - all the code remains the same!

# Note that we don't need the batch dimension for these - each head has its own set of Wq/Wk/Wv and they are trained across batches
# In the batched version (stage 2) these W's only operate on the last two dimensions for each batch element
Wq = nn.Linear(attn_head_input_dim, attn_head_output_dim, bias=False) # After training, Wq has information about tokens
Wk = nn.Linear(attn_head_input_dim, attn_head_output_dim, bias=False) # After training, Wk has information about surrounding tokens
Wv = nn.Linear(attn_head_input_dim, attn_head_output_dim, bias=False) 

queries = Wq(input_data)     # (seq_len x embed_size) @ (embed_size, output_dim) -> (seq_len, outpu_dim) w/ latter the dim ot context vector
keys = Wk(input_data)
values = Wv(input_data)

dim_of_key = keys.size(-1) 

attn_scores = queries @ keys.transpose(-2,-1)/sqrt(dim_of_key) # divide by "inner" dimension == key_dim == query_dim (must)
# attn_scores.shape -> (seq_len, seq_len) with dot product of the embeddings in each cell
# attn_scores are the normalized dot products of the (reduced) embedding vectors for the token_ids (as queries and keys)

# Here is where we will apply masking (before we apply the softmax)

# attn_scores is always going to be square (every token has to be resonated with every token)
mask = torch.ones((attn_scores.shape[-1], attn_scores.shape[-1]), dtype=torch.bool).triu(diagonal=1)

# tensor([[False,  True,  True,  ...,  True,  True,  True], ...

attn_scores = attn_scores.masked_fill(mask, float("-inf"))

attn_weights = F.softmax(attn_scores, dim = -1)
# tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00, 0.0000e+00], ...

# Now we need to apply these weights to the value vectors -> these value vectors contain the information that's passed on 
# to the next layer.
# The queries and keys go thought a softmax, so lots of the information gets lost there, but that's why we're only using them to weigh.


head_context_vector = attn_weights@values # (seq_len, seq_len) @ (seq_len x output_dim) -> (seq_len, output_dim)
#head_context_vector2 = torch.bmm(attn_weights, values) -> same result

head_context_vector.shape

torch.Size([16, 512, 128])

Each head is producing a head_context_vector - these are then brought back to embed_size via Wo
We do that in Stage 3, separate file