We will do relative positional encoding with modifications from the T5 paper

The basic idea of relative pos encoding is using a matrix like this:

tensor([[ 0, 1, 2, 3],
        [-1, 0, 1, 2],
        [-2,-1, 0, 1],
        [-3,-2,-1, 0]])

So that each positioning is encoded relative to the "current token" (query).

Note that for decoder self-attention we want to use causal masking (meaning: we don't want to give ANY info about future tokens)
so it will look like this:

tensor([[ 0, 0, 0, 0],
        [-1, 0, 0, 0],
        [-2,-1, 0, 0],
        [-3,-2,-1, 0]])

and in this implementation we will not use negatives (does not matter) so it will be

tensor([[ 0, 0, 0, 0],
        [ 1, 0, 0, 0],
        [ 2, 1, 0, 0],
        [ 3, 2, 1, 0]])

We'll then apply "buckets" where ranges will get the same pos. encoding, with the first 3 "past" tokens have their exact position, so it will look something like this for the last row (larger seq_len):

[7,7,7,7,6,6,6,5,5,5,4,4,3,2,1,0,]])


Some general info about nn.Embedding:

Both nn.Linear and nn.Embedding will given you, in your example, a 3-dim vector. That’s the whole point, i.e., to convert a token 
into an ideally meaningful vectors (i.e., a numeric and fix-sized representation of a word). The difference is w.r.t. the input
nn.Linear expects a one-hot vector of the size of the vocabulary with the single 1 at the index representing the specific word.
nn.Embedding just expects this index (and not a whole vector).

However, if both nn.Linear and nn.Embedding would be initialized with the same weights, their outputs would be exactly the same.

Yes, by default, the weights of both layers will be modified during the training process. In this respect, there are like any other 
layers in your network. However, you can tell the network not to modify the weights of any specific layer; I think it would look 
something like this:

embedding = nn.Embedding(10, 3)
embedding = weight.requires_grad = False

This makes sense if you use pretrained word embeddings such as Word2Vec or Glove. If you initialize your weights randomly, you 
certainly want them to be modified during training.

How is this matrix used:

Basically, we'll wrap this all in a class that will be storing an embedding vector per position. In the video, the size of this vector
corresponds to the number of heads, TBD why that is.

Then, when we have the qv matrix with the dot products (scalars), well be adding the value for this attention head (from the vector) to 
the dot product. Since we're doing addition, this is all differentiable all the way to the vectors stored in the class, which is stored
in a nn.Embedding object that is returned by the forward() method of our class - I think the latter is irrelevant though: it's 
differentiable since we're adding from the Embedding object itself, so backward() can backprop into it.

In [20]:
import numpy as np
import torch
from torch import nn, einsum
import torch.nn.functional as F
import math

num_buckets = 6
max_distance = 20    # Max sequence length - this will be 128 as per the paper
seq_len = 15         # This is query length
max_context_len = 15 # This is key length - normally same as query length but not for XL Trfrmrs where we concat keys as part of recurrency

# Now we construct a matrix as per the above

q_pos = torch.arange(seq_len, dtype=torch.long)               # Top row
k_pos = torch.arange(max_context_len, dtype=torch.long) 

# Trick:
#[0, 1, 2, 3] - [[0], == (via broadcasting) [[0, 1, 2, 3] - [[0, 0, 0, 0], == [[ 0, 1, 2, 3], 
#                [1],                        [0, 1, 2, 3]    [1, 1, 1, 1],     [-1, 0, 1, 2],
#                [2],                        [0, 1, 2, 3]    [2, 2, 2, 2],     [-2,-1, 0, 1],
#                [3]]                        [0, 1, 2, 3]]   [3, 3, 3, 3]]     [-3,-2,-1, 0]]

# So we need to convert q_pos to a column vector:
q_pos = q_pos.reshape(q_pos.shape[0], 1)

rel_pos = k_pos - q_pos
#rel_pos # With seq_len 10 for query and max_context_len 15 for (concatenated) keys this gives:
# Query goes "up/down" since we only have the current sequence, but we match it with a concat of keys for recurrence ->>
#tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14],
#        [-1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
#        [-2, -1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
#        [-3, -2, -1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
#        [-4, -3, -2, -1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
#        [-5, -4, -3, -2, -1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
#        [-6, -5, -4, -3, -2, -1,  0,  1,  2,  3,  4,  5,  6,  7,  8],
#        [-7, -6, -5, -4, -3, -2, -1,  0,  1,  2,  3,  4,  5,  6,  7],
#        [-8, -7, -6, -5, -4, -3, -2, -1,  0,  1,  2,  3,  4,  5,  6],
#        [-9, -8, -7, -6, -5, -4, -3, -2, -1,  0,  1,  2,  3,  4,  5]])

# Next: since we'er building an encoder, we "mask" the future by putting it to 0 i.e., we don't encode anything for the future 
# Also we make neg pos, just for convenience - doesn't really matter since it's all relative and consistent

