In [1]:
import torch
import math

In [2]:
max_tokens = 12
tokens_per_sample = 5
num_heads = 4

In [3]:
start = (2**(-2**-(math.log2(num_heads)-3)))
start

0.25

In [4]:
ratio = start
slopes = [start*ratio**i for i in range(num_heads)]
slopes

[0.25, 0.0625, 0.015625, 0.00390625]

In [5]:
slopes_t = torch.Tensor(slopes)
slopes_t

tensor([0.2500, 0.0625, 0.0156, 0.0039])

In [19]:
qlen = tokens_per_sample
klen = tokens_per_sample
device = "cpu"

In [20]:
context_position = torch.arange(qlen, dtype=torch.long, device=device)[:, None]
print(context_position.shape)
context_position

torch.Size([5, 1])


tensor([[0],
        [1],
        [2],
        [3],
        [4]])

In [21]:
memory_position = torch.arange(klen, dtype=torch.long, device=device)[None, :]
print(memory_position.shape)
memory_position

torch.Size([1, 5])


tensor([[0, 1, 2, 3, 4]])

In [22]:
relative_position = memory_position - context_position
print(relative_position.shape)
relative_position

torch.Size([5, 5])


tensor([[ 0,  1,  2,  3,  4],
        [-1,  0,  1,  2,  3],
        [-2, -1,  0,  1,  2],
        [-3, -2, -1,  0,  1],
        [-4, -3, -2, -1,  0]])

In [23]:
relative_position = relative_position.abs() * -1
relative_position

tensor([[ 0, -1, -2, -3, -4],
        [-1,  0, -1, -2, -3],
        [-2, -1,  0, -1, -2],
        [-3, -2, -1,  0, -1],
        [-4, -3, -2, -1,  0]])

In [24]:
alibi = relative_position.unsqueeze(0).expand(num_heads, -1, -1) * slopes_t.unsqueeze(1).unsqueeze(1)
alibi

tensor([[[ 0.0000, -0.2500, -0.5000, -0.7500, -1.0000],
         [-0.2500,  0.0000, -0.2500, -0.5000, -0.7500],
         [-0.5000, -0.2500,  0.0000, -0.2500, -0.5000],
         [-0.7500, -0.5000, -0.2500,  0.0000, -0.2500],
         [-1.0000, -0.7500, -0.5000, -0.2500,  0.0000]],

        [[ 0.0000, -0.0625, -0.1250, -0.1875, -0.2500],
         [-0.0625,  0.0000, -0.0625, -0.1250, -0.1875],
         [-0.1250, -0.0625,  0.0000, -0.0625, -0.1250],
         [-0.1875, -0.1250, -0.0625,  0.0000, -0.0625],
         [-0.2500, -0.1875, -0.1250, -0.0625,  0.0000]],

        [[ 0.0000, -0.0156, -0.0312, -0.0469, -0.0625],
         [-0.0156,  0.0000, -0.0156, -0.0312, -0.0469],
         [-0.0312, -0.0156,  0.0000, -0.0156, -0.0312],
         [-0.0469, -0.0312, -0.0156,  0.0000, -0.0156],
         [-0.0625, -0.0469, -0.0312, -0.0156,  0.0000]],

        [[ 0.0000, -0.0039, -0.0078, -0.0117, -0.0156],
         [-0.0039,  0.0000, -0.0039, -0.0078, -0.0117],
         [-0.0078, -0.0039,  0.0000, -0.00

In [25]:
def fill_with_neg_inf(t):
    """FP16-compatible function that fills a tensor with -inf."""
    return t.float().fill_(float("-inf")).type_as(t)

In [26]:
future_mask = torch.triu(fill_with_neg_inf(torch.zeros((tokens_per_sample, tokens_per_sample))), diagonal=1)
future_mask

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [27]:
future_mask = future_mask.unsqueeze(0) + alibi
future_mask

tensor([[[ 0.0000,    -inf,    -inf,    -inf,    -inf],
         [-0.2500,  0.0000,    -inf,    -inf,    -inf],
         [-0.5000, -0.2500,  0.0000,    -inf,    -inf],
         [-0.7500, -0.5000, -0.2500,  0.0000,    -inf],
         [-1.0000, -0.7500, -0.5000, -0.2500,  0.0000]],

        [[ 0.0000,    -inf,    -inf,    -inf,    -inf],
         [-0.0625,  0.0000,    -inf,    -inf,    -inf],
         [-0.1250, -0.0625,  0.0000,    -inf,    -inf],
         [-0.1875, -0.1250, -0.0625,  0.0000,    -inf],
         [-0.2500, -0.1875, -0.1250, -0.0625,  0.0000]],

        [[ 0.0000,    -inf,    -inf,    -inf,    -inf],
         [-0.0156,  0.0000,    -inf,    -inf,    -inf],
         [-0.0312, -0.0156,  0.0000,    -inf,    -inf],
         [-0.0469, -0.0312, -0.0156,  0.0000,    -inf],
         [-0.0625, -0.0469, -0.0312, -0.0156,  0.0000]],

        [[ 0.0000,    -inf,    -inf,    -inf,    -inf],
         [-0.0039,  0.0000,    -inf,    -inf,    -inf],
         [-0.0078, -0.0039,  0.0000,    -i

In [28]:
future_mask.shape

torch.Size([4, 5, 5])

In [30]:
future_mask[:, :3, :3]

tensor([[[ 0.0000,    -inf,    -inf],
         [-0.2500,  0.0000,    -inf],
         [-0.5000, -0.2500,  0.0000]],

        [[ 0.0000,    -inf,    -inf],
         [-0.0625,  0.0000,    -inf],
         [-0.1250, -0.0625,  0.0000]],

        [[ 0.0000,    -inf,    -inf],
         [-0.0156,  0.0000,    -inf],
         [-0.0312, -0.0156,  0.0000]],

        [[ 0.0000,    -inf,    -inf],
         [-0.0039,  0.0000,    -inf],
         [-0.0078, -0.0039,  0.0000]]])

In [75]:
mascara_normal = (
            torch.tril(torch.ones((5, 5)))
            .view((1, 1, 5, 5))
        )
mascara_normal

tensor([[[[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]]]])

In [76]:
mascara_normal + future_mask

tensor([[[[-1.0000,    -inf,    -inf,    -inf,    -inf],
          [-1.0000, -0.5000,    -inf,    -inf,    -inf],
          [-1.0000, -0.5000,  0.0000,    -inf,    -inf],
          [-1.0000, -0.5000,  0.0000,  0.5000,    -inf],
          [-1.0000, -0.5000,  0.0000,  0.5000,  1.0000]],

         [[ 0.5000,    -inf,    -inf,    -inf,    -inf],
          [ 0.5000,  0.6250,    -inf,    -inf,    -inf],
          [ 0.5000,  0.6250,  0.7500,    -inf,    -inf],
          [ 0.5000,  0.6250,  0.7500,  0.8750,    -inf],
          [ 0.5000,  0.6250,  0.7500,  0.8750,  1.0000]],

         [[ 0.8750,    -inf,    -inf,    -inf,    -inf],
          [ 0.8750,  0.9062,    -inf,    -inf,    -inf],
          [ 0.8750,  0.9062,  0.9375,    -inf,    -inf],
          [ 0.8750,  0.9062,  0.9375,  0.9688,    -inf],
          [ 0.8750,  0.9062,  0.9375,  0.9688,  1.0000]],

         [[ 0.9688,    -inf,    -inf,    -inf,    -inf],
          [ 0.9688,  0.9766,    -inf,    -inf,    -inf],
          [ 0.9688,  0.97