In [1]:
import numpy as np
import torch
from torch import nn, einsum
import torch.nn.functional as F
import math

In [2]:
num_buckets = 6 # the total number of index buckets we'll use
max_distance = 20 # maximum sequence length

sequence_length = 14 # query length / input sequence length
max_context_length = 14 # key length: can be equal to sequence_length or greater if recurrence/memory is concatenated

In [16]:
q_pos = torch.arange(sequence_length, dtype=torch.long)
q_pos = q_pos.reshape(q_pos.shape[0], 1)
k_pos = torch.arange(max_context_length, dtype=torch.long)
rel_pos = k_pos - q_pos
rel_pos

tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13],
        [ -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12],
        [ -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11],
        [ -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10],
        [ -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
        [ -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8],
        [ -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7],
        [ -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6],
        [ -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5],
        [ -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4],
        [-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3],
        [-11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2],
        [-12, -11, -10,  -9,  -8,  -7,  

In [17]:
n = -rel_pos
n

tensor([[  0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12, -13],
        [  1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12],
        [  2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11],
        [  3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10],
        [  4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9],
        [  5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8],
        [  6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7],
        [  7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6],
        [  8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5],
        [  9,   8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4],
        [ 10,   9,   8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3],
        [ 11,  10,   9,   8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2],
        [ 12,  11,  10,   9,   8,   7,  

In [18]:
n = torch.max(n, torch.zeros_like(n))
n

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0],
        [ 8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0],
        [ 9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0],
        [10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0],
        [11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0],
        [12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0],
        [13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0]])

In [19]:
max_exact = num_buckets // 2

In [22]:
is_exact = n < max_exact
is_exact

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,

In [23]:
val_if_large = max_exact + \
  (
    torch.log(n.float() / max_exact)  # log of matrix divided by scalar
    / math.log(max_distance / max_exact) * (num_buckets - max_exact) # scalar
    ).long() # convert float to int

val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact))

In [24]:
val_if_large

tensor([[  -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [3.8078, 3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [4.0961, 3.8078, 3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [4.3399, 4.0961, 3.8078, 3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,
  

In [25]:
val_if_large = val_if_large.long()
val_if_large

tensor([[-9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808],
        [                   1, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808],
        [                   2,                    1, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854

In [26]:
position_bucket_indices = torch.where(is_exact, n, val_if_large)

In [27]:
position_bucket_indices

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
        [4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0],
        [4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0],
        [4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0],
        [5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0],
        [5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0],
        [5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0]])

In [31]:
n_heads = 4
pos_embedding_bias = nn.Embedding(num_buckets, n_heads)

In [32]:
pos_embedding_bias.weight

Parameter containing:
tensor([[-0.1905, -0.6951,  1.1648, -0.9169],
        [ 0.0761, -0.3987,  0.1826,  0.1903],
        [-0.2134,  0.7328, -1.4838,  1.1680],
        [ 0.3152, -1.4295, -2.7264, -0.6454],
        [-0.6498, -1.5748, -0.9342,  0.2301],
        [-0.1159,  0.9647,  0.7951, -0.5426]], requires_grad=True)

In [35]:
pos_embedding_values = pos_embedding_bias(position_bucket_indices)

In [36]:
pos_embedding_values.shape

torch.Size([14, 14, 4])