rel_pos = -rel_pos


rel_pos = torch.max(rel_pos, torch.zeros_like(rel_pos))

#rel_pos # For 10x20
#tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#        [2, 1, 0, 0, 0, 0, 0, 0, 0, 0],
#        [3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
#        [4, 3, 2, 1, 0, 0, 0, 0, 0, 0],
#        [5, 4, 3, 2, 1, 0, 0, 0, 0, 0],
#        [6, 5, 4, 3, 2, 1, 0, 0, 0, 0],
#        [7, 6, 5, 4, 3, 2, 1, 0, 0, 0],
#        [8, 7, 6, 5, 4, 3, 2, 1, 0, 0],
#        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]])

#rel_pos # For seq_len/query 10 and max_context_len/keys 15:
#tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#        [2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#        [3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#        [4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#        [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#        [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#        [7, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0],
#        [8, 7, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
#        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]])                 

# Now for the T5 modifications == the buckets

# First half of the buckets are the actual tokens, so "buckets" with just one token in them

num_token_buckets = num_buckets // 2    # This is 3 if num_buckets is 6, so 0, 1, 2 are exact and in single-item buckets

# We're making the changes by applying masks on the matrix elements

# First a mask that puts "True" on items that don't need to change (first 3), and False on all the others

is_exact = rel_pos < num_token_buckets

#is_exact
# Last line is [False, False, False, False, False, False, False,  True,  True,  True]]) so that's True for ... 2, 1, 0]])  

# Second mask: a mask that logaritmically puts more and more items in bins, up to max_distance.
# This works by transforming the number to a max of num_buckets

val_if_large = \
num_token_buckets + \
(torch.log(rel_pos.float() / num_token_buckets) / math.log(max_distance / num_token_buckets) * (num_buckets - num_token_buckets))

# val_if_large
# [5.4360, 5.3188, 5.1922, 5.0546, 4.9039, 4.7373, 4.5510, 4.3399, 4.0961, 3.8078, 3.4549, 3.0000, 2.3588, 1.2627,   -inf]]) 

# long() just converts to int
val_if_large = val_if_large.long()

# print(val_if_large) -> for 15 x 15 q x k
# [5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, -9223372036854775808]]) # The last one is the smallest long int

# The below seems to be extra precaution to make sure that any number in here is never larger than (num_buckets -1).
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

position_bucket_indices = torch.where(is_exact, rel_pos, val_if_large) # Where is_exact is True, put value from n, otherwise value from v_i_l

position_bucket_indices #-> 0, 1, 2 are always exact, from 3 on we start proper bucketing

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
        [4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0],
        [4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0],
        [5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0],
        [5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0],
        [5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0],
        [5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0]])

Note on Embedding layer: this is just like a Linear layer, only it's optimized for sparse inputs and one-hot encoding.

--

Now we need to turn ALL of these items in positional embeddings.

In our reference implementation, this is done like this:

self.position_embeddings = nn.Embedding(self.seq_len, config['embed_size'])

Where embed_size matches the embed_size of the token embeddings, so that we can add the vectors up

In our reference implementation, we do absolute, not relatlive, and the Embedding layers has to "learn" a vector of weights
for each of the absolute positions.

Here, there will be one such vector for each relative position i.e., 0, 1, ..., 5 -> 6 vectors, where each vector has size "num_heads".

In [23]:
num_attn_heads = 4

rel_pos_embeddings = nn.Embedding(num_buckets, num_attn_heads)
#rel_pos_embeddings
#Embedding(6, 4)

#print(rel_pos_embeddings.weight)
#tensor([[-0.1299,  1.1240, -1.4298,  0.9117],   <--- 
#        [-0.1299,  1.9153, -1.3115,  0.3001],
#        [ 1.4696,  0.7510,  0.6996, -1.7691],
#        [-1.2145, -1.5550, -2.3301, -1.0363],
#        [ 0.7550, -1.3407,  1.7770,  0.6528],
#        [ 0.1610,  2.0491, -2.0410,  0.7201]], requires_grad=True)

# Now we need to add a num_attn_heads sized vector to the token embeddings in the query-key matrics, which means that in the
# position_bucket_indices matrix we need to replace each number by corresponding vector in position_embeddings. This way we can then
# add up the two query-key sized matrices element by element (overlay vector addition).

q_k_rel_pos_embeddings = rel_pos_embeddings(position_bucket_indices)

#q_k_rel_pos_embeddings.shape
#torch.Size([15, 15, 4])   -> (seq_len for queries, max_context_len for keys, size of pos. embedding vector)

# Note: in the paper, they have a (seq_len, max_context_len) matrix for each head, so in each head they only use one of the
# num_attn_heads (4 in this case) "slices". So we need to bring the last dimenion to the front (to easily get to a slice).

# So we need to go from (seq_len, max_context_len, num_attn_heads) -> (batch, num_attn_heads, seq_len, max_context_len)


