# Rel Pos

### From Memorizing Transformers paper:

"Position bias. For dense attention within the local context, we use the T5 relative position bias (Raffel
et al., 2020). As noted by Dai et al. (2019), adding a global position encoding to each token does not
work well when processing long documents. We don’t use a position bias for the retrieved memories.
Experiments on the PG19 dataset (Sun et al., 2021) have shown that relative position does not appear
to matter at long range, and the T5 relative bias puts all long-range tokens in the same bucket anyway."

### From T5 paper:

"Since self-attention is order-independent (i.e. it is an operation on sets), it is common
to provide an explicit position signal to the Transformer. While the original Transformer
used a sinusoidal position signal or learned position embeddings, it has recently become
more common to use relative position embeddings (Shaw et al., 2018; Huang et al., 2018a).
Instead of using a fixed embedding for each position, relative position embeddings produce
a different learned embedding according to the offset between the “key” and “query” being
compared in the self-attention mechanism. We use a simplified form of position embeddings
where each “embedding” is simply a scalar that is added to the corresponding logit used
for computing the attention weights. For efficiency, we also share the position embedding
parameters across all layers in our model, though within a given layer each attention head
uses a different learned position embedding. Typically, a fixed number of embeddings are
learned, each corresponding to a range of possible key-query offsets. In this work, we use 32
embeddings for all of our models with ranges that increase in size logarithmically up to an
offset of 128 beyond which we assign all relative positions to the same embedding. Note
that a given layer is insensitive to relative position beyond 128 tokens, but subsequent layers
can build a sensitivity to larger offsets by combining local information from previous layers."

In [None]:
# RELATIVE POSITION MATRIX
# tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13],
#         [ -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12],
#         [ -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11],
#         [ -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10],
#         [ -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
#         [ -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8],
#         [ -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7],
#         [ -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6],
#         [ -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5],
#         [ -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4],
#         [-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3],
#         [-11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2],
#         [-12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1],
#         [-13, -12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0]])

### Idea

- Positional embeddings are added to the QK embeddings during attention
- Relative position embeddings identify, for each input example, how far away all the other tokens are from a specific token of interest
- Instead of giving each token a relative position index of n that is n positions away from our token of interest, T5 relative position "buckets" some tokens into the same index
- First we create this set of indices. then the indices are matched to an embedding layer of weight values. These values are then added to the QK embeddings during attention. The positional embeddings are trained with the network.

### Recipe

- Construct a relative position matrix
- For offsets larger than what we want, start to spread offset values logarithmically into a finite amount of buckets. (Past a certian max value (128) we'll just map everything to one value)
- Initialize embedding weights that we will assign offset values to
- Now the relative position matrix is mapped to these weights
- This matrix gets added to our attention when we perform self-attention. Our self-attention now incorporates as a piece of information the relative positions between tokens



In [None]:
import numpy as np
import torch
from torch import nn, einsum
import torch.nn.functional as F
import math

In [None]:
num_buckets = 6 #tot no. of index buckets we'll see
max_distance = 20 #max sequence length

sequence_length = 14
max_context_length = 14

In [None]:
q_pos = torch.arange(sequence_length, dtype=torch.long)
q_pos = q_pos.reshape(q_pos.shape[0], 1)
k_pos = torch.arange(max_context_length, dtype=torch.long)
rel_pos = k_pos - q_pos
rel_pos

tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13],
        [ -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12],
        [ -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11],
        [ -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10],
        [ -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
        [ -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8],
        [ -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7],
        [ -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6],
        [ -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5],
        [ -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4],
        [-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3],
        [-11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2],
        [-12, -11, -10,  -9,  -8,  -7,  

In [None]:
n = -rel_pos
n

tensor([[  0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12, -13],
        [  1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12],
        [  2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11],
        [  3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10],
        [  4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9],
        [  5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8],
        [  6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7],
        [  7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6],
        [  8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5],
        [  9,   8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4],
        [ 10,   9,   8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3],
        [ 11,  10,   9,   8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2],
        [ 12,  11,  10,   9,   8,   7,  

In [None]:
n = torch.max(n, torch.zeros_like(n))
n

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0],
        [ 8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0],
        [ 9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0],
        [10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0],
        [11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0],
        [12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0],
        [13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0]])

In [None]:
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
max_exact

3

In [None]:
is_small = n < max_exact
is_small

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,

The other half of the buckets are for logarithmically bigger bins in positions up to max_distance.

So, we map a positional embeddings up to a number k exactly (offset by 1, offset by 2, offset by 3...) but at a certain point we have a longer sequence than positional embedding "buckets" (like bins), so we map them logarithmically to spread out over our fixed number of buckets: e.g. [1,2,3,4,5,5,6,6,7,7,7,8,8,8,8,]

In [None]:
val_if_large = max_exact + \
  (
    torch.log(n.float() / max_exact)  # log of matrix divided by scalar
    / math.log(max_distance / max_exact) * (num_buckets - max_exact) # scalar
    ).long() # convert float to int

val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact))

