## Query, Key, Value

The attention head and its associated paper - Attention is all you need - is in a lot of ways remarkably simple considering its impact

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim

In [2]:
%run ../lib/bookreader.py

In [19]:
alice = BookReader()
alice.read("../resources/alice.txt")

vocab_size = len(alice.itos)

lowercase only


### Explore the mechanism

take a text sample of resonable context length

we want to think of how the tokens affect each other

if we throw our mind way back to bigrams - we create a lookup table of all characters against each other

for a context of length 36 we can think of doing the same thing again 

we're using dot-product attention 

(compared in the paper to additive attention: this is a simple linear layer - the wavenet used successive dilutions each of which was a kind of additive attention)

In [17]:
xs, ys = alice.sample_batch(5, 36)
alice.decode(xs[0].tolist())

'ting for the hedgehogs  and in a ver'

In [81]:
embedding_dimension = 3
# attention dimension
dk = 48

em = torch.randn(vocab_size, embedding_dimension)

emx = em[xs]

print(emx.shape)

B, T, C = emx.shape

q = torch.randn(C, dk)
k = torch.randn(C, dk)

qt = emx @ q
kt = emx @ k

print(qt.shape, kt.mT.shape)

# x.mT is equivalent to x.transpose(-2, -1).
lu = qt @ kt.mT
# andreq uses the equivalent
lub = qt @ torch.transpose(kt, -2, -1)

print(lu.shape)

print(lu[0][0])
print(lub[0][0])

torch.Size([5, 36, 3])
torch.Size([5, 36, 48]) torch.Size([5, 48, 36])
torch.Size([5, 36, 36])
tensor([ 21.5033, -21.4774,  12.3940,   6.1019,   4.9846,  22.4517,   5.4259,
        -23.3626,   4.9846,  21.5033,  -1.7026,  15.2632,   4.9846,  -1.7026,
         15.2632,   0.9464,   6.1019,  15.2632,  -1.7026,   5.4259,   6.1019,
          8.3596,   4.9846,   4.9846,  23.0301,  12.3940,   0.9464,   4.9846,
        -21.4774,  12.3940,   4.9846,  23.0301,   4.9846,  -2.3029,  15.2632,
        -23.3626])
tensor([ 21.5033, -21.4774,  12.3940,   6.1019,   4.9846,  22.4517,   5.4259,
        -23.3626,   4.9846,  21.5033,  -1.7026,  15.2632,   4.9846,  -1.7026,
         15.2632,   0.9464,   6.1019,  15.2632,  -1.7026,   5.4259,   6.1019,
          8.3596,   4.9846,   4.9846,  23.0301,  12.3940,   0.9464,   4.9846,
        -21.4774,  12.3940,   4.9846,  23.0301,   4.9846,  -2.3029,  15.2632,
        -23.3626])


### What have we here?
a kind of lookup for how each token depends on all the others

if we softmax it we'll get a probabilithy like view

In [73]:
probs = F.softmax(lu, dim=1)
probs.shape, probs.sum(dim=1)[0][0]

(torch.Size([5, 36, 36]), tensor(1.))

### Does the k, q, v language make sense?

we've got 2 learnable tensors here, the 'query' is in reality the input tensor

the attention heads query vector is really a tensor that converts the input tensor into the attention head's space

we also take the dot-product of the input with 'key' vector - again this converts the input into the attention head's space

for me it's really the dot-product of these that is the query-key lookup: this is where we look at the probs[0][0] to see how we think 
the first Token interacts with the other Tokens in the context.

#### Value

the 'value' tensor operates in a similar way - we find the dot product of the input with it first

#### Scale the q, k table
the paper introduces a scaling for the query-key lookup 
 *'We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has
 extremely small gradients 4. To counteract this effect, we scale the dot products by 1/sqrt(dk)'*

In [77]:
v = torch.randn(C, dk)

vt = emx @ v
print(vt.shape)

out = probs * dk**-0.5 @ vt
out.shape, out[0][0]

torch.Size([5, 36, 48])


(torch.Size([5, 36, 48]),
 tensor([ 6.6203e-07,  3.1450e-07, -1.6754e-07, -1.2448e-06, -5.7189e-07,
          8.9880e-07,  1.7292e-07, -1.4849e-06,  5.6165e-07,  7.0879e-07,
         -5.9794e-07, -1.0631e-06,  1.7446e-07, -6.2405e-07, -4.0282e-07,
         -9.4367e-08,  7.9970e-07, -1.8231e-06,  6.6989e-07, -2.4515e-06,
          3.0819e-07,  6.8690e-07,  1.2855e-06, -1.8265e-06,  3.2168e-07,
          1.3574e-06,  5.9770e-07,  4.6274e-08, -5.2082e-07,  6.5413e-07,
         -3.7107e-07,  1.5625e-06, -1.7119e-07, -8.3103e-07, -1.2251e-06,
          2.1602e-06, -7.1606e-07, -1.2778e-06, -1.2490e-06, -2.9580e-07,
         -2.7646e-07, -1.8160e-07, -7.1711e-07, -8.5725e-07,  2.4421e-07,
         -9.9975e-07,  1.7295e-06,  4.0540e-07]))

## Masked attention

in a decoder network we don't want preceeding tokens to be able to gain information from the following tokens (it's compared to being able to see the answers to questions in advance)

to do this we use a simple triangluar tensor multiplication - a lower left (tril) with the upper right elements set to -inf so when we use softmax we keep a normalized output 

In [98]:
qk = torch.randn(6, 6)
print(qk)
tri = torch.tril(qk)
print(tri)
mask = tri == 0
out = tri.masked_fill(mask, float('-inf'))
print(out)
F.softmax(out, dim=1)

tensor([[ 0.2135, -0.4079, -0.2872, -0.8498,  1.6850, -1.2782],
        [ 0.3345,  1.6951,  0.1399,  1.2216,  0.0459, -1.8412],
        [ 0.0864, -0.6933, -1.9226,  0.1483,  0.3684,  1.7560],
        [ 0.7434,  0.3826,  1.2543,  1.6146, -0.7157,  0.6360],
        [-1.8041, -0.0030, -1.1502,  2.4319,  0.8799,  1.5553],
        [-0.6071, -0.8940, -0.4337,  1.4431, -0.1497,  2.5575]])
tensor([[ 0.2135,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.3345,  1.6951,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0864, -0.6933, -1.9226,  0.0000,  0.0000,  0.0000],
        [ 0.7434,  0.3826,  1.2543,  1.6146,  0.0000,  0.0000],
        [-1.8041, -0.0030, -1.1502,  2.4319,  0.8799,  0.0000],
        [-0.6071, -0.8940, -0.4337,  1.4431, -0.1497,  2.5575]])
tensor([[ 0.2135,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.3345,  1.6951,    -inf,    -inf,    -inf,    -inf],
        [ 0.0864, -0.6933, -1.9226,    -inf,    -inf,    -inf],
        [ 0.7434,  0.3826,  1.2543,  1

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.7959, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6279, 0.2879, 0.0842, 0.0000, 0.0000, 0.0000],
        [0.1738, 0.1212, 0.2897, 0.4153, 0.0000, 0.0000],
        [0.0108, 0.0653, 0.0207, 0.7453, 0.1579, 0.0000],
        [0.0278, 0.0209, 0.0331, 0.2160, 0.0439, 0.6583]])

### Why multiple heads?

It's not really clear from the lecture why we have multiple head rather than just increasing the dimensions dk of a single head

again we look to the paper - the softmax is the difference - with multiple samller heads each with its own softmax the model learns
different attention types for each head which would otherwise be averaged out in a single larger head