#print(q_k_rel_pos_embeddings.shape)  # -> torch.Size([10, 20, 4]) so (seq_len, max_content_len, num_attn_heads)

#print(q_k_rel_pos_embeddings.transpose(0,2).shape) # -> torch.Size([4, 20, 10]) so (num_attn_heads, max_content_len, seq_len)

# The above is what the video does, and it does not seem to lead to the correct dimensions, (num_attn_heads, seq_len, max_content_len)
# which would be (4, 10, 20)

q_k_rel_pos_embeddings = q_k_rel_pos_embeddings.transpose(-1,-2).transpose(0,1) # -> (4, 10, 20)


#print(rel_pos_embeddings.weight)
#print(q_k_rel_pos_embeddings[0])
#
#tensor([[-0.6612, -1.0359,  0.1423, -0.1087], <- This is the num_attn_head size vector for rel pos 0
#        [ 0.3478,  0.1189,  0.9446, -0.2021], <- ""                                                1
#        [ 0.4463, -0.3085,  0.4393, -0.6813], <- ""                                                2
#        [-0.1137, -0.1136, -1.1700, -0.3337], <- ""                                                3
#        [-0.6651, -1.0278, -0.1142,  0.3004], <- ""                                                4
#        [-1.2078,  0.9238, -0.5395,  1.7939]], requires_grad=True) <- ""                           5
#
# We are taking the first vertical slice here, so 0 -> -0.6612, 1 -> 0.3478, etc.
#
#tensor([[-0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612],
#        [ 0.3478, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612],
#        [ 0.4463,  0.3478, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612],
#        [-0.1137,  0.4463,  0.3478, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612],
#        [-0.1137, -0.1137,  0.4463,  0.3478, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612],
#        [-0.1137, -0.1137, -0.1137,  0.4463,  0.3478, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612],
#        [-0.6651, -0.1137, -0.1137, -0.1137,  0.4463,  0.3478, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612],
#        [-0.6651, -0.6651, -0.1137, -0.1137, -0.1137,  0.4463,  0.3478, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612],
#        [-0.6651, -0.6651, -0.6651, -0.1137, -0.1137, -0.1137,  0.4463,  0.3478,
#         -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612],
#        [-0.6651, -0.6651, -0.6651, -0.6651, -0.1137, -0.1137, -0.1137,  0.4463,
#          0.3478, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612, -0.6612,
#         -0.6612, -0.6612, -0.6612, -0.6612]], grad_fn=<SelectBackward0>)

# Now add the batch dimension so that it becomes (1, 4, 10, 20) so (1, num_attn_heads, seq_len, max_context_len)

q_k_rel_pos_embeddings = q_k_rel_pos_embeddings.unsqueeze(0)

q_k_rel_pos_embeddings.shape

torch.Size([1, 4, 15, 15])

In [15]:
# Class version

# Again this returns a tensor object of size (1, num_attn_heads, queries, keys) where each attn head is 
# supposed to one "layer" which is queries x keys

class RelativePosition(nn.Module):

    def __init__(self, scaling_factor, num_buckets, max_distance, num_attn_heads):
        
        super().__init__()
        self.scaling_factor = scaling_factor
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.rel_pos_embeddings = nn.Embedding(num_buckets, num_attn_heads)

    # Which each forward pass we provide our "latest" tensor object, updated in the previous backward() pass
    
    def forward(self, seq_len, max_context_len):     # These app. can change depending on the layer - check this.
        
        q_pos = torch.arange(seq_len, dtype=torch.long)               
        k_pos = torch.arange(max_context_len, dtype=torch.long)
        q_pos = q_pos.reshape(q_pos.shape[0], 1)
        rel_pos = k_pos - q_pos
        rel_pos = -rel_pos
        rel_pos = torch.max(rel_pos, torch.zeros_like(rel_pos))
        # We have rel_pos, now we need is_exact matrix and val_if_large matrix
        
        num_token_buckets = self.num_buckets // 2
        is_exact = rel_pos < num_token_buckets

        val_if_large = \
        num_token_buckets + \
        (torch.log(rel_pos.float() / num_token_buckets) / math.log(self.max_distance / num_token_buckets) * (self.num_buckets - num_token_buckets))
        val_if_large = val_if_large.long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        position_bucket_indices = torch.where(is_exact, rel_pos, val_if_large)

        q_k_rel_pos_embeddings = self.rel_pos_embeddings(position_bucket_indices)

        q_k_rel_pos_embeddings = q_k_rel_pos_embeddings.transpose(-1,-2).transpose(0,1)
        
        q_k_rel_pos_embeddings = q_k_rel_pos_embeddings.unsqueeze(0)

        return q_k_rel_pos_embeddings * self.scaling_factor # This is new in the class.        
        

In [24]:
relpos = RelativePosition(1, 6, 20, 4)
relpos(15,15).shape

torch.Size([1, 4, 15, 15])