In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat, pack, unpack, einsum

b_s = 16
seq_len = 512
head_dim = 10
num_attn_heads = 8
embed_dim = 13 


In [53]:
# Create some fake traininig data

input_data = torch.randn(b_s, seq_len, embed_dim)

Wq = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
Wk = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
Wv = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
Wo = nn.Linear(num_attn_heads * head_dim, embed_dim) # Look up why Wo can hava a bias term

queries = Wq(input_data)
keys = Wk(input_data)
values = Wv(input_data)

# The above is all the same, now we're goint to add fake recurrence

recurrence = torch.randn(b_s, seq_len, 2, num_attn_heads * head_dim) # This is how it will come in.

recurrence_keys, recurrence_values = recurrence.unbind(dim=-2) # Each of the 2 has dim (b_s, seq_len, num_attn_heads * head_dim)

xl_keys = torch.cat((recurrence_keys, keys), dim=-2) # Prepend the recurrence keys to the keys along the seq_len dimension (-2)
# So we are basically doubling the seq_len
xl_values = torch.cat((recurrence_values, values), dim=-2)

#print(queries.shape)  # torch.Size([16, 512, 80])
#print(xl_keys.shape)  # torch.Size([16, 1024, 80])
#xl_values.shape       # torch.Size([16, 1024, 80])

# Same as before: we pull the heads out and caculate attention scores per head with einsum.
queries = rearrange(queries, 'b s (h d) -> b h s d', h = num_attn_heads)
xl_keys = rearrange(xl_keys, 'b s (h d) -> b h s d', h = num_attn_heads) # s is twice as long here
attn_scores = einsum(queries, xl_keys, 'b h s1 d, b h s2 d -> b h s1 s2')

# Apply causal masking -> tokens from the recurrent matrix are all allowed to be seen.
# 1    1    1    1    1 -inf -inf -inf
# 1    1    1    1    1    1 -inf -inf
# 1    1    1    1    1    1    1 -inf
# 1    1    1    1    1    1    1    1

rows, cols = attn_scores.shape[-2:]
mask = torch.ones((rows, cols), dtype=torch.bool).triu(diagonal=(cols-rows)+1)
attn_scores = attn_scores.masked_fill(mask, float("-inf"))

# -1 here means that we sum up to 1 across the rows, see below
attn_weights = F.softmax(attn_scores, dim = -1)

# All the rest is reqular self attention again i.e. we rearrange the 

xl_values = rearrange(xl_values, 'b s (h d) -> b h s d', h = num_attn_heads) # s is twice as long here

# print(attn_weights.shape) -> torch.Size([16, 8, 512, 1024])
# print(xl_values.shape) -> torch.Size([16, 8, 1024, 10])

head_context_vectors = attn_weights@xl_values
# print(head_context_vectors.shape) -> torch.Size([16, 8, 512, 10])

multihead_context_vector = rearrange(head_context_vectors, 'b h s d -> b s (h d)')
#multihead_context_vector.shape -> torch.Size([16, 512, 80])

out = Wo(multihead_context_vector)  # Back to embedding dimension
# out.shape -> torch.Size([16, 512, 13])

In [28]:
# Intermezzo on triu

import torch

attn_scores = torch.tensor([[0.,1.,2.,3.,4.,5.,6.,7.],
               [0.,1.,2.,3.,4.,5.,6.,7.],
               [0.,1.,2.,3.,4.,5.,6.,7.],
               [0.,1.,2.,3.,4.,5.,6.,7.]])

rows, cols = attn_scores.shape[-2:] # Here is equiv. to attn_scores.shape[0:] -> basically we are getting the dims 
#print(attn_scores.shape) -> torch.Size([4, 8])

mask = torch.ones((4, 8), dtype=torch.bool).triu(diagonal=(cols-rows)+1) # triu retains upper half, sets bottom to zero. Half is shifted up +4.
# Note: if the inputs are square, then you can put "rows+1", but this would only work with recurrence.
# Without recurrence we need the dim to be 1
# So we use (cols-rows)+1 -> this is rows+1 if they are not the same and 1 otherwise.
print(mask)
attn_scores = attn_scores.masked_fill(mask, float("-inf")) # Put -inf where mask is True, leave untouched elsewhere.
print(attn_scores)

tensor([[False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False]])
tensor([[0., 1., 2., 3., 4., -inf, -inf, -inf],
        [0., 1., 2., 3., 4., 5., -inf, -inf],
        [0., 1., 2., 3., 4., 5., 6., -inf],
        [0., 1., 2., 3., 4., 5., 6., 7.]])


In [45]:
# Intermezzo on softmax

attn_scores = torch.tensor([[0.,1.,2.,3.,4.,5.,6.,7.],
               [0.,10.,20.,30.,40.,50.,60.,70.],
               [0.,100.,200.,300.,400.,500.,600.,700.],
               [0.,1000.,2000.,3000.,4000.,5000.,6000.,7000.]])
attn_weights = F.softmax(attn_scores, dim = -1)
print(attn_weights)
print(attn_weights.shape) 
print(attn_weights.shape[-1])
print(torch.sum(attn_weights, dim=-1))
# attn_weights[0].sum() -> this sums up only the first row, sum() comes instead of the last dimension


# dim -1 is the "most inner" list, so the row in the below. So after softmax, the row sums up to 1.
#tensor([[5.7661e-04, 1.5674e-03, 4.2606e-03, 1.1582e-02, 3.1482e-02, 8.5577e-02,
#         2.3262e-01, 6.3233e-01],
#        [3.9753e-31, 8.7561e-27, 1.9287e-22, 4.2482e-18, 9.3572e-14, 2.0611e-09,
#         4.5398e-05, 9.9995e-01],
#        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
#         3.7835e-44, 1.0000e+00],
#        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
#         0.0000e+00, 1.0000e+00]])
#torch.Size([4, 8])
#8
#tensor([1., 1., 1., 1.])


tensor([[5.7661e-04, 1.5674e-03, 4.2606e-03, 1.1582e-02, 3.1482e-02, 8.5577e-02,
         2.3262e-01, 6.3233e-01],
        [3.9753e-31, 8.7561e-27, 1.9287e-22, 4.2482e-18, 9.3572e-14, 2.0611e-09,
         4.5398e-05, 9.9995e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         3.7835e-44, 1.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 1.0000e+00]])
torch.Size([4, 8])
8
tensor([1., 1., 1., 1.])


In [None]:
# 24:00 