In [None]:
val_if_large

tensor([[  -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [3.8078, 3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [4.0961, 3.8078, 3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [4.3399, 4.0961, 3.8078, 3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,
  

In [None]:
val_if_large = val_if_large.long()
val_if_large

tensor([[-9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808],
        [                   1, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808],
        [                   2,                    1, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854

In [None]:
position_bucket_indices = torch.where(is_small, n, val_if_large)
position_bucket_indices

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
        [4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0],
        [4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0],
        [4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0],
        [5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0],
        [5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0],
        [5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0]])

In [None]:
heads = 4
relative_position_bias = nn.Embedding(num_buckets, heads)
relative_position_bias

Embedding(6, 4)

In [None]:
relative_position_bias.weight

Parameter containing:
tensor([[-0.7429,  1.0904,  0.8333,  0.2273],
        [-0.4354,  1.5028,  1.2448,  0.7483],
        [-0.0872, -1.7964,  0.8769, -0.3345],
        [ 2.7251,  1.2114, -0.8605, -0.7477],
        [-0.9158, -1.1318, -1.0841, -1.7824],
        [ 0.1928,  0.4019,  2.2828, -1.8319]], requires_grad=True)

In [None]:
relative_position_values = relative_position_bias(position_bucket_indices)
relative_position_values.shape

torch.Size([14, 14, 4])

In [None]:
# Need to reshape from (sequence, context, heads) -> (batch, heads, sequence, context)
relative_position_values = relative_position_values.transpose(0,2).unsqueeze(0)
relative_position_values.shape


torch.Size([1, 4, 14, 14])

### Putting it all in a class:

In [None]:
class RelativePosition(nn.Module):
  def __init__(
      self,
      rp_scale,
      num_buckets = 32,
      rp_max_distance = 128,
      heads = 8
  ):
      super().__init__()
      self.scale = rp_scale
      self.num_buckets = num_buckets
      self.rp_max_distance = rp_max_distance
      self.relative_attention_embedding = nn.Embedding(num_buckets, heads)

  def relative_position_bucket(self, relative_position_matrix):
      n = -relative_position_matrix
      n = torch.max(n, torch.zeros_like(n))

      max_exact = self.num_buckets // 2

      is_small = n < max_exact
      val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(self.rp_max_distance / max_exact) * (self.num_buckets - max_exact)).long()
      val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, self.num_buckets - 1))

      return torch.where(is_small, n, val_if_large)

  def forward(self, sequence_length, max_context_length):

      sequence_pos = torch.arange(sequence_length, dtype=torch.long)
      context_pos = torch.arange(max_context_length, dtype=torch.long)
      sequence_pos = sequence_pos.reshape(sequence_pos.shape[0], 1)
      rel_pos = context_rel_pos - sequence_rel_pos

      position_bucket_indices = self.relative_position_bucket(rel_pos)

      rp_values = self.relative_attention_embedding(position_bucket_indices)
      # Rearrange (sequence, context, heads) -> (1, heads, sequence, context)
      rp_values = rp_values.transpose(0,2)
      rp_values = rp_values.unsqueeze(0)
      return rp_values * self.scale